#-*-coding:utf-8 -*- ############################## #python使用ocr完成精确模板匹配 ############################## import cv2 import os import pytesseract from pytesseract import Output from PIL import Image import numpy as np import pprint from functools import reduce import math from utils.bdocr import rec_general_num def calc_overlap(rec1,rec2): """ """ if rec1 and rec2: x1,y1,w1,h1 = rec1["x"],rec1["y"],rec1["w"],rec1["h"] x2,y2,w2,h2 = rec2["x"],rec2["y"],rec2["w"],rec2["h"] #重叠区域 startx = min(x1,x2) endx = max(x1+w1,x2+w2) width = w1+w2-(endx-startx) stary = min(y1,y2) endy = max(y1+h1,y2+h2) height = h1+h2-(endy-stary) if width <= 0 or height <= 0: return 0.0 else: area = width * height area1 = w1*h1 area2 = w2*h2 ratio = float(area)/(area1+area2-area) return ratio return 0.0 def crop_img_local(orgimg,new_img,point=()): """ 切割图片 """ img = Image.open(orgimg) region = img.crop(point) region.save(new_img, format="png") return new_img def write_img(imgpath,data): """ """ img = cv2.imread(imgpath) for item in data: x = item["x"] y = item["y"] w = item["w"] h = item["h"] pt = (item["x"],item["y"]) right_bottom = (x + w, y + h) cv2.rectangle(img, pt, right_bottom, (0, 0, 255), 2) cv2.imwrite("/tmp/templateDetectResult.png",img) def recText(imgpath,choices): """ 识别出ABCD选项并分隔出模板图像 """ img = cv2.imread(imgpath) gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) cv2.imwrite("/tmp/src_gray.png",gray) _,thresh = cv2.threshold(gray,0,250,cv2.THRESH_BINARY_INV | cv2.THRESH_OTSU) cv2.imwrite("/tmp/src_thresh.png",thresh) hImg,wImg,_ = img.shape boxes = pytesseract.image_to_boxes(img) for b in boxes.splitlines(): b = b.split(" ") for ans in choices: if b[0] == ans: x1,y1,x2,y2 = int(b[1]),int(b[2]),int(b[3]),int(b[4]) w = x2-x1 h = y2-y1 #cv2.rectangle(img,(x1-w,hImg-y1+2),(x2+w,hImg-y2-2),(0,0,255),2) crop_point = (x1-w-2,hImg-y1-(h+4),x2+w+2,hImg-y2+h+4) crop_img_local(imgpath,"/tmp/{}.png".format(ans),crop_point) break cv2.imwrite("/tmp/ocrResult.png",img) def round_int(x): """整数个位四舍五入 """ if x%100%10 < 8: return int(x/10)*10 else: return (int(x/10)+1)*10 def recTemplate(src,choices): """模板匹配 """ data = [] img = cv2.imread(src) for ans in choices: tplpath = "/tmp/{}.png".format(ans) template = cv2.imread(tplpath, 0) h, w = template.shape[:2] img_gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) res = cv2.matchTemplate(img_gray, template, cv2.TM_CCOEFF_NORMED) threshold = 0.8 loc = np.where(res >= threshold) for pt in zip(*loc[::-1]): x,y = int(pt[0]),int(pt[1]) z = math.sqrt(x**2+y**2) data.append({"x":x,"y":y,"w":w,"h":h,"z":z}) return data def format_data(data): """通过xy相对位置分组 """ groups = [] data = sorted(data,key=lambda x:x["x"]) #pprint.pprint(data) max_x = data[0]["x"] group = [] for i,item in enumerate(data): cur_w = item["w"] cur_x = item["x"] dis = cur_x - max_x if dis < cur_w*2: group.append(item) else: groups.append(group) group = [item] max_x = cur_x groups.append(group) # for i,item in enumerate(groups): groups[i] = list(sorted(item,key=lambda x:x["y"])) return groups def remove_repeat(data): """相同位置的矩形去重 """ for i,item in enumerate(data): for j,jtem in enumerate(data): if i == j: continue else: ratio = calc_overlap(item,jtem) if ratio >0.5: data[j] = None data = [x for x in data if x] return data def rec_stand_ans_points(srcimg,choices): """ 通过模板匹配识别出每一个选项框的位置 """ recText(srcimg,choices) data = recTemplate(srcimg,choices) data = sorted(data,key=lambda x:x["x"]) write_img(srcimg,data) #去重 data = remove_repeat(data) #先分组 data = format_data(data) return data def get_std_xy(qnos,data,choices,order=1): """ order:1/2,1竖排2横排 获取每一道选择题的标准坐标 """ step = len(choices) newdata = [] if order == 1: for item in data: choices = [item[i:i+step] for i in range(0,len(item),step)] newdata.extend(choices) pprint.pprint(newdata) #将分好组的数据分配到每一道题上面 std_xy_data = {} for fpoints in newdata: fpoints = list(sorted(fpoints,key=lambda x:x["z"])) fpoint = fpoints[0] for k,v in qnos.items(): top,left,width,height = v["top"],v["left"],v["width"],v["height"] if fpoint["y"] > top-height/2 and fpoint["y"] < top+height/2 and left < fpoint["x"] and fpoint["x"] < left + width: std_xy_data[k] = fpoints #newdata = list(sorted(newdata,key=lambda x:x["x"])) #std_xy_data = dict(zip(qnos,newdata)) pprint.pprint(std_xy_data) return std_xy_data def rec_std_ans(srcimg): """识别标准答题卡 """ qnos,qnos_dct,choices,rank_order = rec_general_num(srcimg) data = rec_stand_ans_points(srcimg,choices) std_xy_data = get_std_xy(qnos_dct,data,choices,rank_order) return std_xy_data if __name__ == "__main__": srcimg = "/tmp/test6.png" srcimg = "/tmp/test_0_crop.png" std_ans_data = rec_std_ans(srcimg) pprint.pprint(std_ans_data)