recpaper.py 9.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248
  1. # -*-coding:utf-8 -*-
  2. import os
  3. import cv2
  4. import numpy as np
  5. from imutils.perspective import four_point_transform
  6. class RecPaper:
  7. """识别出试卷上的类答题卡区域
  8. """
  9. def __init__(self, image_path="test.jpg", size="A4", resize=False):
  10. self.image_path = image_path
  11. self.image = cv2.imread(self.image_path)
  12. # 图片的宽高度
  13. self.width = self.image.shape[1]
  14. self.height = self.image.shape[0]
  15. # 图片转成灰度图
  16. self.gray = cv2.cvtColor(self.image, cv2.COLOR_BGR2GRAY)
  17. # 边缘检测后的图
  18. self.edged = cv2.Canny(self.gray, 75, 200)
  19. # 对灰度图进行二值化处理
  20. ret, thresh = cv2.threshold(self.gray, 0, 250, cv2.THRESH_BINARY_INV | cv2.THRESH_OTSU)
  21. self.thresh = thresh
  22. if resize:
  23. if size == "A4":
  24. self.thresh = cv2.resize(thresh, (2480, 3480), interpolation=cv2.INTER_LANCZOS4)
  25. else:
  26. self.thresh = cv2.resize(thresh, (4870, 3480), interpolation=cv2.INTER_LANCZOS4)
  27. # 输出二值化操作后的图像
  28. spt_lst = os.path.splitext(self.image_path)
  29. close_path = spt_lst[0] + '_thresh' + spt_lst[1]
  30. cv2.imwrite(close_path, self.thresh)
  31. def transform(self, point4):
  32. """透视变换
  33. """
  34. four = [[point4[0]["x"], point4[0]["y"]], [point4[1]["x"], point4[1]["y"]],
  35. [point4[3]["x"], point4[3]["y"]], [point4[2]["x"], point4[2]["y"]]]
  36. docCnt = np.array(four)
  37. warped = four_point_transform(self.image, docCnt.reshape(4, 2))
  38. warped = cv2.resize(warped, (2480, 3480), interpolation=cv2.INTER_LANCZOS4)
  39. spt_lst = os.path.splitext(self.image_path)
  40. trpath = spt_lst[0] + '_transform' + spt_lst[1]
  41. cv2.imwrite(trpath, warped)
  42. return trpath
  43. def _erode(self, w=None, h=None):
  44. '''
  45. 图像腐蚀操作
  46. '''
  47. # 腐蚀核大小
  48. edsize = cv2.getStructuringElement(cv2.MORPH_RECT, (w, h))
  49. self.erode_image = cv2.erode(self.thresh, edsize)
  50. # 输出腐蚀操作后的图像
  51. spt_lst = os.path.splitext(self.image_path)
  52. close_path = spt_lst[0] + '_erode' + spt_lst[1]
  53. cv2.imwrite(close_path, self.erode_image)
  54. return self.erode_image
  55. def _erode_dilate(self, w=None, h=None):
  56. '''
  57. 图像腐蚀操作
  58. '''
  59. # 腐蚀核大小
  60. edsize = cv2.getStructuringElement(cv2.MORPH_RECT, (w, h))
  61. self.erode_dilate_image = cv2.dilate(self.erode_image, edsize)
  62. # 输出腐蚀操作后的图像
  63. spt_lst = os.path.splitext(self.image_path)
  64. close_path = spt_lst[0] + '_erode_dilate' + spt_lst[1]
  65. cv2.imwrite(close_path, self.erode_dilate_image)
  66. return self.erode_dilate_image
  67. def _dilate(self, w=None, h=None):
  68. '''
  69. 图像膨胀操作
  70. '''
  71. dilsize = cv2.getStructuringElement(cv2.MORPH_RECT, (w, h))
  72. self.dilate_image = cv2.dilate(self.thresh, dilsize)
  73. # 输出膨胀操作后的图像
  74. spt_lst = os.path.splitext(self.image_path)
  75. close_path = spt_lst[0] + '_dilate' + spt_lst[1]
  76. cv2.imwrite(close_path, self.dilate_image)
  77. return self.dilate_image
  78. def _dilate_erode(self, w=None, h=None):
  79. '''
  80. 图像膨胀操作
  81. '''
  82. dilsize = cv2.getStructuringElement(cv2.MORPH_RECT, (w, h))
  83. self.dilate_erode_image = cv2.erode(self.dilate_image, dilsize)
  84. # 输出膨胀操作后的图像
  85. spt_lst = os.path.splitext(self.image_path)
  86. close_path = spt_lst[0] + '_dilate_erode' + spt_lst[1]
  87. cv2.imwrite(close_path, self.dilate_erode_image)
  88. return self.dilate_erode_image
  89. def _open(self, w=None, h=None):
  90. """开操作先腐蚀后膨胀
  91. """
  92. self._erode(w, h)
  93. # 自定义开操作
  94. self.open_image = self._erode_dilate(w, h)
  95. # 输出闭操作后的图像
  96. spt_lst = os.path.splitext(self.image_path)
  97. close_path = spt_lst[0] + '_open' + spt_lst[1]
  98. cv2.imwrite(close_path, self.open_image)
  99. return self.open_image
  100. def _close(self, w=20, h=40):
  101. """闭操作先膨胀后腐蚀
  102. """
  103. self._dilate()
  104. #自定义闭操作
  105. self.close_image = self._dilate_erode(w, h)
  106. # 输出闭操作后的图像
  107. spt_lst = os.path.splitext(self.image_path)
  108. close_path = spt_lst[0] + '_close' + spt_lst[1]
  109. cv2.imwrite(close_path, self.close_image)
  110. return self.close_image
  111. def get_cnts(self, std_w=None, std_h=None):
  112. """识别出答案轮廓
  113. """
  114. self._open(std_w, std_h)
  115. cnts = cv2.findContours(self.open_image, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE)[0]
  116. all_c = []
  117. if len(cnts) > 0:
  118. cnts = sorted(cnts, key=cv2.contourArea, reverse=True)
  119. all_c = []
  120. for c in cnts:
  121. peri = std_h / 2
  122. approx = cv2.approxPolyDP(c, peri, True)
  123. if len(approx) == 4 or True:
  124. (x, y, w, h) = cv2.boundingRect(c)
  125. # 标识出识别出的机型轮廓并输出
  126. cv2.drawContours(self.image, [c], -1, (255, 0, 255), 3)
  127. spt_lst = os.path.splitext(self.image_path)
  128. draw_path = spt_lst[0] + '_draw' + spt_lst[1]
  129. cv2.imwrite(draw_path, self.image)
  130. all_c.append({'x': x, 'y': y, 'w': w, 'h': h})
  131. return all_c
  132. def get_std_points(self, ow=10, oh=10, std_w=(10, 10), std_h=(10, 10)):
  133. """
  134. """
  135. std_points = []
  136. cnts = self.get_cnts(ow, oh)
  137. for c in cnts:
  138. if c["w"] > std_w[0] and c["w"] < std_w[1]:
  139. std_points.append(c)
  140. std_points = sorted(std_points, key=lambda x: x["x"] + x["y"])
  141. return std_points
  142. def get_ans_points(self, ow=20, oh=10):
  143. """识别考号和客观题答案
  144. """
  145. std_w = (ow,ow*2)
  146. std_h = (oh,oh*2)
  147. std_points = []
  148. cnts = self.get_cnts(ow, oh)
  149. for c in cnts:
  150. #if c["w"] > std_w[0] and c["w"] < std_w[1]:
  151. if c["w"] > std_w[0]:
  152. std_points.append(c)
  153. #std_points = sorted(std_points, key=lambda x: x["x"])
  154. ans_x_list = sorted(std_points, key=lambda x: x["x"])
  155. ans_y_list = sorted(std_points, key=lambda x: x["y"])
  156. return ans_x_list, ans_y_list
  157. def rec_kh(self,kh_y_list,std_x):
  158. """识别和计算考号
  159. """
  160. kh_list = []
  161. #计算考号区域的总长度
  162. kh_w_list = [x["w"] for x in kh_y_list]
  163. kh_w_list.sort()
  164. if kh_w_list[1:-2]:
  165. std_w = sum(kh_w_list[1:-2])/len(kh_w_list[1:-2])
  166. kh_length = std_w * 1.5 * 10
  167. for kh in kh_y_list:
  168. kh_x = kh["x"] + kh_length
  169. dis = kh_x-std_x
  170. kh = round(dis/(std_w*1.5))
  171. kh_list.append(str(kh))
  172. return "".join(kh_list)
  173. def get_kh_ans(self,ow,oh,std_x_list,std_y_list,qno_list):
  174. """
  175. """
  176. ans_x_list,ans_y_list = self.get_ans_points(ow,oh)
  177. std_x = std_x_list[0]["x"]
  178. std_w = std_x_list[0]["w"]
  179. #识别考号
  180. kh_x_list = filter(lambda x:x["x"]<std_x-std_w/2,ans_x_list)
  181. kh_y_list = sorted(kh_x_list,key=lambda x: x["y"])
  182. kh = self.rec_kh(kh_y_list,std_x-std_w)
  183. #识别客观题答案
  184. ans_dct = {}
  185. #遍历题号
  186. for i,qno in enumerate(qno_list):
  187. std_x = std_x_list[i]["x"]
  188. std_w = std_x_list[i]["w"]
  189. #遍历识别出的轮廓
  190. ans_list = []
  191. for ans in ans_x_list:
  192. #判断当前题目是否有填涂
  193. if ans["x"] > std_x - std_w and ans["x"] < std_x + std_w:
  194. #计算填图的答案是ABCD
  195. for j,stdy in enumerate(std_y_list):
  196. std_y = stdy["y"]
  197. std_h = stdy["h"]
  198. if ans["y"] > std_y - std_h and ans["y"] < std_y + std_h:
  199. std_choices = ["A","B","C","D","E","F","G","H"]
  200. ans_list.append(std_choices[j])
  201. ans_dct[qno] = ",".join(ans_list)
  202. return kh,ans_dct
  203. def main():
  204. stdans_config = {"std_x_list": [{"h": 19, "w": 38, "x": 968, "y": 318}, {"h": 20, "w": 38, "x": 1027, "y": 318}, {"h": 20, "w": 39, "x": 1086, "y": 318}, {"h": 20, "w": 38, "x": 1145, "y": 319}, {"h": 20, "w": 38, "x": 1204, "y": 319}, {"h": 20, "w": 38, "x": 1263, "y": 319}, {"h": 20, "w": 39, "x": 1322, "y": 319}, {"h": 20, "w": 37, "x": 1382, "y": 319}, {"h": 21, "w": 38, "x": 1440, "y": 319}, {"h": 20, "w": 38, "x": 1499, "y": 319}, {"h": 21, "w": 38, "x": 1558, "y": 319}, {"h": 21, "w": 38, "x": 1617, "y": 319}, {"h": 21, "w": 38, "x": 1676, "y": 319}, {"h": 20, "w": 38, "x": 1735, "y": 320}, {"h": 20, "w": 38, "x": 1794, "y": 320}], "std_y_list": [{"h": 20, "w": 37, "x": 1869, "y": 128}, {"h": 21, "w": 38, "x": 1868, "y": 177}, {"h": 21, "w": 38, "x": 1868, "y": 227}, {"h": 21, "w": 37, "x": 1868, "y": 277}]}
  205. qno_list = [1,2,3,5,6,9,10,13,14,15,16,17,18,19,20]
  206. std_x_list = stdans_config.get("std_x_list")
  207. std_y_list = stdans_config.get("std_y_list")
  208. img_dect = RecPaper("/tmp/1_0011_crop.jpg")
  209. img_dect.get_kh_ans(20,10,std_x_list,std_y_list,qno_list)
  210. if __name__ == "__main__":
  211. main()