from sqlalchemy import create_engine, Column, Integer, String, DateTime, Boolean, Text, cast, Numeric from sqlalchemy.ext.declarative import declarative_base from sqlalchemy.orm import sessionmaker from sqlalchemy.dialects.postgresql import ARRAY from datetime import datetime from config import settings from decimal import Decimal # 定义基类 Base = declarative_base() class ItemTrack(Base): """ItemTrack数据表映射类""" __tablename__ = "itemtrack" id = Column(Integer, primary_key=True, autoincrement=True) name = Column(String, nullable=False) price = Column(Numeric(10, 2), default="0") start_at = Column(DateTime, default=datetime.now) discontinued_at = Column(DateTime, nullable=True) discontinued_price = Column(Numeric(10, 2), default="0") notes = Column(Text, default="") is_deleted = Column(Boolean, default=False) class ItemTrackManager: """ItemTrack表操作管理器""" def __init__(self, database_url=settings.DATABASE_URL): """ 初始化数据库连接 Args: database_url: 数据库连接URL """ self.engine = create_engine( database_url, echo=True, # 设置为True可以看到生成的SQL语句 pool_pre_ping=True, # 连接池预检测 pool_size=5, # 连接池大小 max_overflow=10, # 连接池最大连接数 ) # 创建数据表(如果不存在) Base.metadata.create_all(self.engine) # 创建会话工厂 self.Session = sessionmaker(bind=self.engine) def get_session(self): """获取新的数据库会话""" return self.Session() def get_all_items(self): """ 获取所有未删除的商品 Returns: list: 商品列表 """ with self.get_session() as session: try: query = session.query(ItemTrack).filter(ItemTrack.is_deleted == False) items = query.order_by(ItemTrack.start_at.asc()).all() # print("GET ALL ITEMS:", items) return items except Exception as e: print(f"GET ALL ITEMS FAILED: {e}") return [] def get_item_by_id(self, item_id): """ 根据ID获取商品 Args: item_id: 商品ID Returns: ItemTrack: 商品对象或None """ with self.get_session() as session: try: item = session.query(ItemTrack).filter(ItemTrack.id == item_id, ItemTrack.is_deleted == False).first() return item except Exception as e: print(f"GET ITEM BY ID FAILED: {e}") return None def create_item( self, name, price="0", start_at=datetime.now(), discontinued_at=None, discontinued_price=Decimal("0"), notes="" ): """ 创建新商品 Returns: int: 商品ID """ with self.get_session() as session: try: item = ItemTrack( name=name, price=price, start_at=start_at, discontinued_at=discontinued_at, discontinued_price=discontinued_price, notes=notes, ) session.add(item) session.commit() return item.id except Exception as e: session.rollback() print(f"CREATE ITEM FAILED: {e}") return None def delete_item(self, item_id): """ 软删除商品 Args: item_id: 商品ID Returns: bool: 是否删除成功 """ with self.get_session() as session: try: item = session.query(ItemTrack).filter(ItemTrack.id == item_id).first() if item: item.is_deleted = True session.commit() return True else: return False except Exception as e: session.rollback() print(f"DELETE ITEM FAILED: {e}") return False def update_item( self, id, name=None, price=None, start_at=None, discontinued_at=None, discontinued_price=None, notes=None ): """ 更新商品信息 Args: id: 商品ID name: 新名称(可选) price: 新价格(可选) start_at: 新开始时间(可选) discontinued_at: 新下架时间(可选) discontinued_price: 新下架价格(可选) notes: 新备注(可选) Returns: bool: 是否更新成功 """ with self.get_session() as session: try: item = session.query(ItemTrack).filter(ItemTrack.id == id).first() if item: if name: item.name = name if price: item.price = price if start_at: item.start_at = start_at if discontinued_at: item.discontinued_at = discontinued_at if discontinued_price: item.discontinued_price = discontinued_price if notes: item.notes = notes session.commit() return True else: return False except Exception as e: session.rollback() print(f"UPDATE ITEM FAILED: {e}") return False