242 lines
8.2 KiB
Python
242 lines
8.2 KiB
Python
from sqlalchemy import create_engine, Column, Integer, String, DateTime, Boolean, Text, cast
|
||
from sqlalchemy.ext.declarative import declarative_base
|
||
from sqlalchemy.orm import sessionmaker
|
||
from sqlalchemy.dialects.postgresql import ARRAY
|
||
from datetime import datetime
|
||
from config import settings
|
||
|
||
|
||
# 定义基类
|
||
Base = declarative_base()
|
||
|
||
|
||
class Note(Base):
|
||
"""Note数据表映射类"""
|
||
|
||
__tablename__ = "note"
|
||
|
||
id = Column(Integer, primary_key=True, autoincrement=True)
|
||
content = Column(String, nullable=False)
|
||
tags = Column(ARRAY(String))
|
||
created_at = Column(DateTime, default=datetime.now)
|
||
updated_at = Column(DateTime, default=datetime.now, onupdate=datetime.now)
|
||
is_deleted = Column(Boolean, default=False)
|
||
name = Column(String)
|
||
|
||
|
||
class NoteManager:
|
||
"""Note表操作管理器"""
|
||
|
||
def __init__(self, database_url=settings.DATABASE_URL):
|
||
"""
|
||
初始化数据库连接
|
||
|
||
Args:
|
||
database_url: 数据库连接URL
|
||
"""
|
||
self.engine = create_engine(
|
||
database_url,
|
||
echo=True, # 设置为True可以看到生成的SQL语句
|
||
pool_pre_ping=True, # 连接池预检测
|
||
pool_size=5, # 连接池大小
|
||
max_overflow=10, # 连接池最大连接数
|
||
)
|
||
|
||
# 创建数据表(如果不存在)
|
||
Base.metadata.create_all(self.engine)
|
||
|
||
# 创建会话工厂
|
||
self.Session = sessionmaker(bind=self.engine)
|
||
|
||
def get_session(self):
|
||
"""获取新的数据库会话"""
|
||
return self.Session()
|
||
|
||
def create_note(self, name, content, tags=None):
|
||
"""
|
||
创建新笔记
|
||
|
||
Args:
|
||
content: 笔记内容
|
||
tags: 标签列表
|
||
|
||
Returns:
|
||
Note: 新创建的Note对象
|
||
"""
|
||
with self.get_session() as session:
|
||
try:
|
||
if tags and len(tags) == 0:
|
||
tags = ["未分类"]
|
||
note = Note(content=content, tags=tags, name=name)
|
||
session.add(note)
|
||
session.commit()
|
||
# print(f"创建笔记成功,ID: {note.id}")
|
||
return note.id
|
||
except Exception as e:
|
||
session.rollback()
|
||
print(f"CREATE NOTE FAILED: {e}")
|
||
return None
|
||
|
||
def update_note(self, name, note_id, content=None, tags=None, created_at=None):
|
||
"""
|
||
更新笔记
|
||
|
||
Args:
|
||
note_id: 笔记ID
|
||
content: 新的内容(可选)
|
||
tags: 新的标签列表(可选)
|
||
|
||
Returns:
|
||
bool: 更新成功返回True,失败返回False
|
||
"""
|
||
with self.get_session() as session:
|
||
try:
|
||
note = session.query(Note).filter(Note.id == note_id, Note.is_deleted == False).first()
|
||
if note.name != name:
|
||
session.rollback()
|
||
print(f"UPDATE NOTE FAILED, NAME NOT MATCH: {e}")
|
||
return False
|
||
if not note:
|
||
print(f"NOTE ID NOT FOUND: {note_id} ")
|
||
return False
|
||
|
||
if content is None:
|
||
self.delete_note(note_id)
|
||
return
|
||
else:
|
||
note.content = content
|
||
|
||
if tags is None or len(tags) == 0:
|
||
note.tags = ["未分类"]
|
||
elif "未分类" in tags and len(tags) > 1:
|
||
tags.remove("未分类")
|
||
note.tags = tags
|
||
else:
|
||
note.tags = tags
|
||
|
||
if created_at is None:
|
||
note.created_at = datetime.now()
|
||
else:
|
||
note.created_at = created_at
|
||
|
||
note.updated_at = datetime.now()
|
||
session.commit()
|
||
# print(f"更新笔记成功,ID: {note_id}")
|
||
return True
|
||
except Exception as e:
|
||
session.rollback()
|
||
return False
|
||
|
||
def delete_note(self, name, note_id):
|
||
"""
|
||
删除笔记
|
||
|
||
Args:
|
||
note_id: 笔记ID
|
||
|
||
Returns:
|
||
bool: 删除成功返回True,失败返回False
|
||
"""
|
||
with self.get_session() as session:
|
||
try:
|
||
# 软删除
|
||
note = session.query(Note).filter(Note.id == note_id, Note.is_deleted == False).first()
|
||
if note.name != name:
|
||
session.rollback()
|
||
print(f"DELETE NOTE FAILED, NAME NOT MATCH: {e}")
|
||
return False
|
||
if note:
|
||
note.is_deleted = True
|
||
note.updated_at = datetime.now()
|
||
session.commit()
|
||
# print(f"软删除笔记成功,ID: {note_id}")
|
||
return True
|
||
else:
|
||
print(f"NOTE ID NOT FOUND: {note_id} ")
|
||
return False
|
||
except Exception as e:
|
||
session.rollback()
|
||
print(f"DELETE NOTE FAILED: {e}")
|
||
return False
|
||
|
||
def get_note_by_page(self, name, page=1, page_size=10, note_filter="", tag_filter=[]):
|
||
"""
|
||
分页查询笔记
|
||
|
||
Args:
|
||
page: 当前页码
|
||
page_size: 每页显示的笔记数量
|
||
note_filter: 笔记内容过滤条件(可选)
|
||
tag_filter: 标签过滤条件(可选)
|
||
|
||
Returns:
|
||
list: 笔记列表
|
||
"""
|
||
with self.get_session() as session:
|
||
try:
|
||
start_idx = (page - 1) * page_size
|
||
query = session.query(Note).filter(Note.name == name, Note.is_deleted == False)
|
||
if note_filter:
|
||
query = query.filter(Note.content.like(f"%{note_filter}%"))
|
||
if tag_filter and len(tag_filter) > 0:
|
||
query = query.filter(Note.tags.op("@>")(cast(tag_filter, ARRAY(Text))))
|
||
notes = query.order_by(Note.updated_at.desc()).limit(page_size).offset(start_idx).all()
|
||
for note in notes:
|
||
note.created_at = note.created_at.strftime("%Y-%m-%d %H:%M:%S")
|
||
note.updated_at = note.updated_at.strftime("%Y-%m-%d %H:%M:%S")
|
||
return notes
|
||
except Exception as e:
|
||
print(f"GET NOTE BY PAGE FAILED: {e}")
|
||
return []
|
||
|
||
def get_total_note_num(self, name, note_filter="", tag_filter=[]):
|
||
"""
|
||
获取笔记总数
|
||
|
||
Args:
|
||
note_filter: 笔记内容过滤条件(可选)
|
||
tag_filter: 标签过滤条件(可选)
|
||
|
||
Returns:
|
||
int: 笔记总数
|
||
"""
|
||
with self.get_session() as session:
|
||
try:
|
||
query = session.query(Note).filter(Note.name == name, Note.is_deleted == False)
|
||
if note_filter:
|
||
query = query.filter(Note.content.like(f"%{note_filter}%"))
|
||
if tag_filter and len(tag_filter) > 0:
|
||
query = query.filter(Note.tags.op("@>")(cast(tag_filter, ARRAY(Text))))
|
||
count = query.count()
|
||
return count
|
||
except Exception as e:
|
||
print(f"GET TOTAL NOTE NUM FAILED: {e}")
|
||
return 0
|
||
|
||
def get_all_tags(self, name, note_filter="", tag_filter=[]):
|
||
"""
|
||
获取所有标签
|
||
|
||
Args:
|
||
note_filter: 笔记内容过滤条件(可选)
|
||
tag_filter: 标签过滤条件(可选)
|
||
|
||
Returns:
|
||
set: 标签集合
|
||
"""
|
||
with self.get_session() as session:
|
||
try:
|
||
query = session.query(Note).filter(Note.name == name, Note.is_deleted == False)
|
||
if note_filter:
|
||
query = query.filter(Note.content.like(f"%{note_filter}%"))
|
||
if tag_filter and len(tag_filter) > 0:
|
||
query = query.filter(Note.tags.op("@>")(cast(tag_filter, ARRAY(Text))))
|
||
notes = query.all()
|
||
tags = set()
|
||
for note in notes:
|
||
tags.update(note.tags)
|
||
return tags
|
||
except Exception as e:
|
||
print(f"GET ALL TAGS FAILED: {e}")
|
||
return set()
|