#!/usr/bin/env python # -*- coding: utf-8 -*- from typing import Any, Dict, List, Optional, TypeVar, Tuple, Union, Type, Generic from fastapi.encoders import jsonable_encoder from pydantic import BaseModel from sqlalchemy import select, func, delete from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.orm import Bundle, Session from sqlalchemy.exc import IntegrityError from sqlalchemy import asc, desc, text from db.base import Base ModelType = TypeVar("ModelType", bound=Base) CreateSchemaType = TypeVar("CreateSchemaType", bound=BaseModel) UpdateSchemaType = TypeVar("UpdateSchemaType", bound=BaseModel) class CRUDBase(Generic[ModelType, CreateSchemaType, UpdateSchemaType]): def __init__(self, model: Type[ModelType]): """ CRUD object with default methods to Create, Read, Update, Delete (CRUD). **Parameters** * `model`: A SQLAlchemy model class * `schemas`: A Pydantic model (schemas) class """ self.model = model def get(self, db: Session, id: Any) -> Optional[ModelType]: return db.query(self.model).filter(self.model.id == id).first() def get_multi(self, db: Session, *, skip: int = 0, limit: int = 100) -> List[ModelType]: return db.query(self.model).offset(skip).limit(limit).all() def create(self, db: Session, *, obj_in: CreateSchemaType) -> ModelType: obj_in_data = jsonable_encoder(obj_in) db_obj = self.model(**obj_in_data) # type: ignore db.add(db_obj) db.commit() db.refresh(db_obj) return db_obj def update(self, db: Session, *, db_obj: ModelType, obj_in: Union[UpdateSchemaType, Dict[str, Any]]) -> ModelType: obj_data = jsonable_encoder(db_obj) if isinstance(obj_in, dict): update_data = obj_in else: update_data = obj_in.dict(exclude_unset=True) for field in obj_data: if field in update_data: setattr(db_obj, field, update_data[field]) db.add(db_obj) db.commit() db.refresh(db_obj) return db_obj def remove(self, db: Session, *, id: int) -> ModelType: obj = db.query(self.model).get(id) db.delete(obj) db.commit() return obj class CrudManager(object): def __init__(self, model: Type[ModelType]): self.model = model def parse_filter(self, filters: Union[List[Any], Dict[str, Any]]) -> List[Any]: if isinstance(filters, dict): where_clauses = [ getattr(self.model, k) == v for k, v in filters.items() ] else: where_clauses = filters return where_clauses async def find_one(self, db: AsyncSession, filters: Union[List[Any], Dict[str, Any]], return_fields: List[str] = None) -> Optional[ModelType]: where_clauses = self.parse_filter(filters) if return_fields: fields = [getattr(self.model, k) for k in return_fields] stmt = select(Bundle("data", *fields)).where(*where_clauses).limit(1) else: stmt = select(self.model).where(*where_clauses).limit(1) try: item = await db.execute(stmt) finally: await db.close() return item.scalar() async def fetch_all(self, db: AsyncSession, *, filters: Union[List[Any], Dict[str, Any]], order_by: List[Tuple[str, int]] = None, return_fields: List[Any] = None) -> List[ModelType]: where_clauses = self.parse_filter(filters) try: if return_fields: fields = [getattr(self.model, k) for k in return_fields] #stmt = select(Bundle("data", *fields)).where(*where_clauses) stmt = select(*fields).where(*where_clauses) else: stmt = select(self.model).where(*where_clauses) if order_by: stmt = stmt.order_by(*order_by) items = (await db.execute(stmt)).scalars().all() finally: await db.close() return items async def find_all( self, db: AsyncSession, *, filters: Union[List[Any], Dict[str, Any]] = None, offset: int = None, limit: int = None, order_by: List[Tuple[str, int]] = None, return_fields: List[Any] = None) -> Tuple[int, List[ModelType]]: where_clauses = self.parse_filter(filters) # 统计总数 stmt = select(func.count(self.model.id)).where(*where_clauses) total = (await db.execute(stmt)).first()[0] # 指定字段,以List[Tuple[str...]]形式返回 if return_fields: fields = [getattr(self.model, k) for k in return_fields] stmt = select(Bundle("data", *fields)).where(*where_clauses) else: stmt = select(self.model).where(*where_clauses) # 是否有排序 if order_by: stmt = stmt.order_by(*order_by) else: stmt = stmt.order_by(*[asc("id")]) # 翻页和数量限制 if offset: stmt = stmt.offset(offset) if limit: stmt = stmt.limit(limit) # 执行查询返回数据 try: items = (await db.execute(stmt)).scalars().all() finally: await db.close() return total, items async def insert_one(self, db: AsyncSession, obj_in: CreateSchemaType) -> Union[None, ModelType]: db_obj = self.model( **jsonable_encoder(obj_in, by_alias=False, exclude_none=True)) db.add(db_obj) try: await db.commit() await db.refresh(db_obj) except Exception as ex: print(f"[InsertOne Error] {str(ex)}") return None else: return db_obj finally: await db.close() async def insert_many(self, db: AsyncSession, obj_in: List[CreateSchemaType]) -> Union[None, int]: objs = [ self.model(**jsonable_encoder(obj, by_alias=False)) for obj in obj_in ] db.add_all(objs) try: await db.commit() except IntegrityError as ex: print(f"[InsertMany Error] {str(ex)}") return None else: return len(objs) finally: await db.close() async def update( self, db: AsyncSession, db_obj: ModelType, obj_in: Union[UpdateSchemaType, Dict[str, Any]]) -> ModelType: if isinstance(obj_in, dict): update_data = obj_in else: update_data = obj_in.dict(exclude_none=True) # 更新字段 for field in update_data: setattr(db_obj, field, update_data[field]) # 提交入库 try: db.add(db_obj) await db.commit() await db.refresh(db_obj) finally: await db.close() return db_obj async def delete( self, db: AsyncSession, *, obj_id: int = 0, where_clauses: Union[List[Any], Dict[str, Any]] = None) -> bool: # where clauses _where_clauses = [self.model.id == obj_id] if obj_id else [] if where_clauses: where_clauses = self.parse_filter(where_clauses) _where_clauses.extend(where_clauses) try: stmt = delete(self.model).where(*_where_clauses) await db.execute(stmt) await db.commit() except Exception as ex: print(f"Delete Error: {str(ex)}") return False finally: await db.close() return True async def execute(self, db: AsyncSession, stmt) -> Union[ModelType, List[Any], bool]: if stmt.startswith("SELECT"): items = (await db.execute(stmt)).scalars().all() else: await db.execute(stmt) items = True await db.commit() return items async def increase(self, db: AsyncSession, stmt) -> bool: try: await db.execute(stmt) await db.commit() except Exception as ex: print(f"[Increase Error] {str(ex)}") return False return True async def count(self, db: AsyncSession, filters: Union[List[Any], Dict[str, Any]] = None) -> int: where_clauses = self.parse_filter(filters) stmt = select(func.count(self.model.id)).where(*where_clauses) total = (await db.execute(stmt)).first()[0] return total