|
- #!/usr/bin/env python
- # -*- coding: utf-8 -*-
- from typing import Any, Dict, List, Optional, TypeVar, Tuple, Union, Type, Generic
- from fastapi.encoders import jsonable_encoder
- from pydantic import BaseModel
- from sqlalchemy import desc
- from sqlalchemy import select, func, delete
- from sqlalchemy.exc import IntegrityError
- from sqlalchemy.ext.asyncio import AsyncSession
- from sqlalchemy.orm import Bundle, Session
- from sqlalchemy.sql import Select, Update, Delete
- from db import BaseORMModel
- ModelType = TypeVar("ModelType", bound=BaseORMModel)
- CreateSchemaType = TypeVar("CreateSchemaType", bound=BaseModel)
- UpdateSchemaType = TypeVar("UpdateSchemaType", bound=BaseModel)
- class CRUDBase(Generic[ModelType, CreateSchemaType, UpdateSchemaType]):
- def __init__(self, model: Type[ModelType]):
- """
- CRUD object with default methods to Create, Read, Update, Delete (CRUD).
- **Parameters**
- * `model`: A SQLAlchemy model class
- * `schemas`: A Pydantic model (schemas) class
- """
- self.model = model
- def get(self, db: Session, id: Any) -> Optional[ModelType]:
- return db.query(self.model).filter(self.model.id == id).first()
- def get_multi(self, db: Session, *, skip: int = 0, limit: int = 100) -> List[ModelType]:
- return db.query(self.model).offset(skip).limit(limit).all()
- def create(self, db: Session, *, obj_in: CreateSchemaType) -> ModelType:
- obj_in_data = jsonable_encoder(obj_in)
- db_obj = self.model(**obj_in_data) # type: ignore
- db.add(db_obj)
- db.commit()
- db.refresh(db_obj)
- return db_obj
- def update(self, db: Session, *, db_obj: ModelType, obj_in: Union[UpdateSchemaType,
- Dict[str, Any]]) -> ModelType:
- obj_data = jsonable_encoder(db_obj)
- if isinstance(obj_in, dict):
- update_data = obj_in
- else:
- update_data = obj_in.dict(exclude_unset=True)
- for field in obj_data:
- if field in update_data:
- setattr(db_obj, field, update_data[field])
- db.add(db_obj)
- db.commit()
- db.refresh(db_obj)
- return db_obj
- def remove(self, db: Session, *, id: int) -> ModelType:
- obj = db.query(self.model).get(id)
- db.delete(obj)
- db.commit()
- return obj
- class CrudManager(object):
- def __init__(self, model: Type[ModelType]):
- self.model = model
- def parse_filter(self, filters: Union[List[Any], Dict[str, Any]] = None) -> List[Any]:
- if isinstance(filters, dict):
- #where_clauses = [getattr(self.model, k) == v for k, v in filters.items()]
- where_clauses = []
- print(999999999999999)
- for k,v in filters.items():
- if k.endswith("__in"):
- k = k.rstrip("__in")
- where_clauses.append(getattr(self.model,k).in_(v))
- print(where_clauses,2222222222222222222)
- else:
- where_clauses.append(getattr(self.model, k) == v)
- else:
- where_clauses = filters
- return where_clauses
- async def find_one(self,
- db: AsyncSession,
- *,
- filters: Union[List[Any], Dict[str, Any]],
- return_fields: List[str] = None) -> Optional[ModelType]:
- where_clauses = self.parse_filter(filters)
- if return_fields:
- fields = [getattr(self.model, k) for k in return_fields]
- stmt = select(Bundle("data", *fields)).where(*where_clauses).limit(1)
- else:
- stmt = select(self.model).where(*where_clauses).limit(1)
- try:
- item = await db.execute(stmt)
- finally:
- await db.close()
- return item.scalar()
- async def fetch_all(self,
- db: AsyncSession,
- *,
- filters: Union[List[Any], Dict[str, Any]],
- order_by: List[Tuple[str, int]] = None,
- return_fields: List[Any] = None) -> List[ModelType]:
- where_clauses = self.parse_filter(filters)
- try:
- if return_fields:
- fields = [getattr(self.model, k) for k in return_fields]
- # stmt = select(Bundle("data", *fields)).where(*where_clauses)
- stmt = select(*fields).where(*where_clauses)
- else:
- stmt = select(self.model).where(*where_clauses)
- if order_by:
- stmt = stmt.order_by(*order_by)
- items = (await db.execute(stmt)).scalars().all()
- finally:
- await db.close()
- return items
- async def find_all(self,
- db: AsyncSession,
- *,
- filters: Union[List[Any], Dict[str, Any]] = None,
- offset: int = None,
- limit: int = None,
- order_by: List[Tuple[str, int]] = None,
- return_fields: List[Any] = None) -> Tuple[int, List[ModelType]]:
- """
- 批量查询所有数据
- """
- where_clauses = self.parse_filter(filters)
- # 统计总数
- count_stmt = select(func.count(self.model.id))
- # 指定字段,以List[Tuple[str...]]形式返回
- if return_fields:
- fetch_stmt = select(Bundle("data", *[getattr(self.model, k) for k in return_fields]))
- else:
- fetch_stmt = select(self.model)
- if where_clauses:
- count_stmt = count_stmt.where(*where_clauses)
- fetch_stmt = fetch_stmt.where(*where_clauses)
- total = (await db.execute(count_stmt)).first()[0]
- # 是否有排序
- if order_by:
- stmt = fetch_stmt.order_by(*order_by)
- else:
- stmt = fetch_stmt.order_by(*[desc("id")])
- # 翻页和数量限制
- if offset:
- stmt = stmt.offset(offset)
- if limit:
- stmt = stmt.limit(limit)
- # 执行查询返回数据
- try:
- items = (await db.execute(stmt)).scalars().all()
- finally:
- await db.close()
- return total, items
- async def insert_one(self, db: AsyncSession,
- obj_in: CreateSchemaType) -> Union[None, ModelType]:
- """
- 插入单条数据
- """
- db_obj = self.model(**jsonable_encoder(obj_in, by_alias=False, exclude_none=True))
- db.add(db_obj)
- try:
- await db.commit()
- await db.refresh(db_obj)
- except Exception as ex:
- print(f"[InsertOne Error] {str(ex)}")
- return None
- else:
- return db_obj
- finally:
- await db.close()
- async def insert_many(self, db: AsyncSession,
- obj_in: List[CreateSchemaType]) -> Union[None, int]:
- """
- 批量插入多条数据
- """
- objs = [self.model(**jsonable_encoder(obj, by_alias=False)) for obj in obj_in]
- try:
- db.add_all(objs)
- await db.commit()
- except IntegrityError as ex:
- print(f"[InsertMany Error] {str(ex)}")
- return None
- else:
- return len(objs)
- finally:
- await db.close()
- async def update(self, db: AsyncSession, db_obj: ModelType,
- obj_in: Union[UpdateSchemaType, Dict[str, Any]]) -> ModelType:
- """
- 更新数据
- """
- if isinstance(obj_in, dict):
- update_data = obj_in
- else:
- update_data = obj_in.dict(exclude_none=True)
- # 更新字段
- for field in update_data:
- setattr(db_obj, field, update_data[field])
- # 提交入库
- try:
- db.add(db_obj)
- await db.commit()
- await db.refresh(db_obj)
- finally:
- await db.close()
- return db_obj
- async def delete(self,
- db: AsyncSession,
- *,
- obj_id: int = 0,
- where_clauses: Union[List[Any], Dict[str, Any]] = None) -> bool:
- """
- 删除数据
- """
- # where clauses
- _where_clauses = [self.model.id == obj_id] if obj_id else []
- if where_clauses:
- where_clauses = self.parse_filter(where_clauses)
- _where_clauses.extend(where_clauses)
- try:
- stmt = delete(self.model).where(*_where_clauses)
- await db.execute(stmt)
- await db.commit()
- except Exception as ex:
- print(f"Delete Error: {str(ex)}")
- return False
- finally:
- await db.close()
- return True
- async def execute(self, db: AsyncSession, stmt) -> Union[ModelType, List[Any], bool]:
- """
- 执行单条sql语句
- """
- try:
- if stmt.startswith("SELECT"):
- items = await db.execute(stmt)
- else:
- await db.execute(stmt)
- await db.commit()
- items = True
- finally:
- await db.close()
- return items
- async def execute_v2(self, db: AsyncSession, stmt: Union[Select, Update, Delete]):
- try:
- if isinstance(stmt, Select):
- data = (await db.execute(stmt)).fetchall()
- else:
- await db.execute(stmt)
- await db.commit()
- data = True
- finally:
- await db.close()
- return data
- async def executemany(self, db: AsyncSession, stmts: List[str]) -> List[Union[ModelType, bool]]:
- """
- 批量执行sql语句
- """
- data = []
- try:
- data.extend([await self.execute(db, stmt) for stmt in stmts])
- finally:
- await db.close()
- return data
- async def increase(self, db: AsyncSession, stmt) -> bool:
- """
- 更新数量,类似于 redis.incr() 方法
- """
- data = True
- try:
- await db.execute(stmt)
- await db.commit()
- except Exception as ex:
- print(f"[Increase Error] {str(ex)}")
- data = False
- finally:
- await db.close()
- return data
- async def count(self,
- db: AsyncSession,
- filters: Union[List[Any], Dict[str, Any]] = None) -> int:
- """
- 统计数量
- """
- total = 0
- try:
- where_clauses = self.parse_filter(filters)
- stmt = select(func.count()).select_from(self.model)
- if where_clauses:
- stmt = stmt.where(*where_clauses)
- total += (await db.execute(stmt)).first()[0]
- finally:
- await db.close()
- return total
|