ansDetect.py 5.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239
  1. #-*-coding:utf-8 -*-
  2. ##############################
  3. #python使用ocr完成精确模板匹配
  4. ##############################
  5. import cv2
  6. import os
  7. import pytesseract
  8. from pytesseract import Output
  9. from PIL import Image
  10. import numpy as np
  11. import pprint
  12. from functools import reduce
  13. import math
  14. from utils.bdocr import rec_general_num
  15. def calc_overlap(rec1,rec2):
  16. """
  17. """
  18. if rec1 and rec2:
  19. x1,y1,w1,h1 = rec1["x"],rec1["y"],rec1["w"],rec1["h"]
  20. x2,y2,w2,h2 = rec2["x"],rec2["y"],rec2["w"],rec2["h"]
  21. #重叠区域
  22. startx = min(x1,x2)
  23. endx = max(x1+w1,x2+w2)
  24. width = w1+w2-(endx-startx)
  25. stary = min(y1,y2)
  26. endy = max(y1+h1,y2+h2)
  27. height = h1+h2-(endy-stary)
  28. if width <= 0 or height <= 0:
  29. return 0.0
  30. else:
  31. area = width * height
  32. area1 = w1*h1
  33. area2 = w2*h2
  34. ratio = float(area)/(area1+area2-area)
  35. return ratio
  36. return 0.0
  37. def crop_img_local(orgimg,new_img,point=()):
  38. """
  39. 切割图片
  40. """
  41. img = Image.open(orgimg)
  42. region = img.crop(point)
  43. region.save(new_img, format="png")
  44. return new_img
  45. def write_img(imgpath,data):
  46. """
  47. """
  48. img = cv2.imread(imgpath)
  49. for item in data:
  50. x = item["x"]
  51. y = item["y"]
  52. w = item["w"]
  53. h = item["h"]
  54. pt = (item["x"],item["y"])
  55. right_bottom = (x + w, y + h)
  56. cv2.rectangle(img, pt, right_bottom, (0, 0, 255), 2)
  57. cv2.imwrite("/tmp/templateDetectResult.png",img)
  58. def recText(imgpath,choices):
  59. """
  60. 识别出ABCD选项并分隔出模板图像
  61. """
  62. img = cv2.imread(imgpath)
  63. gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
  64. cv2.imwrite("/tmp/src_gray.png",gray)
  65. _,thresh = cv2.threshold(gray,0,250,cv2.THRESH_BINARY_INV | cv2.THRESH_OTSU)
  66. cv2.imwrite("/tmp/src_thresh.png",thresh)
  67. hImg,wImg,_ = img.shape
  68. boxes = pytesseract.image_to_boxes(img)
  69. for b in boxes.splitlines():
  70. b = b.split(" ")
  71. for ans in choices:
  72. if b[0] == ans:
  73. x1,y1,x2,y2 = int(b[1]),int(b[2]),int(b[3]),int(b[4])
  74. w = x2-x1
  75. h = y2-y1
  76. #cv2.rectangle(img,(x1-w,hImg-y1+2),(x2+w,hImg-y2-2),(0,0,255),2)
  77. crop_point = (x1-w-2,hImg-y1-(h+4),x2+w+2,hImg-y2+h+4)
  78. crop_img_local(imgpath,"/tmp/{}.png".format(ans),crop_point)
  79. break
  80. cv2.imwrite("/tmp/ocrResult.png",img)
  81. def round_int(x):
  82. """整数个位四舍五入
  83. """
  84. if x%100%10 < 8:
  85. return int(x/10)*10
  86. else:
  87. return (int(x/10)+1)*10
  88. def recTemplate(src,choices):
  89. """模板匹配
  90. """
  91. data = []
  92. img = cv2.imread(src)
  93. for ans in choices:
  94. tplpath = "/tmp/{}.png".format(ans)
  95. template = cv2.imread(tplpath, 0)
  96. h, w = template.shape[:2]
  97. img_gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
  98. res = cv2.matchTemplate(img_gray, template, cv2.TM_CCOEFF_NORMED)
  99. threshold = 0.8
  100. loc = np.where(res >= threshold)
  101. for pt in zip(*loc[::-1]):
  102. x,y = int(pt[0]),int(pt[1])
  103. z = math.sqrt(x**2+y**2)
  104. data.append({"x":x,"y":y,"w":w,"h":h,"z":z})
  105. return data
  106. def format_data(data):
  107. """通过xy相对位置分组
  108. """
  109. groups = []
  110. data = sorted(data,key=lambda x:x["x"])
  111. #pprint.pprint(data)
  112. max_x = data[0]["x"]
  113. group = []
  114. for i,item in enumerate(data):
  115. cur_w = item["w"]
  116. cur_x = item["x"]
  117. dis = cur_x - max_x
  118. if dis < cur_w*2:
  119. group.append(item)
  120. else:
  121. groups.append(group)
  122. group = [item]
  123. max_x = cur_x
  124. groups.append(group)
  125. #
  126. for i,item in enumerate(groups):
  127. groups[i] = list(sorted(item,key=lambda x:x["y"]))
  128. return groups
  129. def remove_repeat(data):
  130. """相同位置的矩形去重
  131. """
  132. for i,item in enumerate(data):
  133. for j,jtem in enumerate(data):
  134. if i == j:
  135. continue
  136. else:
  137. ratio = calc_overlap(item,jtem)
  138. if ratio >0.5:
  139. data[j] = None
  140. data = [x for x in data if x]
  141. return data
  142. def rec_stand_ans_points(srcimg,choices):
  143. """
  144. 通过模板匹配识别出每一个选项框的位置
  145. """
  146. recText(srcimg,choices)
  147. data = recTemplate(srcimg,choices)
  148. data = sorted(data,key=lambda x:x["x"])
  149. write_img(srcimg,data)
  150. #去重
  151. data = remove_repeat(data)
  152. #先分组
  153. data = format_data(data)
  154. return data
  155. def get_std_xy(qnos,data,choices,order=1):
  156. """
  157. order:1/2,1竖排2横排
  158. 获取每一道选择题的标准坐标
  159. """
  160. step = len(choices)
  161. newdata = []
  162. if order == 1:
  163. for item in data:
  164. choices = [item[i:i+step] for i in range(0,len(item),step)]
  165. newdata.extend(choices)
  166. pprint.pprint(newdata)
  167. #将分好组的数据分配到每一道题上面
  168. std_xy_data = {}
  169. for fpoints in newdata:
  170. fpoints = list(sorted(fpoints,key=lambda x:x["z"]))
  171. fpoint = fpoints[0]
  172. for k,v in qnos.items():
  173. top,left,width,height = v["top"],v["left"],v["width"],v["height"]
  174. if fpoint["y"] > top-height/2 and fpoint["y"] < top+height/2 and left < fpoint["x"] and fpoint["x"] < left + width:
  175. std_xy_data[k] = fpoints
  176. #newdata = list(sorted(newdata,key=lambda x:x["x"]))
  177. #std_xy_data = dict(zip(qnos,newdata))
  178. pprint.pprint(std_xy_data)
  179. return std_xy_data
  180. def rec_std_ans(srcimg):
  181. """识别标准答题卡
  182. """
  183. qnos,qnos_dct,choices,rank_order = rec_general_num(srcimg)
  184. data = rec_stand_ans_points(srcimg,choices)
  185. std_xy_data = get_std_xy(qnos_dct,data,choices,rank_order)
  186. return std_xy_data
  187. if __name__ == "__main__":
  188. srcimg = "/tmp/test6.png"
  189. srcimg = "/tmp/test_0_crop.png"
  190. std_ans_data = rec_std_ans(srcimg)
  191. pprint.pprint(std_ans_data)