|
- #!/usr/bin/env python
- # -*- coding: utf-8 -*-
- from typing import Any, Dict, List, Optional, TypeVar, Tuple, Union, Type, Generic
- from fastapi.encoders import jsonable_encoder
- from pydantic import BaseModel
- from sqlalchemy import select, func, delete
- from sqlalchemy.ext.asyncio import AsyncSession
- from sqlalchemy.orm import Bundle, Session
- from sqlalchemy.exc import IntegrityError
- from sqlalchemy import asc, desc, text
- from db.base import Base
- ModelType = TypeVar("ModelType", bound=Base)
- CreateSchemaType = TypeVar("CreateSchemaType", bound=BaseModel)
- UpdateSchemaType = TypeVar("UpdateSchemaType", bound=BaseModel)
- class CRUDBase(Generic[ModelType, CreateSchemaType, UpdateSchemaType]):
- def __init__(self, model: Type[ModelType]):
- """
- CRUD object with default methods to Create, Read, Update, Delete (CRUD).
- **Parameters**
- * `model`: A SQLAlchemy model class
- * `schemas`: A Pydantic model (schemas) class
- """
- self.model = model
- def get(self, db: Session, id: Any) -> Optional[ModelType]:
- return db.query(self.model).filter(self.model.id == id).first()
- def get_multi(self,
- db: Session,
- *,
- skip: int = 0,
- limit: int = 100) -> List[ModelType]:
- return db.query(self.model).offset(skip).limit(limit).all()
- def create(self, db: Session, *, obj_in: CreateSchemaType) -> ModelType:
- obj_in_data = jsonable_encoder(obj_in)
- db_obj = self.model(**obj_in_data) # type: ignore
- db.add(db_obj)
- db.commit()
- db.refresh(db_obj)
- return db_obj
- def update(self, db: Session, *, db_obj: ModelType,
- obj_in: Union[UpdateSchemaType, Dict[str, Any]]) -> ModelType:
- obj_data = jsonable_encoder(db_obj)
- if isinstance(obj_in, dict):
- update_data = obj_in
- else:
- update_data = obj_in.dict(exclude_unset=True)
- for field in obj_data:
- if field in update_data:
- setattr(db_obj, field, update_data[field])
- db.add(db_obj)
- db.commit()
- db.refresh(db_obj)
- return db_obj
- def remove(self, db: Session, *, id: int) -> ModelType:
- obj = db.query(self.model).get(id)
- db.delete(obj)
- db.commit()
- return obj
- class CrudManager(object):
- def __init__(self, model: Type[ModelType]):
- self.model = model
- def parse_filter(self, filters: Union[List[Any], Dict[str,
- Any]]) -> List[Any]:
- if isinstance(filters, dict):
- where_clauses = [
- getattr(self.model, k) == v for k, v in filters.items()
- ]
- else:
- where_clauses = filters
- return where_clauses
- async def find_one(self,
- db: AsyncSession,
- filters: Union[List[Any], Dict[str, Any]],
- return_fields: List[str] = None) -> Optional[ModelType]:
- where_clauses = self.parse_filter(filters)
- if return_fields:
- fields = [getattr(self.model, k) for k in return_fields]
- stmt = select(Bundle("data",
- *fields)).where(*where_clauses).limit(1)
- else:
- stmt = select(self.model).where(*where_clauses).limit(1)
- try:
- item = await db.execute(stmt)
- finally:
- await db.close()
- return item.scalar()
- async def fetch_all(self,
- db: AsyncSession,
- *,
- filters: Union[List[Any], Dict[str, Any]],
- order_by: List[Tuple[str, int]] = None,
- return_fields: List[Any] = None) -> List[ModelType]:
- where_clauses = self.parse_filter(filters)
- try:
- if return_fields:
- fields = [getattr(self.model, k) for k in return_fields]
- #stmt = select(Bundle("data", *fields)).where(*where_clauses)
- stmt = select(*fields).where(*where_clauses)
- else:
- stmt = select(self.model).where(*where_clauses)
- if order_by:
- stmt = stmt.order_by(*order_by)
- items = (await db.execute(stmt)).scalars().all()
- finally:
- await db.close()
- return items
- async def find_all(
- self,
- db: AsyncSession,
- *,
- filters: Union[List[Any], Dict[str, Any]] = None,
- offset: int = None,
- limit: int = None,
- order_by: List[Tuple[str, int]] = None,
- return_fields: List[Any] = None) -> Tuple[int, List[ModelType]]:
- where_clauses = self.parse_filter(filters)
- # 统计总数
- stmt = select(func.count(self.model.id)).where(*where_clauses)
- total = (await db.execute(stmt)).first()[0]
- # 指定字段,以List[Tuple[str...]]形式返回
- if return_fields:
- fields = [getattr(self.model, k) for k in return_fields]
- stmt = select(Bundle("data", *fields)).where(*where_clauses)
- else:
- stmt = select(self.model).where(*where_clauses)
- # 是否有排序
- if order_by:
- stmt = stmt.order_by(*order_by)
- else:
- stmt = stmt.order_by(*[asc("id")])
- # 翻页和数量限制
- if offset:
- stmt = stmt.offset(offset)
- if limit:
- stmt = stmt.limit(limit)
- # 执行查询返回数据
- try:
- items = (await db.execute(stmt)).scalars().all()
- finally:
- await db.close()
- return total, items
- async def insert_one(self, db: AsyncSession,
- obj_in: CreateSchemaType) -> Union[None, ModelType]:
- db_obj = self.model(
- **jsonable_encoder(obj_in, by_alias=False, exclude_none=True))
- db.add(db_obj)
- try:
- await db.commit()
- await db.refresh(db_obj)
- except Exception as ex:
- print(f"[InsertOne Error] {str(ex)}")
- return None
- else:
- return db_obj
- finally:
- await db.close()
- async def insert_many(self, db: AsyncSession,
- obj_in: List[CreateSchemaType]) -> Union[None, int]:
- objs = [
- self.model(**jsonable_encoder(obj, by_alias=False))
- for obj in obj_in
- ]
- db.add_all(objs)
- try:
- await db.commit()
- except IntegrityError as ex:
- print(f"[InsertMany Error] {str(ex)}")
- return None
- else:
- return len(objs)
- finally:
- await db.close()
- async def update(
- self, db: AsyncSession, db_obj: ModelType,
- obj_in: Union[UpdateSchemaType, Dict[str, Any]]) -> ModelType:
- if isinstance(obj_in, dict):
- update_data = obj_in
- else:
- update_data = obj_in.dict(exclude_none=True)
- # 更新字段
- for field in update_data:
- setattr(db_obj, field, update_data[field])
- # 提交入库
- try:
- db.add(db_obj)
- await db.commit()
- await db.refresh(db_obj)
- finally:
- await db.close()
- return db_obj
- async def delete(
- self,
- db: AsyncSession,
- *,
- obj_id: int = 0,
- where_clauses: Union[List[Any], Dict[str, Any]] = None) -> bool:
- # where clauses
- _where_clauses = [self.model.id == obj_id] if obj_id else []
- if where_clauses:
- where_clauses = self.parse_filter(where_clauses)
- _where_clauses.extend(where_clauses)
- try:
- stmt = delete(self.model).where(*_where_clauses)
- await db.execute(stmt)
- await db.commit()
- except Exception as ex:
- print(f"Delete Error: {str(ex)}")
- return False
- finally:
- await db.close()
- return True
- async def execute(self, db: AsyncSession,
- stmt) -> Union[ModelType, List[Any], bool]:
- if stmt.startswith("SELECT"):
- items = (await db.execute(stmt)).scalars().all()
- else:
- await db.execute(stmt)
- items = True
- await db.commit()
- return items
- async def increase(self, db: AsyncSession, stmt) -> bool:
- try:
- await db.execute(stmt)
- await db.commit()
- except Exception as ex:
- print(f"[Increase Error] {str(ex)}")
- return False
- return True
- async def count(self,
- db: AsyncSession,
- filters: Union[List[Any], Dict[str, Any]] = None) -> int:
- where_clauses = self.parse_filter(filters)
- stmt = select(func.count(self.model.id)).where(*where_clauses)
- total = (await db.execute(stmt)).first()[0]
- return total
|