123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239 |
- #-*-coding:utf-8 -*-
- ##############################
- #python使用ocr完成精确模板匹配
- ##############################
- import cv2
- import os
- import pytesseract
- from pytesseract import Output
- from PIL import Image
- import numpy as np
- import pprint
- from functools import reduce
- import math
- from utils.bdocr import rec_general_num
- def calc_overlap(rec1,rec2):
- """
- """
- if rec1 and rec2:
- x1,y1,w1,h1 = rec1["x"],rec1["y"],rec1["w"],rec1["h"]
- x2,y2,w2,h2 = rec2["x"],rec2["y"],rec2["w"],rec2["h"]
- #重叠区域
- startx = min(x1,x2)
- endx = max(x1+w1,x2+w2)
- width = w1+w2-(endx-startx)
- stary = min(y1,y2)
- endy = max(y1+h1,y2+h2)
- height = h1+h2-(endy-stary)
- if width <= 0 or height <= 0:
- return 0.0
- else:
- area = width * height
- area1 = w1*h1
- area2 = w2*h2
- ratio = float(area)/(area1+area2-area)
- return ratio
- return 0.0
- def crop_img_local(orgimg,new_img,point=()):
- """
- 切割图片
- """
- img = Image.open(orgimg)
- region = img.crop(point)
- region.save(new_img, format="png")
- return new_img
- def write_img(imgpath,data):
- """
- """
- img = cv2.imread(imgpath)
- for item in data:
- x = item["x"]
- y = item["y"]
- w = item["w"]
- h = item["h"]
- pt = (item["x"],item["y"])
- right_bottom = (x + w, y + h)
- cv2.rectangle(img, pt, right_bottom, (0, 0, 255), 2)
- cv2.imwrite("/tmp/templateDetectResult.png",img)
- def recText(imgpath,choices):
- """
- 识别出ABCD选项并分隔出模板图像
- """
- img = cv2.imread(imgpath)
- gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
- cv2.imwrite("/tmp/src_gray.png",gray)
- _,thresh = cv2.threshold(gray,0,250,cv2.THRESH_BINARY_INV | cv2.THRESH_OTSU)
- cv2.imwrite("/tmp/src_thresh.png",thresh)
- hImg,wImg,_ = img.shape
- boxes = pytesseract.image_to_boxes(img)
- for b in boxes.splitlines():
- b = b.split(" ")
- for ans in choices:
- if b[0] == ans:
- x1,y1,x2,y2 = int(b[1]),int(b[2]),int(b[3]),int(b[4])
- w = x2-x1
- h = y2-y1
- #cv2.rectangle(img,(x1-w,hImg-y1+2),(x2+w,hImg-y2-2),(0,0,255),2)
- crop_point = (x1-w-2,hImg-y1-(h+4),x2+w+2,hImg-y2+h+4)
- crop_img_local(imgpath,"/tmp/{}.png".format(ans),crop_point)
- break
- cv2.imwrite("/tmp/ocrResult.png",img)
- def round_int(x):
- """整数个位四舍五入
- """
- if x%100%10 < 8:
- return int(x/10)*10
- else:
- return (int(x/10)+1)*10
- def recTemplate(src,choices):
- """模板匹配
- """
- data = []
- img = cv2.imread(src)
-
- for ans in choices:
- tplpath = "/tmp/{}.png".format(ans)
- template = cv2.imread(tplpath, 0)
- h, w = template.shape[:2]
- img_gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
- res = cv2.matchTemplate(img_gray, template, cv2.TM_CCOEFF_NORMED)
- threshold = 0.8
- loc = np.where(res >= threshold)
- for pt in zip(*loc[::-1]):
- x,y = pt[0],pt[1]
- z = math.sqrt(x**2+y**2)
- data.append({"x":x,"y":y,"w":w,"h":h,"z":z})
- return data
- def format_data(data):
- """通过xy相对位置分组
- """
- groups = []
- data = sorted(data,key=lambda x:x["x"])
- #pprint.pprint(data)
- max_x = data[0]["x"]
- group = []
- for i,item in enumerate(data):
- cur_w = item["w"]
- cur_x = item["x"]
- dis = cur_x - max_x
- if dis < cur_w*2:
- group.append(item)
- else:
- groups.append(group)
- group = [item]
- max_x = cur_x
- groups.append(group)
- #
- for i,item in enumerate(groups):
- groups[i] = list(sorted(item,key=lambda x:x["y"]))
- return groups
- def remove_repeat(data):
- """相同位置的矩形去重
- """
- for i,item in enumerate(data):
- for j,jtem in enumerate(data):
- if i == j:
- continue
- else:
- ratio = calc_overlap(item,jtem)
- if ratio >0.5:
- data[j] = None
- data = [x for x in data if x]
- return data
- def rec_stand_ans_points(srcimg,choices):
- """
- 通过模板匹配识别出每一个选项框的位置
- """
- recText(srcimg,choices)
- data = recTemplate(srcimg,choices)
- data = sorted(data,key=lambda x:x["x"])
- write_img(srcimg,data)
- #去重
- data = remove_repeat(data)
- #先分组
- data = format_data(data)
- return data
- def get_std_xy(qnos,data,choices,order=1):
- """
- order:1/2,1竖排2横排
- 获取每一道选择题的标准坐标
- """
- step = len(choices)
- newdata = []
- if order == 1:
- for item in data:
- choices = [item[i:i+step] for i in range(0,len(item),step)]
- newdata.extend(choices)
- pprint.pprint(newdata)
- #将分好组的数据分配到每一道题上面
- std_xy_data = {}
- for fpoints in newdata:
- fpoints = list(sorted(fpoints,key=lambda x:x["z"]))
- fpoint = fpoints[0]
- for k,v in qnos.items():
- top,left,width,height = v["top"],v["left"],v["width"],v["height"]
- if fpoint["y"] > top-height/2 and fpoint["y"] < top+height/2 and left < fpoint["x"] and fpoint["x"] < left + width:
- std_xy_data[k] = fpoints
- #newdata = list(sorted(newdata,key=lambda x:x["x"]))
- #std_xy_data = dict(zip(qnos,newdata))
- pprint.pprint(std_xy_data)
- return std_xy_data
- def rec_std_ans(srcimg):
- """识别标准答题卡
- """
- qnos,qnos_dct,choices,rank_order = rec_general_num(srcimg)
- data = rec_stand_ans_points(srcimg,choices)
- std_xy_data = get_std_xy(qnos_dct,data,choices,rank_order)
- return std_xy_data
- if __name__ == "__main__":
- srcimg = "/tmp/test6.png"
- srcimg = "/tmp/test_0_crop.png"
- std_ans_data = rec_std_ans(srcimg)
- pprint.pprint(std_ans_data)
|