base.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328
  1. #!/usr/bin/env python
  2. # -*- coding: utf-8 -*-
  3. from typing import Any, Dict, List, Optional, TypeVar, Tuple, Union, Type, Generic
  4. from fastapi.encoders import jsonable_encoder
  5. from pydantic import BaseModel
  6. from sqlalchemy import desc
  7. from sqlalchemy import select, func, delete
  8. from sqlalchemy.exc import IntegrityError
  9. from sqlalchemy.ext.asyncio import AsyncSession
  10. from sqlalchemy.orm import Bundle, Session
  11. from sqlalchemy.sql import Select, Update, Delete
  12. from db import BaseORMModel
  13. ModelType = TypeVar("ModelType", bound=BaseORMModel)
  14. CreateSchemaType = TypeVar("CreateSchemaType", bound=BaseModel)
  15. UpdateSchemaType = TypeVar("UpdateSchemaType", bound=BaseModel)
  16. class CRUDBase(Generic[ModelType, CreateSchemaType, UpdateSchemaType]):
  17. def __init__(self, model: Type[ModelType]):
  18. """
  19. CRUD object with default methods to Create, Read, Update, Delete (CRUD).
  20. **Parameters**
  21. * `model`: A SQLAlchemy model class
  22. * `schemas`: A Pydantic model (schemas) class
  23. """
  24. self.model = model
  25. def get(self, db: Session, id: Any) -> Optional[ModelType]:
  26. return db.query(self.model).filter(self.model.id == id).first()
  27. def get_multi(self, db: Session, *, skip: int = 0, limit: int = 100) -> List[ModelType]:
  28. return db.query(self.model).offset(skip).limit(limit).all()
  29. def create(self, db: Session, *, obj_in: CreateSchemaType) -> ModelType:
  30. obj_in_data = jsonable_encoder(obj_in)
  31. db_obj = self.model(**obj_in_data) # type: ignore
  32. db.add(db_obj)
  33. db.commit()
  34. db.refresh(db_obj)
  35. return db_obj
  36. def update(self, db: Session, *, db_obj: ModelType, obj_in: Union[UpdateSchemaType,
  37. Dict[str, Any]]) -> ModelType:
  38. obj_data = jsonable_encoder(db_obj)
  39. if isinstance(obj_in, dict):
  40. update_data = obj_in
  41. else:
  42. update_data = obj_in.dict(exclude_unset=True)
  43. for field in obj_data:
  44. if field in update_data:
  45. setattr(db_obj, field, update_data[field])
  46. db.add(db_obj)
  47. db.commit()
  48. db.refresh(db_obj)
  49. return db_obj
  50. def remove(self, db: Session, *, id: int) -> ModelType:
  51. obj = db.query(self.model).get(id)
  52. db.delete(obj)
  53. db.commit()
  54. return obj
  55. class CrudManager(object):
  56. def __init__(self, model: Type[ModelType]):
  57. self.model = model
  58. def parse_filter(self, filters: Union[List[Any], Dict[str, Any]] = None) -> List[Any]:
  59. if isinstance(filters, dict):
  60. #where_clauses = [getattr(self.model, k) == v for k, v in filters.items()]
  61. where_clauses = []
  62. print(999999999999999)
  63. for k,v in filters.items():
  64. if k.endswith("__in"):
  65. k = k.rstrip("__in")
  66. where_clauses.append(getattr(self.model,k).in_(v))
  67. print(where_clauses,2222222222222222222)
  68. else:
  69. where_clauses.append(getattr(self.model, k) == v)
  70. else:
  71. where_clauses = filters
  72. return where_clauses
  73. async def find_one(self,
  74. db: AsyncSession,
  75. *,
  76. filters: Union[List[Any], Dict[str, Any]],
  77. return_fields: List[str] = None) -> Optional[ModelType]:
  78. where_clauses = self.parse_filter(filters)
  79. if return_fields:
  80. fields = [getattr(self.model, k) for k in return_fields]
  81. stmt = select(Bundle("data", *fields)).where(*where_clauses).limit(1)
  82. else:
  83. stmt = select(self.model).where(*where_clauses).limit(1)
  84. try:
  85. item = await db.execute(stmt)
  86. finally:
  87. await db.close()
  88. return item.scalar()
  89. async def fetch_all(self,
  90. db: AsyncSession,
  91. *,
  92. filters: Union[List[Any], Dict[str, Any]],
  93. order_by: List[Tuple[str, int]] = None,
  94. return_fields: List[Any] = None) -> List[ModelType]:
  95. where_clauses = self.parse_filter(filters)
  96. try:
  97. if return_fields:
  98. fields = [getattr(self.model, k) for k in return_fields]
  99. # stmt = select(Bundle("data", *fields)).where(*where_clauses)
  100. stmt = select(*fields).where(*where_clauses)
  101. else:
  102. stmt = select(self.model).where(*where_clauses)
  103. if order_by:
  104. stmt = stmt.order_by(*order_by)
  105. items = (await db.execute(stmt)).scalars().all()
  106. finally:
  107. await db.close()
  108. return items
  109. async def find_all(self,
  110. db: AsyncSession,
  111. *,
  112. filters: Union[List[Any], Dict[str, Any]] = None,
  113. offset: int = None,
  114. limit: int = None,
  115. order_by: List[Tuple[str, int]] = None,
  116. return_fields: List[Any] = None) -> Tuple[int, List[ModelType]]:
  117. """
  118. 批量查询所有数据
  119. """
  120. where_clauses = self.parse_filter(filters)
  121. # 统计总数
  122. count_stmt = select(func.count(self.model.id))
  123. # 指定字段,以List[Tuple[str...]]形式返回
  124. if return_fields:
  125. fetch_stmt = select(Bundle("data", *[getattr(self.model, k) for k in return_fields]))
  126. else:
  127. fetch_stmt = select(self.model)
  128. if where_clauses:
  129. count_stmt = count_stmt.where(*where_clauses)
  130. fetch_stmt = fetch_stmt.where(*where_clauses)
  131. total = (await db.execute(count_stmt)).first()[0]
  132. # 是否有排序
  133. if order_by:
  134. stmt = fetch_stmt.order_by(*order_by)
  135. else:
  136. stmt = fetch_stmt.order_by(*[desc("id")])
  137. # 翻页和数量限制
  138. if offset:
  139. stmt = stmt.offset(offset)
  140. if limit:
  141. stmt = stmt.limit(limit)
  142. # 执行查询返回数据
  143. try:
  144. items = (await db.execute(stmt)).scalars().all()
  145. finally:
  146. await db.close()
  147. return total, items
  148. async def insert_one(self, db: AsyncSession,
  149. obj_in: CreateSchemaType) -> Union[None, ModelType]:
  150. """
  151. 插入单条数据
  152. """
  153. db_obj = self.model(**jsonable_encoder(obj_in, by_alias=False, exclude_none=True))
  154. db.add(db_obj)
  155. try:
  156. await db.commit()
  157. await db.refresh(db_obj)
  158. except Exception as ex:
  159. print(f"[InsertOne Error] {str(ex)}")
  160. return None
  161. else:
  162. return db_obj
  163. finally:
  164. await db.close()
  165. async def insert_many(self, db: AsyncSession,
  166. obj_in: List[CreateSchemaType]) -> Union[None, int]:
  167. """
  168. 批量插入多条数据
  169. """
  170. objs = [self.model(**jsonable_encoder(obj, by_alias=False)) for obj in obj_in]
  171. try:
  172. db.add_all(objs)
  173. await db.commit()
  174. except IntegrityError as ex:
  175. print(f"[InsertMany Error] {str(ex)}")
  176. return None
  177. else:
  178. return len(objs)
  179. finally:
  180. await db.close()
  181. async def update(self, db: AsyncSession, db_obj: ModelType,
  182. obj_in: Union[UpdateSchemaType, Dict[str, Any]]) -> ModelType:
  183. """
  184. 更新数据
  185. """
  186. if isinstance(obj_in, dict):
  187. update_data = obj_in
  188. else:
  189. update_data = obj_in.dict(exclude_none=True)
  190. # 更新字段
  191. for field in update_data:
  192. setattr(db_obj, field, update_data[field])
  193. # 提交入库
  194. try:
  195. db.add(db_obj)
  196. await db.commit()
  197. await db.refresh(db_obj)
  198. finally:
  199. await db.close()
  200. return db_obj
  201. async def delete(self,
  202. db: AsyncSession,
  203. *,
  204. obj_id: int = 0,
  205. where_clauses: Union[List[Any], Dict[str, Any]] = None) -> bool:
  206. """
  207. 删除数据
  208. """
  209. # where clauses
  210. _where_clauses = [self.model.id == obj_id] if obj_id else []
  211. if where_clauses:
  212. where_clauses = self.parse_filter(where_clauses)
  213. _where_clauses.extend(where_clauses)
  214. try:
  215. stmt = delete(self.model).where(*_where_clauses)
  216. await db.execute(stmt)
  217. await db.commit()
  218. except Exception as ex:
  219. print(f"Delete Error: {str(ex)}")
  220. return False
  221. finally:
  222. await db.close()
  223. return True
  224. async def execute(self, db: AsyncSession, stmt) -> Union[ModelType, List[Any], bool]:
  225. """
  226. 执行单条sql语句
  227. """
  228. try:
  229. if stmt.startswith("SELECT"):
  230. items = await db.execute(stmt)
  231. else:
  232. await db.execute(stmt)
  233. await db.commit()
  234. items = True
  235. finally:
  236. await db.close()
  237. return items
  238. async def execute_v2(self, db: AsyncSession, stmt: Union[Select, Update, Delete]):
  239. try:
  240. if isinstance(stmt, Select):
  241. data = (await db.execute(stmt)).fetchall()
  242. else:
  243. await db.execute(stmt)
  244. await db.commit()
  245. data = True
  246. finally:
  247. await db.close()
  248. return data
  249. async def executemany(self, db: AsyncSession, stmts: List[str]) -> List[Union[ModelType, bool]]:
  250. """
  251. 批量执行sql语句
  252. """
  253. data = []
  254. try:
  255. data.extend([await self.execute(db, stmt) for stmt in stmts])
  256. finally:
  257. await db.close()
  258. return data
  259. async def increase(self, db: AsyncSession, stmt) -> bool:
  260. """
  261. 更新数量,类似于 redis.incr() 方法
  262. """
  263. data = True
  264. try:
  265. await db.execute(stmt)
  266. await db.commit()
  267. except Exception as ex:
  268. print(f"[Increase Error] {str(ex)}")
  269. data = False
  270. finally:
  271. await db.close()
  272. return data
  273. async def count(self,
  274. db: AsyncSession,
  275. filters: Union[List[Any], Dict[str, Any]] = None) -> int:
  276. """
  277. 统计数量
  278. """
  279. total = 0
  280. try:
  281. where_clauses = self.parse_filter(filters)
  282. stmt = select(func.count()).select_from(self.model)
  283. if where_clauses:
  284. stmt = stmt.where(*where_clauses)
  285. total += (await db.execute(stmt)).first()[0]
  286. finally:
  287. await db.close()
  288. return total