item-track/api/SQLAlchemy.py
2026-04-23 20:41:32 +08:00

185 lines
5.7 KiB
Python

from sqlalchemy import create_engine, Column, Integer, String, DateTime, Boolean, Text, cast, Numeric
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.orm import sessionmaker
from sqlalchemy.dialects.postgresql import ARRAY
from datetime import datetime
from config import settings
from decimal import Decimal
# 定义基类
Base = declarative_base()
class ItemTrack(Base):
"""ItemTrack数据表映射类"""
__tablename__ = "itemtrack"
id = Column(Integer, primary_key=True, autoincrement=True)
name = Column(String, nullable=False)
price = Column(Numeric(10, 2), default="0")
start_at = Column(DateTime, default=datetime.now)
discontinued_at = Column(DateTime, nullable=True)
discontinued_price = Column(Numeric(10, 2), default="0")
notes = Column(Text, default="")
is_deleted = Column(Boolean, default=False)
class ItemTrackManager:
"""ItemTrack表操作管理器"""
def __init__(self, database_url=settings.DATABASE_URL):
"""
初始化数据库连接
Args:
database_url: 数据库连接URL
"""
self.engine = create_engine(
database_url,
echo=True, # 设置为True可以看到生成的SQL语句
pool_pre_ping=True, # 连接池预检测
pool_size=5, # 连接池大小
max_overflow=10, # 连接池最大连接数
)
# 创建数据表(如果不存在)
Base.metadata.create_all(self.engine)
# 创建会话工厂
self.Session = sessionmaker(bind=self.engine)
def get_session(self):
"""获取新的数据库会话"""
return self.Session()
def get_all_items(self):
"""
获取所有未删除的商品
Returns:
list: 商品列表
"""
with self.get_session() as session:
try:
query = session.query(ItemTrack).filter(ItemTrack.is_deleted == False)
items = query.order_by(ItemTrack.start_at.asc()).all()
# print("GET ALL ITEMS:", items)
return items
except Exception as e:
print(f"GET ALL ITEMS FAILED: {e}")
return []
def get_item_by_id(self, item_id):
"""
根据ID获取商品
Args:
item_id: 商品ID
Returns:
ItemTrack: 商品对象或None
"""
with self.get_session() as session:
try:
item = session.query(ItemTrack).filter(ItemTrack.id == item_id, ItemTrack.is_deleted == False).first()
return item
except Exception as e:
print(f"GET ITEM BY ID FAILED: {e}")
return None
def create_item(
self, name, price="0", start_at=datetime.now(), discontinued_at=None, discontinued_price=Decimal("0"), notes=""
):
"""
创建新商品
Returns:
int: 商品ID
"""
with self.get_session() as session:
try:
item = ItemTrack(
name=name,
price=price,
start_at=start_at,
discontinued_at=discontinued_at,
discontinued_price=discontinued_price,
notes=notes,
)
session.add(item)
session.commit()
return item.id
except Exception as e:
session.rollback()
print(f"CREATE ITEM FAILED: {e}")
return None
def delete_item(self, item_id):
"""
软删除商品
Args:
item_id: 商品ID
Returns:
bool: 是否删除成功
"""
with self.get_session() as session:
try:
item = session.query(ItemTrack).filter(ItemTrack.id == item_id).first()
if item:
item.is_deleted = True
session.commit()
return True
else:
return False
except Exception as e:
session.rollback()
print(f"DELETE ITEM FAILED: {e}")
return False
def update_item(
self, id, name=None, price=None, start_at=None, discontinued_at=None, discontinued_price=None, notes=None
):
"""
更新商品信息
Args:
id: 商品ID
name: 新名称(可选)
price: 新价格(可选)
start_at: 新开始时间(可选)
discontinued_at: 新下架时间(可选)
discontinued_price: 新下架价格(可选)
notes: 新备注(可选)
Returns:
bool: 是否更新成功
"""
with self.get_session() as session:
try:
item = session.query(ItemTrack).filter(ItemTrack.id == id).first()
if item:
if name:
item.name = name
if price:
item.price = price
if start_at:
item.start_at = start_at
if discontinued_at:
item.discontinued_at = discontinued_at
if discontinued_price:
item.discontinued_price = discontinued_price
if notes:
item.notes = notes
session.commit()
return True
else:
return False
except Exception as e:
session.rollback()
print(f"UPDATE ITEM FAILED: {e}")
return False