base.py 8.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273
  1. #!/usr/bin/env python
  2. # -*- coding: utf-8 -*-
  3. from typing import Any, Dict, List, Optional, TypeVar, Tuple, Union, Type, Generic
  4. from fastapi.encoders import jsonable_encoder
  5. from pydantic import BaseModel
  6. from sqlalchemy import select, func, delete
  7. from sqlalchemy.ext.asyncio import AsyncSession
  8. from sqlalchemy.orm import Bundle, Session
  9. from sqlalchemy.exc import IntegrityError
  10. from sqlalchemy import asc, desc, text
  11. from db.base import Base
  12. ModelType = TypeVar("ModelType", bound=Base)
  13. CreateSchemaType = TypeVar("CreateSchemaType", bound=BaseModel)
  14. UpdateSchemaType = TypeVar("UpdateSchemaType", bound=BaseModel)
  15. class CRUDBase(Generic[ModelType, CreateSchemaType, UpdateSchemaType]):
  16. def __init__(self, model: Type[ModelType]):
  17. """
  18. CRUD object with default methods to Create, Read, Update, Delete (CRUD).
  19. **Parameters**
  20. * `model`: A SQLAlchemy model class
  21. * `schemas`: A Pydantic model (schemas) class
  22. """
  23. self.model = model
  24. def get(self, db: Session, id: Any) -> Optional[ModelType]:
  25. return db.query(self.model).filter(self.model.id == id).first()
  26. def get_multi(self,
  27. db: Session,
  28. *,
  29. skip: int = 0,
  30. limit: int = 100) -> List[ModelType]:
  31. return db.query(self.model).offset(skip).limit(limit).all()
  32. def create(self, db: Session, *, obj_in: CreateSchemaType) -> ModelType:
  33. obj_in_data = jsonable_encoder(obj_in)
  34. db_obj = self.model(**obj_in_data) # type: ignore
  35. db.add(db_obj)
  36. db.commit()
  37. db.refresh(db_obj)
  38. return db_obj
  39. def update(self, db: Session, *, db_obj: ModelType,
  40. obj_in: Union[UpdateSchemaType, Dict[str, Any]]) -> ModelType:
  41. obj_data = jsonable_encoder(db_obj)
  42. if isinstance(obj_in, dict):
  43. update_data = obj_in
  44. else:
  45. update_data = obj_in.dict(exclude_unset=True)
  46. for field in obj_data:
  47. if field in update_data:
  48. setattr(db_obj, field, update_data[field])
  49. db.add(db_obj)
  50. db.commit()
  51. db.refresh(db_obj)
  52. return db_obj
  53. def remove(self, db: Session, *, id: int) -> ModelType:
  54. obj = db.query(self.model).get(id)
  55. db.delete(obj)
  56. db.commit()
  57. return obj
  58. class CrudManager(object):
  59. def __init__(self, model: Type[ModelType]):
  60. self.model = model
  61. def parse_filter(self, filters: Union[List[Any], Dict[str,
  62. Any]]) -> List[Any]:
  63. if isinstance(filters, dict):
  64. where_clauses = [
  65. getattr(self.model, k) == v for k, v in filters.items()
  66. ]
  67. else:
  68. where_clauses = filters
  69. return where_clauses
  70. async def find_one(self,
  71. db: AsyncSession,
  72. filters: Union[List[Any], Dict[str, Any]],
  73. return_fields: List[str] = None) -> Optional[ModelType]:
  74. where_clauses = self.parse_filter(filters)
  75. if return_fields:
  76. fields = [getattr(self.model, k) for k in return_fields]
  77. stmt = select(Bundle("data",
  78. *fields)).where(*where_clauses).limit(1)
  79. else:
  80. stmt = select(self.model).where(*where_clauses).limit(1)
  81. try:
  82. item = await db.execute(stmt)
  83. finally:
  84. await db.close()
  85. return item.scalar()
  86. async def fetch_all(self,
  87. db: AsyncSession,
  88. *,
  89. filters: Union[List[Any], Dict[str, Any]],
  90. order_by: List[Tuple[str, int]] = None,
  91. return_fields: List[Any] = None) -> List[ModelType]:
  92. where_clauses = self.parse_filter(filters)
  93. try:
  94. if return_fields:
  95. fields = [getattr(self.model, k) for k in return_fields]
  96. #stmt = select(Bundle("data", *fields)).where(*where_clauses)
  97. stmt = select(*fields).where(*where_clauses)
  98. else:
  99. stmt = select(self.model).where(*where_clauses)
  100. if order_by:
  101. stmt = stmt.order_by(*order_by)
  102. items = (await db.execute(stmt)).scalars().all()
  103. finally:
  104. await db.close()
  105. return items
  106. async def find_all(
  107. self,
  108. db: AsyncSession,
  109. *,
  110. filters: Union[List[Any], Dict[str, Any]] = None,
  111. offset: int = None,
  112. limit: int = None,
  113. order_by: List[Tuple[str, int]] = None,
  114. return_fields: List[Any] = None) -> Tuple[int, List[ModelType]]:
  115. where_clauses = self.parse_filter(filters)
  116. # 统计总数
  117. stmt = select(func.count(self.model.id)).where(*where_clauses)
  118. total = (await db.execute(stmt)).first()[0]
  119. # 指定字段,以List[Tuple[str...]]形式返回
  120. if return_fields:
  121. fields = [getattr(self.model, k) for k in return_fields]
  122. stmt = select(Bundle("data", *fields)).where(*where_clauses)
  123. else:
  124. stmt = select(self.model).where(*where_clauses)
  125. # 是否有排序
  126. if order_by:
  127. stmt = stmt.order_by(*order_by)
  128. else:
  129. stmt = stmt.order_by(*[asc("id")])
  130. # 翻页和数量限制
  131. if offset:
  132. stmt = stmt.offset(offset)
  133. if limit:
  134. stmt = stmt.limit(limit)
  135. # 执行查询返回数据
  136. try:
  137. items = (await db.execute(stmt)).scalars().all()
  138. finally:
  139. await db.close()
  140. return total, items
  141. async def insert_one(self, db: AsyncSession,
  142. obj_in: CreateSchemaType) -> Union[None, ModelType]:
  143. db_obj = self.model(
  144. **jsonable_encoder(obj_in, by_alias=False, exclude_none=True))
  145. db.add(db_obj)
  146. try:
  147. await db.commit()
  148. await db.refresh(db_obj)
  149. except Exception as ex:
  150. print(f"[InsertOne Error] {str(ex)}")
  151. return None
  152. else:
  153. return db_obj
  154. finally:
  155. await db.close()
  156. async def insert_many(self, db: AsyncSession,
  157. obj_in: List[CreateSchemaType]) -> Union[None, int]:
  158. objs = [
  159. self.model(**jsonable_encoder(obj, by_alias=False))
  160. for obj in obj_in
  161. ]
  162. db.add_all(objs)
  163. try:
  164. await db.commit()
  165. except IntegrityError as ex:
  166. print(f"[InsertMany Error] {str(ex)}")
  167. return None
  168. else:
  169. return len(objs)
  170. finally:
  171. await db.close()
  172. async def update(
  173. self, db: AsyncSession, db_obj: ModelType,
  174. obj_in: Union[UpdateSchemaType, Dict[str, Any]]) -> ModelType:
  175. if isinstance(obj_in, dict):
  176. update_data = obj_in
  177. else:
  178. update_data = obj_in.dict(exclude_none=True)
  179. # 更新字段
  180. for field in update_data:
  181. setattr(db_obj, field, update_data[field])
  182. # 提交入库
  183. try:
  184. db.add(db_obj)
  185. await db.commit()
  186. await db.refresh(db_obj)
  187. finally:
  188. await db.close()
  189. return db_obj
  190. async def delete(
  191. self,
  192. db: AsyncSession,
  193. *,
  194. obj_id: int = 0,
  195. where_clauses: Union[List[Any], Dict[str, Any]] = None) -> bool:
  196. # where clauses
  197. _where_clauses = [self.model.id == obj_id] if obj_id else []
  198. if where_clauses:
  199. where_clauses = self.parse_filter(where_clauses)
  200. _where_clauses.extend(where_clauses)
  201. try:
  202. stmt = delete(self.model).where(*_where_clauses)
  203. await db.execute(stmt)
  204. await db.commit()
  205. except Exception as ex:
  206. print(f"Delete Error: {str(ex)}")
  207. return False
  208. finally:
  209. await db.close()
  210. return True
  211. async def execute(self, db: AsyncSession,
  212. stmt) -> Union[ModelType, List[Any], bool]:
  213. if stmt.startswith("SELECT"):
  214. items = (await db.execute(stmt)).scalars().all()
  215. else:
  216. await db.execute(stmt)
  217. items = True
  218. await db.commit()
  219. return items
  220. async def increase(self, db: AsyncSession, stmt) -> bool:
  221. try:
  222. await db.execute(stmt)
  223. await db.commit()
  224. except Exception as ex:
  225. print(f"[Increase Error] {str(ex)}")
  226. return False
  227. return True
  228. async def count(self,
  229. db: AsyncSession,
  230. filters: Union[List[Any], Dict[str, Any]] = None) -> int:
  231. where_clauses = self.parse_filter(filters)
  232. stmt = select(func.count(self.model.id)).where(*where_clauses)
  233. total = (await db.execute(stmt)).first()[0]
  234. return total