note/api/SQLAlchemy.py
2026-04-23 20:42:16 +08:00

242 lines
8.2 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

from sqlalchemy import create_engine, Column, Integer, String, DateTime, Boolean, Text, cast
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.orm import sessionmaker
from sqlalchemy.dialects.postgresql import ARRAY
from datetime import datetime
from config import settings
# 定义基类
Base = declarative_base()
class Note(Base):
"""Note数据表映射类"""
__tablename__ = "note"
id = Column(Integer, primary_key=True, autoincrement=True)
content = Column(String, nullable=False)
tags = Column(ARRAY(String))
created_at = Column(DateTime, default=datetime.now)
updated_at = Column(DateTime, default=datetime.now, onupdate=datetime.now)
is_deleted = Column(Boolean, default=False)
name = Column(String)
class NoteManager:
"""Note表操作管理器"""
def __init__(self, database_url=settings.DATABASE_URL):
"""
初始化数据库连接
Args:
database_url: 数据库连接URL
"""
self.engine = create_engine(
database_url,
echo=True, # 设置为True可以看到生成的SQL语句
pool_pre_ping=True, # 连接池预检测
pool_size=5, # 连接池大小
max_overflow=10, # 连接池最大连接数
)
# 创建数据表(如果不存在)
Base.metadata.create_all(self.engine)
# 创建会话工厂
self.Session = sessionmaker(bind=self.engine)
def get_session(self):
"""获取新的数据库会话"""
return self.Session()
def create_note(self, name, content, tags=None):
"""
创建新笔记
Args:
content: 笔记内容
tags: 标签列表
Returns:
Note: 新创建的Note对象
"""
with self.get_session() as session:
try:
if tags and len(tags) == 0:
tags = ["未分类"]
note = Note(content=content, tags=tags, name=name)
session.add(note)
session.commit()
# print(f"创建笔记成功ID: {note.id}")
return note.id
except Exception as e:
session.rollback()
print(f"CREATE NOTE FAILED: {e}")
return None
def update_note(self, name, note_id, content=None, tags=None, created_at=None):
"""
更新笔记
Args:
note_id: 笔记ID
content: 新的内容(可选)
tags: 新的标签列表(可选)
Returns:
bool: 更新成功返回True失败返回False
"""
with self.get_session() as session:
try:
note = session.query(Note).filter(Note.id == note_id, Note.is_deleted == False).first()
if note.name != name:
session.rollback()
print(f"UPDATE NOTE FAILED, NAME NOT MATCH: {e}")
return False
if not note:
print(f"NOTE ID NOT FOUND: {note_id} ")
return False
if content is None:
self.delete_note(note_id)
return
else:
note.content = content
if tags is None or len(tags) == 0:
note.tags = ["未分类"]
elif "未分类" in tags and len(tags) > 1:
tags.remove("未分类")
note.tags = tags
else:
note.tags = tags
if created_at is None:
note.created_at = datetime.now()
else:
note.created_at = created_at
note.updated_at = datetime.now()
session.commit()
# print(f"更新笔记成功ID: {note_id}")
return True
except Exception as e:
session.rollback()
return False
def delete_note(self, name, note_id):
"""
删除笔记
Args:
note_id: 笔记ID
Returns:
bool: 删除成功返回True失败返回False
"""
with self.get_session() as session:
try:
# 软删除
note = session.query(Note).filter(Note.id == note_id, Note.is_deleted == False).first()
if note.name != name:
session.rollback()
print(f"DELETE NOTE FAILED, NAME NOT MATCH: {e}")
return False
if note:
note.is_deleted = True
note.updated_at = datetime.now()
session.commit()
# print(f"软删除笔记成功ID: {note_id}")
return True
else:
print(f"NOTE ID NOT FOUND: {note_id} ")
return False
except Exception as e:
session.rollback()
print(f"DELETE NOTE FAILED: {e}")
return False
def get_note_by_page(self, name, page=1, page_size=10, note_filter="", tag_filter=[]):
"""
分页查询笔记
Args:
page: 当前页码
page_size: 每页显示的笔记数量
note_filter: 笔记内容过滤条件(可选)
tag_filter: 标签过滤条件(可选)
Returns:
list: 笔记列表
"""
with self.get_session() as session:
try:
start_idx = (page - 1) * page_size
query = session.query(Note).filter(Note.name == name, Note.is_deleted == False)
if note_filter:
query = query.filter(Note.content.like(f"%{note_filter}%"))
if tag_filter and len(tag_filter) > 0:
query = query.filter(Note.tags.op("@>")(cast(tag_filter, ARRAY(Text))))
notes = query.order_by(Note.updated_at.desc()).limit(page_size).offset(start_idx).all()
for note in notes:
note.created_at = note.created_at.strftime("%Y-%m-%d %H:%M:%S")
note.updated_at = note.updated_at.strftime("%Y-%m-%d %H:%M:%S")
return notes
except Exception as e:
print(f"GET NOTE BY PAGE FAILED: {e}")
return []
def get_total_note_num(self, name, note_filter="", tag_filter=[]):
"""
获取笔记总数
Args:
note_filter: 笔记内容过滤条件(可选)
tag_filter: 标签过滤条件(可选)
Returns:
int: 笔记总数
"""
with self.get_session() as session:
try:
query = session.query(Note).filter(Note.name == name, Note.is_deleted == False)
if note_filter:
query = query.filter(Note.content.like(f"%{note_filter}%"))
if tag_filter and len(tag_filter) > 0:
query = query.filter(Note.tags.op("@>")(cast(tag_filter, ARRAY(Text))))
count = query.count()
return count
except Exception as e:
print(f"GET TOTAL NOTE NUM FAILED: {e}")
return 0
def get_all_tags(self, name, note_filter="", tag_filter=[]):
"""
获取所有标签
Args:
note_filter: 笔记内容过滤条件(可选)
tag_filter: 标签过滤条件(可选)
Returns:
set: 标签集合
"""
with self.get_session() as session:
try:
query = session.query(Note).filter(Note.name == name, Note.is_deleted == False)
if note_filter:
query = query.filter(Note.content.like(f"%{note_filter}%"))
if tag_filter and len(tag_filter) > 0:
query = query.filter(Note.tags.op("@>")(cast(tag_filter, ARRAY(Text))))
notes = query.all()
tags = set()
for note in notes:
tags.update(note.tags)
return tags
except Exception as e:
print(f"GET ALL TAGS FAILED: {e}")
return set()