from sqlalchemy import create_engine, Column, Integer, String, DateTime, Boolean, Text, cast from sqlalchemy.ext.declarative import declarative_base from sqlalchemy.orm import sessionmaker from sqlalchemy.dialects.postgresql import ARRAY from datetime import datetime from config import settings # 定义基类 Base = declarative_base() class Note(Base): """Note数据表映射类""" __tablename__ = "note" id = Column(Integer, primary_key=True, autoincrement=True) content = Column(String, nullable=False) tags = Column(ARRAY(String)) created_at = Column(DateTime, default=datetime.now) updated_at = Column(DateTime, default=datetime.now, onupdate=datetime.now) is_deleted = Column(Boolean, default=False) name = Column(String) class NoteManager: """Note表操作管理器""" def __init__(self, database_url=settings.DATABASE_URL): """ 初始化数据库连接 Args: database_url: 数据库连接URL """ self.engine = create_engine( database_url, echo=True, # 设置为True可以看到生成的SQL语句 pool_pre_ping=True, # 连接池预检测 pool_size=5, # 连接池大小 max_overflow=10, # 连接池最大连接数 ) # 创建数据表(如果不存在) Base.metadata.create_all(self.engine) # 创建会话工厂 self.Session = sessionmaker(bind=self.engine) def get_session(self): """获取新的数据库会话""" return self.Session() def create_note(self, name, content, tags=None): """ 创建新笔记 Args: content: 笔记内容 tags: 标签列表 Returns: Note: 新创建的Note对象 """ with self.get_session() as session: try: if tags and len(tags) == 0: tags = ["未分类"] note = Note(content=content, tags=tags, name=name) session.add(note) session.commit() # print(f"创建笔记成功,ID: {note.id}") return note.id except Exception as e: session.rollback() print(f"CREATE NOTE FAILED: {e}") return None def update_note(self, name, note_id, content=None, tags=None, created_at=None): """ 更新笔记 Args: note_id: 笔记ID content: 新的内容(可选) tags: 新的标签列表(可选) Returns: bool: 更新成功返回True,失败返回False """ with self.get_session() as session: try: note = session.query(Note).filter(Note.id == note_id, Note.is_deleted == False).first() if note.name != name: session.rollback() print(f"UPDATE NOTE FAILED, NAME NOT MATCH: {e}") return False if not note: print(f"NOTE ID NOT FOUND: {note_id} ") return False if content is None: self.delete_note(note_id) return else: note.content = content if tags is None or len(tags) == 0: note.tags = ["未分类"] elif "未分类" in tags and len(tags) > 1: tags.remove("未分类") note.tags = tags else: note.tags = tags if created_at is None: note.created_at = datetime.now() else: note.created_at = created_at note.updated_at = datetime.now() session.commit() # print(f"更新笔记成功,ID: {note_id}") return True except Exception as e: session.rollback() return False def delete_note(self, name, note_id): """ 删除笔记 Args: note_id: 笔记ID Returns: bool: 删除成功返回True,失败返回False """ with self.get_session() as session: try: # 软删除 note = session.query(Note).filter(Note.id == note_id, Note.is_deleted == False).first() if note.name != name: session.rollback() print(f"DELETE NOTE FAILED, NAME NOT MATCH: {e}") return False if note: note.is_deleted = True note.updated_at = datetime.now() session.commit() # print(f"软删除笔记成功,ID: {note_id}") return True else: print(f"NOTE ID NOT FOUND: {note_id} ") return False except Exception as e: session.rollback() print(f"DELETE NOTE FAILED: {e}") return False def get_note_by_page(self, name, page=1, page_size=10, note_filter="", tag_filter=[]): """ 分页查询笔记 Args: page: 当前页码 page_size: 每页显示的笔记数量 note_filter: 笔记内容过滤条件(可选) tag_filter: 标签过滤条件(可选) Returns: list: 笔记列表 """ with self.get_session() as session: try: start_idx = (page - 1) * page_size query = session.query(Note).filter(Note.name == name, Note.is_deleted == False) if note_filter: query = query.filter(Note.content.like(f"%{note_filter}%")) if tag_filter and len(tag_filter) > 0: query = query.filter(Note.tags.op("@>")(cast(tag_filter, ARRAY(Text)))) notes = query.order_by(Note.updated_at.desc()).limit(page_size).offset(start_idx).all() for note in notes: note.created_at = note.created_at.strftime("%Y-%m-%d %H:%M:%S") note.updated_at = note.updated_at.strftime("%Y-%m-%d %H:%M:%S") return notes except Exception as e: print(f"GET NOTE BY PAGE FAILED: {e}") return [] def get_total_note_num(self, name, note_filter="", tag_filter=[]): """ 获取笔记总数 Args: note_filter: 笔记内容过滤条件(可选) tag_filter: 标签过滤条件(可选) Returns: int: 笔记总数 """ with self.get_session() as session: try: query = session.query(Note).filter(Note.name == name, Note.is_deleted == False) if note_filter: query = query.filter(Note.content.like(f"%{note_filter}%")) if tag_filter and len(tag_filter) > 0: query = query.filter(Note.tags.op("@>")(cast(tag_filter, ARRAY(Text)))) count = query.count() return count except Exception as e: print(f"GET TOTAL NOTE NUM FAILED: {e}") return 0 def get_all_tags(self, name, note_filter="", tag_filter=[]): """ 获取所有标签 Args: note_filter: 笔记内容过滤条件(可选) tag_filter: 标签过滤条件(可选) Returns: set: 标签集合 """ with self.get_session() as session: try: query = session.query(Note).filter(Note.name == name, Note.is_deleted == False) if note_filter: query = query.filter(Note.content.like(f"%{note_filter}%")) if tag_filter and len(tag_filter) > 0: query = query.filter(Note.tags.op("@>")(cast(tag_filter, ARRAY(Text)))) notes = query.all() tags = set() for note in notes: tags.update(note.tags) return tags except Exception as e: print(f"GET ALL TAGS FAILED: {e}") return set()