sheet_adjust.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486
  1. # @Author : mbq
  2. # @File : sheet_adjust.py
  3. # @Time : 2019/9/26 0026 上午 10:12
  4. import copy
  5. import json
  6. import os
  7. import cv2
  8. import numpy as np
  9. ''' 根据CV检测矩形框 调整模型输出框'''
  10. ''' LSD直线检测 暂时改用 霍夫曼检测'''
  11. ADJUST_CLASS = ['solve', 'solve0', 'composition', 'composition0', 'choice', 'cloze', 'correction']
  12. # 用户自己计算阈值
  13. def custom_threshold(gray, type_inv=cv2.THRESH_BINARY):
  14. # gray = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY) #把输入图像灰度化
  15. h, w = gray.shape[:2]
  16. m = np.reshape(gray, [1, w * h])
  17. mean = m.sum() / (w * h)
  18. ret, binary = cv2.threshold(gray, min(230, mean), 255, type_inv)
  19. return binary
  20. # 开运算
  21. def open_img(image_bin, kera=(5, 5)):
  22. kernel = cv2.getStructuringElement(cv2.MORPH_RECT, kera)
  23. opening = cv2.morphologyEx(image_bin, cv2.MORPH_OPEN, kernel)
  24. return opening
  25. # 闭运算
  26. def close_img(image_bin, kera=(5, 5)):
  27. kernel = cv2.getStructuringElement(cv2.MORPH_RECT, kera)
  28. closing = cv2.morphologyEx(image_bin, cv2.MORPH_CLOSE, kernel)
  29. return closing
  30. # 腐蚀
  31. def erode_img(image, kernel_size):
  32. kernel = np.ones((kernel_size, kernel_size), np.uint8)
  33. erosion = cv2.erode(image, kernel)
  34. return erosion
  35. # 膨胀
  36. def dilation_img(image, kernel_size):
  37. kernel = np.ones((kernel_size, kernel_size), np.uint8)
  38. dilaion = cv2.dilate(image, kernel)
  39. return dilaion
  40. # 图像padding
  41. def image_padding(image, padding_w, padding_h):
  42. h, w = image.shape[:2]
  43. if 3 == len(image.shape):
  44. image_new = np.zeros((h + padding_h, w + padding_w, 3), np.uint8)
  45. else:
  46. image_new = np.zeros((h + padding_h, w + padding_w), np.uint8)
  47. image_new[int(padding_h / 2):int(padding_h / 2) + h, int(padding_w / 2):int(padding_w / 2) + w] = image
  48. return image_new
  49. def horizontal_projection(img_bin, mut=0):
  50. """水平方向投影"""
  51. h, w = img_bin.shape[:2]
  52. hist = [0 for i in range(w)]
  53. for x in range(w):
  54. tmp = 0
  55. for y in range(h):
  56. if img_bin[y][x]:
  57. tmp += 1
  58. if tmp > mut:
  59. hist[x] = tmp
  60. return hist
  61. def vertical_projection(img_bin, mut=0):
  62. """垂直方向投影"""
  63. h, w = img_bin.shape[:2]
  64. hist = [0 for i in range(h)]
  65. for y in range(h):
  66. tmp = 0
  67. for x in range(w):
  68. if img_bin[y][x]:
  69. tmp += 1
  70. if tmp > mut:
  71. hist[y] = tmp
  72. return hist
  73. def get_white_blok_pos(arry, blok_w=0):
  74. """获取投影结果中的白色块"""
  75. pos = []
  76. start = 1
  77. x0 = 0
  78. x1 = 0
  79. for idx, val in enumerate(arry):
  80. if start:
  81. if val:
  82. x0 = idx
  83. start = 0
  84. else:
  85. if 0 == val:
  86. x1 = idx
  87. start = 1
  88. if x1 - x0 > blok_w:
  89. pos.append((x0, x1))
  90. if 0 == start:
  91. x1 = len(arry) - 1
  92. if x1 - x0 > blok_w:
  93. pos.append((x0, x1))
  94. return pos
  95. def get_decide_boberLpa(itemRe, itemGT):
  96. """
  97. IOU 计算
  98. """
  99. x1 = int(itemRe[0])
  100. y1 = int(itemRe[1])
  101. x1_ = int(itemRe[2])
  102. y1_ = int(itemRe[3])
  103. width1 = x1_ - x1
  104. height1 = y1_ - y1
  105. x2 = int(float(itemGT[0]))
  106. y2 = int(float(itemGT[1]))
  107. x2_ = int(float(itemGT[2]))
  108. y2_ = int(float(itemGT[3]))
  109. width2 = x2_ - x2
  110. height2 = y2_ - y2
  111. endx = max(x1_, x2_)
  112. startx = min(x1, x2)
  113. width = width1 + width2 - (endx - startx)
  114. endy = max(y1_, y2_)
  115. starty = min(y1, y2)
  116. height = height1 + height2 - (endy - starty)
  117. AreaJc = 0
  118. ratio = 0.0
  119. if width <= 0 or height <= 0:
  120. res = 0
  121. else:
  122. AreaJc = width * height
  123. AreaRe = width1 * height1
  124. AreaGT = width2 * height2
  125. ratio = float(AreaJc) / float((AreaGT + AreaRe - AreaJc))
  126. return ratio
  127. # 查找连通区域 微调专用 不通用
  128. def get_contours(image):
  129. # image = cv2.imread(img_path,0)
  130. # if debug: plt_imshow(image)
  131. image_binary = custom_threshold(image)
  132. # if debug: plt_imshow(image_binary)
  133. # if debug: cv2.imwrite(os.path.join(file_dir,"bin.jpg"),image_binary)
  134. image_dilation = open_img(image_binary, kera=(5, 1))
  135. image_dilation = open_img(image_dilation, kera=(1, 5))
  136. # if debug: plt_imshow(image_dilation)
  137. # if debug: cv2.imwrite(os.path.join(file_dir,"dia.jpg"),image_dilation)
  138. _, labels, stats, centers = cv2.connectedComponentsWithStats(image_dilation)
  139. rects = []
  140. img_h, img_w = image.shape[:2]
  141. for box in stats:
  142. x0 = int(box[0])
  143. y0 = int(box[1])
  144. w = int(box[2])
  145. h = int(box[3])
  146. area = int(box[4])
  147. if w < img_w / 5 or w > img_w - 10 or h < 50 or h > img_h - 10: # 常见框大小限定
  148. continue
  149. if img_w > img_h: # 多栏答题卡 w大于宽度的一般肯定是错误的框
  150. if w > img_w / 2:
  151. continue
  152. if area < w * h / 3: # 大框套小框 中空白色区域形成的面积 排除
  153. continue
  154. rects.append((x0, y0, x0 + w, y0 + h))
  155. return rects
  156. def adjust_alarm_info(image, box):
  157. """
  158. 调整上下坐标 排除内部含有了边框线情况
  159. 左右调整只有100%确认的 从边界开始遇到的第一个非0列就终止 误伤情况太多
  160. LSD算法转不过来 霍夫曼检测不靠谱 连通区域测试后排除误伤情况太多 改用投影
  161. image: 灰度 非 二值图
  162. box : 坐标信息
  163. """
  164. # debug
  165. # debug = 0
  166. if image is None:
  167. print("error image")
  168. return box
  169. img_box = image[box[1]:box[3], box[0]:box[2]]
  170. h, w = img_box.shape[:2]
  171. # debug
  172. # if debug: ia.imshow(img_box)
  173. img_bin = custom_threshold(img_box, type_inv=cv2.THRESH_BINARY_INV)
  174. img_padding = image_padding(img_bin, 100, 100)
  175. img_close = close_img(img_padding, kera=(30, 3))
  176. img_back = img_close[50:50 + h, 50:50 + w]
  177. # debug
  178. # if debug: ia.imshow(img_back)
  179. # 垂直投影 找 left top
  180. hist_vert = vertical_projection(img_back, mut=h / 4)
  181. # debug
  182. # if debug:
  183. # print(hist_vert)
  184. # black_img_h = np.zeros_like(img_back)
  185. # for idx, val in enumerate(hist_vert):
  186. # if (val == 0):
  187. # continue
  188. # for x in range(val):
  189. # black_img_h[idx][x] = 255
  190. # ia.imshow(black_img_h)
  191. y_pos = get_white_blok_pos(hist_vert, 2)
  192. if (len(y_pos) == 0):
  193. return box
  194. # 获取最大的作为alarm_info的区域
  195. max_id = 0
  196. max_len = 0
  197. for idx, pos_tmp in enumerate(y_pos):
  198. pos_len = abs(pos_tmp[1] - pos_tmp[0])
  199. if (pos_len > max_len):
  200. max_id = idx
  201. max_len = pos_len
  202. # debug to show
  203. # if debug:
  204. # img_show = cv2.cvtColor(img_box, cv2.COLOR_GRAY2BGR)
  205. # cv2.line(img_show, (0, y_pos[max_id][0]), (w - 1, y_pos[max_id][0]), (0, 0, 255), 2)
  206. # cv2.line(img_show, (0, y_pos[max_id][1]), (w - 1, y_pos[max_id][1]), (0, 0, 255), 2)
  207. # ia.imshow(img_show)
  208. # 左右 的微调
  209. img_next = img_bin[y_pos[max_id][0]:y_pos[max_id][1], 0:w - 1]
  210. img_lr_close = open_img(img_next, kera=(1, 1))
  211. img_lr_close = close_img(img_lr_close, kera=(3, 1))
  212. # debug
  213. # if debug: ia.imshow(img_lr_close)
  214. hist_proj = horizontal_projection(img_lr_close, mut=1)
  215. w_len = len(hist_proj)
  216. new_left = 0
  217. new_right = w_len - 1
  218. b_flag = [0, 0]
  219. for idx, val in enumerate(hist_proj):
  220. if 0 == b_flag[0]:
  221. if val != 0:
  222. new_left = idx
  223. b_flag[0] = 1
  224. if 0 == b_flag[1]:
  225. if hist_proj[w_len - 1 - idx] != 0:
  226. new_right = w_len - idx - 1
  227. b_flag[1] = 1
  228. if b_flag[0] and b_flag[1]:
  229. break
  230. new_top = box[1] + y_pos[max_id][0]
  231. new_bottom = box[1] + y_pos[max_id][1]
  232. new_left += box[0]
  233. new_right += box[0]
  234. box[1] = new_top
  235. box[3] = new_bottom
  236. box[0] = new_left
  237. box[2] = new_right
  238. return box
  239. def adjust_zg_info(image, box, cv_boxes):
  240. """
  241. 调整大区域的box
  242. 1、cvbox要与box纵坐标有交叉
  243. 2、IOU值大于0。8时 默认相等拷贝区域坐标
  244. """
  245. if image is None:
  246. return box
  247. min_rotio = 0.5
  248. img_box = image[box[1]:box[3], box[0]:box[2]]
  249. h, w = img_box.shape[:2]
  250. jc_boxes = [] # 记录与box存在交叉的 cv_boxes
  251. tmp_rotio = 0
  252. rc_mz = box
  253. for idx, cv_box in enumerate(cv_boxes):
  254. if (box[1] - 10) > (cv_box[3]): # 首先要保证纵坐标有交叉
  255. continue
  256. if (box[3] + 10) < cv_box[1]:
  257. continue
  258. jc_x = max(box[0], cv_box[0])
  259. jc_y = min(box[2], cv_box[2])
  260. bj_x = min(box[0], cv_box[0])
  261. bj_y = max(box[2], cv_box[2])
  262. rt = abs(jc_y - jc_x) * 1.0 / abs(bj_y - bj_x) * 1.0
  263. if rt < min_rotio:
  264. continue
  265. jc_boxes.append(cv_box)
  266. if rt > tmp_rotio:
  267. rc_mz = cv_box
  268. tmp_rotio = rt
  269. # 判断 调整
  270. if len(jc_boxes) != 0:
  271. box[0] = rc_mz[0]
  272. box[2] = rc_mz[2]
  273. b_find = 0
  274. frotio = 0.0
  275. rc_biggst = rc_mz
  276. for mz_box in jc_boxes:
  277. iou = get_decide_boberLpa(mz_box, box)
  278. if iou > 0.8:
  279. b_find = 1
  280. frotio = iou
  281. rc_biggst = mz_box
  282. if b_find:
  283. box[1] = rc_biggst[1]
  284. box[3] = rc_biggst[3]
  285. return box
  286. def adjust_item_edge(img_path, reback_json):
  287. """
  288. 根据图像的CV分析结果和 模型直接输出结果 对模型输出的边框做微调
  289. 1、外接矩形查找
  290. 2、LSD直线检测 替换方法 霍夫曼直线检测
  291. 3、只处理有把握的情况 任何含有不确定因素的一律不作任何处理
  292. img_path: 待处理图像绝对路径
  293. re_json : 模型输出结果
  294. """
  295. debug = 1
  296. # 存放新的结果
  297. re_json = copy.deepcopy(reback_json)
  298. if not os.path.exists(img_path) or 0 == len(re_json):
  299. return
  300. image = cv2.imread(img_path, 0)
  301. # 获取CV连通区域结果
  302. cv_boxes = get_contours(image)
  303. if debug:
  304. print(len(cv_boxes))
  305. image_draw = cv2.cvtColor(image, cv2.COLOR_GRAY2BGR)
  306. # for item in cv_boxes:
  307. # cv2.rectangle(image_draw, (item[0], item[1]), (item[2], item[3]), (0, 0, 250), 2)
  308. # cv2.imwrite(os.path.join(file_dir, "show.jpg"), image_draw)
  309. # 循环处理指定的box
  310. for idx, item in enumerate(re_json):
  311. name = item["class_name"]
  312. box = [item["bounding_box"]["xmin"], item["bounding_box"]["ymin"], item["bounding_box"]["xmax"],
  313. item["bounding_box"]["ymax"]]
  314. # print(name ,box)
  315. if name == "alarm_info" or name == "page" or name == "type_score":
  316. if debug:
  317. cv2.rectangle(image_draw, (box[0], box[1]), (box[2], box[3]), (0, 255, 0), 2)
  318. new_box = adjust_alarm_info(image, box)
  319. if debug:
  320. cv2.rectangle(image_draw, (box[0], box[1]), (box[2], box[3]), (255, 0, 0), 2)
  321. item["bounding_box"]["xmin"] = box[0]
  322. item["bounding_box"]["xmax"] = box[2]
  323. item["bounding_box"]["ymin"] = box[1]
  324. item["bounding_box"]["ymax"] = box[3]
  325. elif (name == "solve" or name == "solve0"
  326. or name == "cloze" or name == "choice"
  327. or name == "composition" or name == "composition0"):
  328. if debug:
  329. cv2.rectangle(image_draw, (box[0], box[1]), (box[2], box[3]), (0, 255, 0), 2)
  330. new_box = adjust_zg_info(image, box, cv_boxes)
  331. if debug:
  332. cv2.rectangle(image_draw, (box[0], box[1]), (box[2], box[3]), (255, 0, 0), 2)
  333. item["bounding_box"]["xmin"] = box[0]
  334. item["bounding_box"]["xmax"] = box[2]
  335. item["bounding_box"]["ymin"] = box[1]
  336. item["bounding_box"]["ymax"] = box[3]
  337. else:
  338. pass
  339. if debug:
  340. cv2.imwrite(os.path.join(r"E:\data\aug_img\adjust", "show.jpg"), image_draw)
  341. return re_json
  342. def adjust_item_edge_by_gray_image(image, reback_json):
  343. '''
  344. 根据图像的CV分析结果和 模型直接输出结果 对模型输出的边框做微调
  345. 1、外接矩形查找
  346. 2、LSD直线检测 替换方法 霍夫曼直线检测
  347. 3、只处理有把握的情况 任何含有不确定因素的一律不作任何处理
  348. img_path: 待处理图像绝对路径
  349. re_json : 模型输出结果
  350. '''
  351. debug = 0
  352. re_json = copy.deepcopy(reback_json)
  353. # 存放新的结果
  354. # 获取CV连通区域结果
  355. if len(image.shape) > 2:
  356. image = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY)
  357. cv_boxes = get_contours(image)
  358. if debug:
  359. print(len(cv_boxes))
  360. image_draw = cv2.cvtColor(image, cv2.COLOR_GRAY2BGR)
  361. # for item in cv_boxes:
  362. # cv2.rectangle(image_draw, (item[0], item[1]), (item[2], item[3]), (0, 0, 250), 2)
  363. # cv2.imwrite(os.path.join(file_dir, "show.jpg"), image_draw)
  364. # 循环处理指定的box
  365. for idx, item in enumerate(re_json):
  366. name = item["class_name"]
  367. box = [item["bounding_box"]["xmin"], item["bounding_box"]["ymin"], item["bounding_box"]["xmax"],
  368. item["bounding_box"]["ymax"]]
  369. # print(name ,box)
  370. if name == "alarm_info" or name == "page" or name == "type_score":
  371. if debug:
  372. cv2.rectangle(image_draw, (box[0], box[1]), (box[2], box[3]), (0, 255, 0), 2)
  373. new_box = adjust_alarm_info(image, box)
  374. if debug:
  375. cv2.rectangle(image_draw, (box[0], box[1]), (box[2], box[3]), (255, 0, 0), 2)
  376. item["bounding_box"]["xmin"] = box[0]
  377. item["bounding_box"]["xmax"] = box[2]
  378. item["bounding_box"]["ymin"] = box[1]
  379. item["bounding_box"]["ymax"] = box[3]
  380. elif name in ADJUST_CLASS:
  381. if debug:
  382. cv2.rectangle(image_draw, (box[0], box[1]), (box[2], box[3]), (0, 255, 0), 2)
  383. new_box = adjust_zg_info(image, box, cv_boxes)
  384. if debug:
  385. cv2.rectangle(image_draw, (box[0], box[1]), (box[2], box[3]), (255, 0, 0), 2)
  386. item["bounding_box"]["xmin"] = box[0]
  387. item["bounding_box"]["xmax"] = box[2]
  388. item["bounding_box"]["ymin"] = box[1]
  389. item["bounding_box"]["ymax"] = box[3]
  390. else:
  391. pass
  392. if debug:
  393. cv2.imwrite(os.path.join(r"E:\data\aug_img\adjust", "show.jpg"), image_draw)
  394. return re_json
  395. # if __name__ == '__main__':
  396. # '''服务端传入数据为json内数据 和图像
  397. # 使用方法:
  398. # new_json = adjust_item_edge(img_path, key_json)
  399. # key_json : regions 数组
  400. # new_json : 调整后的结果 size == key_json.size
  401. # '''
  402. #
  403. # print("前置解析")
  404. # file_dir = r"E:\data\aug_img\adjust"
  405. # img_path = os.path.join(file_dir, "7642572.jpg")
  406. # json_path = os.path.join(file_dir, "7642572.json")
  407. # print(img_path, json_path)
  408. # # 读取json
  409. # output_ios = open(json_path).read()
  410. # output_json = json.loads(output_ios)
  411. # for item in output_json:
  412. # # print(item,output_json[item])
  413. # if (item == "regions"):
  414. # key_json = output_json[item]
  415. # # print(len(key_json))
  416. # for idx, item in enumerate(key_json):
  417. # # print(key_json[idx])
  418. # if (item["class_name"] == "alarm_info"):
  419. # key_json[idx]["bounding_box"]["ymin"] -= 10
  420. # key_json[idx]["bounding_box"]["ymax"] += 10
  421. # # print(key_json[idx])
  422. #
  423. # new_json = adjust_item_edge(img_path, key_json)
  424. # for idx, val in enumerate(key_json):
  425. # print(key_json[idx])
  426. # print(new_json[idx])