sheet_server.py 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430
  1. # @Author : lightXu
  2. # @File : sheet_server.py
  3. # @Time : 2018/12/19 0019 下午 14:33
  4. import itertools
  5. import json
  6. import os
  7. import shutil
  8. import time
  9. import traceback
  10. import uuid
  11. import xml.etree.cElementTree as ET
  12. import cv2
  13. import numpy as np
  14. from PIL import Image
  15. from django.conf import settings
  16. from segment.sheet_resolve.tools.tf_settings import xml_template_path
  17. import segment.logging_config as logging
  18. from segment.sheet_resolve.analysis.anchor.marker_detection import find_anchor
  19. from segment.sheet_resolve.analysis.resolve import choice, choice_m_row_col
  20. from segment.sheet_resolve.analysis.resolve import cloze
  21. from segment.sheet_resolve.analysis.resolve import exam_number_row_col
  22. from segment.sheet_resolve.analysis.resolve import sheet
  23. from segment.sheet_resolve.analysis.resolve import solve, solve_with_number, cloze_with_number
  24. from segment.sheet_resolve.analysis.sheet.analysis_sheet import box_region_format, question_number_format, merge_span_boxes
  25. from segment.sheet_resolve.analysis.sheet.sheet_points import get_sheet_points
  26. from segment.sheet_resolve.analysis.sheet.sheet_points_total import get_sheet_number_total
  27. from segment.sheet_resolve.tools import utils
  28. from segment.sheet_resolve.tools.brain_api import get_ocr_text_and_coordinate, change_format_baidu_to_google
  29. from segment.sheet_resolve.analysis.sheet.sheet_points_by_nlp import get_sheet_points_by_nlp
  30. from segment.sheet_resolve.analysis.sheet.ocr_sheet import sheet_sorted
  31. from segment.sheet_resolve.analysis.sheet.decide_blank import svm_predict
  32. logger = logging.getLogger(settings.LOGGING_TYPE)
  33. def decide_blank_sheet(image, subject):
  34. """
  35. :param image:
  36. :param subject:
  37. :return: true:blank, false:unblank
  38. """
  39. if len(image.shape) <= 2:
  40. gray_image = image
  41. else:
  42. gray_image = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY)
  43. if subject == 'math':
  44. subject_id = 3
  45. blank_cond = svm_predict(gray_image, subject_id)
  46. else:
  47. height = gray_image.shape[0]
  48. width = gray_image.shape[1]
  49. if max(height, width) > 800:
  50. percent = max(height, width) / 800
  51. new_x = int(width * percent)
  52. new_y = int(height * percent)
  53. gray_image = cv2.resize(gray_image, (new_x, new_y), interpolation=cv2.INTER_AREA)
  54. if height > width: # 纵向
  55. image = gray_image[height//2:, :]
  56. PIX_VALUE_LOW = 25.0 # 二进制参数
  57. PIX_VALUE_HIGH = 220 # 原始图像参数
  58. else: # 横向
  59. image = gray_image[:, width // 2:]
  60. PIX_VALUE_LOW = 15.0
  61. PIX_VALUE_HIGH = 250
  62. bin_img = cv2.threshold(image, 0, 255, cv2.THRESH_BINARY_INV + cv2.THRESH_OTSU)[1]
  63. bin_img_mean = np.mean(bin_img)
  64. img_raw_mean = np.mean(image)
  65. print(bin_img_mean, img_raw_mean)
  66. blank_cond = bin_img_mean < PIX_VALUE_LOW or img_raw_mean > PIX_VALUE_HIGH
  67. return blank_cond
  68. def convert_pil_to_jpeg(raw_img):
  69. if raw_img.mode == 'L':
  70. channels = raw_img.split()
  71. img = Image.merge("RGB", (channels[0], channels[0], channels[0]))
  72. elif raw_img.mode == 'RGB':
  73. img = raw_img
  74. elif raw_img.mode == 'RGBA':
  75. img = Image.new("RGB", raw_img.size, (255, 255, 255))
  76. img.paste(raw_img, mask=raw_img.split()[3]) # 3 is the alpha channel
  77. else:
  78. img = raw_img
  79. open_cv_image = np.array(img)
  80. return img, open_cv_image
  81. def handle_uploaded_xml_file(f, save_path):
  82. with open(save_path, 'wb+') as destination:
  83. for chunk in f.chunks():
  84. destination.write(chunk)
  85. def generate_serial_number(time_str, sheet_big_boxes):
  86. if len(sheet_big_boxes.objects.all()) < 1:
  87. last_number_gen = time_str + '000001'
  88. else:
  89. objects = sheet_big_boxes.objects.latest('update_time')
  90. last_number_in_db = objects.series_number
  91. if time_str in last_number_in_db[0:9]:
  92. last_number_gen = str(int(last_number_in_db) + 1)
  93. else:
  94. last_number_gen = time_str + '000001'
  95. return last_number_gen
  96. def save_raw_image_with_paper_id(subject, paper_id, img_file, analysis_type):
  97. time_str = time.strftime('%Y-%m-%d', time.localtime(time.time()))
  98. # 随机生成新的图片名,自定义路径。
  99. ext = img_file.name.split('.')[-1]
  100. # raw_name = img_file.name[0:-len(ext) - 1]
  101. # file_name = '{}_{}.{}'.format(raw_name, uuid.uuid4().hex[:10], 'jpg')
  102. file_name = '{}.{}'.format(paper_id, ext)
  103. raw_img = Image.open(img_file) # 读取上传的网络图像
  104. save_dir = os.path.join(settings.MEDIA_ROOT, analysis_type, subject, time_str)
  105. if not os.path.exists(save_dir):
  106. os.makedirs(save_dir)
  107. save_path = os.path.join(save_dir, file_name)
  108. pil_img, open_cv_image = convert_pil_to_jpeg(raw_img)
  109. try:
  110. pil_img.save(save_path)
  111. except Exception as e:
  112. raise e
  113. url_path = os.path.join(settings.MEDIA_URL, analysis_type, subject, time_str, file_name).replace('\\', '/')
  114. return save_path, open_cv_image, url_path
  115. def save_raw_image_without_segment(subject, datetime, img_file, analysis_type):
  116. # 随机生成新的图片名,自定义路径。
  117. ext = img_file.name.split('.')[-1]
  118. raw_name = img_file.name[0:-len(ext) - 1]
  119. file_name = '{}_{}.{}'.format(raw_name, uuid.uuid4().hex[:10], 'jpg')
  120. raw_img = Image.open(img_file) # 读取上传的网络图像
  121. save_dir = os.path.join(settings.MEDIA_ROOT, analysis_type, subject, datetime)
  122. if not os.path.exists(save_dir):
  123. os.makedirs(save_dir)
  124. save_path = os.path.join(save_dir, file_name)
  125. pil_img, open_cv_image = convert_pil_to_jpeg(raw_img)
  126. try:
  127. pil_img.save(save_path)
  128. shutil.copy(save_path, save_path.replace('.jpg', '_small.jpg'))
  129. except Exception as e:
  130. raise e
  131. url_path = os.path.join(settings.MEDIA_URL, analysis_type, subject, datetime, file_name).replace('\\', '/')
  132. return save_path, open_cv_image, url_path
  133. def sheet_big_boxes_resolve(series_number, image, saved_path, subject, sheet_sess, ocr=''):
  134. status = 1
  135. conf_thresh_0 = 0.7
  136. mns_thresh_0 = 0.3
  137. sheets_dict_0 = ''
  138. xml_save_path = ''
  139. try:
  140. sheets_dict_0, xml_save_path = sheet(series_number, saved_path, image,
  141. conf_thresh_0, mns_thresh_0, subject, sheet_sess, ocr)
  142. except Exception as e:
  143. status = 0
  144. logger.info('试卷:{} 答题卡区域解析失败: {}'.format(saved_path, e))
  145. return status, sheets_dict_0, xml_save_path
  146. def sheet_small_boxes_resolve(raw_img, sheet_dict, choice_sess, cloze_sess, xml_save_path):
  147. conf_thresh_0 = 0.7
  148. mns_thresh_0 = 0.3
  149. regions = sheet_dict['regions']
  150. classes_name = str([ele['class_name'] for ele in regions])
  151. sheet_dict.pop('regions')
  152. json.dumps(sheet_dict, ensure_ascii=False)
  153. if 'choice' in classes_name:
  154. try:
  155. sheet_dict['choice'] = choice(raw_img, regions, xml_save_path, conf_thresh_0, mns_thresh_0, choice_sess)
  156. except Exception as e:
  157. traceback.print_exc()
  158. logger.info('试卷:{} 选择题区域解析失败: {}'.format(xml_save_path.replace('xml', '.jpg'), e))
  159. if 'exam_number' in classes_name:
  160. try:
  161. # sheet_dict['exam_number'] = exam_number(raw_img, regions, xml_save_path)
  162. sheet_dict['exam_number'] = exam_number_row_col(raw_img, regions, xml_save_path)
  163. except Exception as e:
  164. traceback.print_exc()
  165. logger.info('试卷:{} 考号区域解析失败: {}'.format(xml_save_path.replace('xml', '.jpg'), e))
  166. if 'cloze' in classes_name:
  167. try:
  168. sheet_dict['cloze'] = cloze(raw_img, regions, xml_save_path, conf_thresh_0, mns_thresh_0, cloze_sess)
  169. except Exception as e:
  170. traceback.print_exc()
  171. logger.info('试卷:{} 填空题区域解析失败: {}'.format(xml_save_path.replace('xml', '.jpg'), e))
  172. if 'solve' in classes_name:
  173. try:
  174. # solve_list, mark_list = solve(raw_img, regions, xml_save_path,)
  175. # sheet_dict['solve'] = solve_list
  176. # sheet_dict['mark'] = mark_list
  177. sheet_dict['solve'] = solve(raw_img, regions, xml_save_path)
  178. except Exception as e:
  179. traceback.print_exc()
  180. logger.info('试卷:{} 解答题区域解析失败: {}'.format(xml_save_path.replace('xml', '.jpg'), e))
  181. if 'qr_code' in classes_name:
  182. try:
  183. for ele in regions:
  184. if 'qr_code' == ele['class_name']:
  185. sheet_dict['qr_code'] = ele['bounding_box']
  186. except Exception as e:
  187. traceback.print_exc()
  188. logger.info('试卷:{} 二维码区域解析失败: {}'.format(xml_save_path.replace('xml', '.jpg'), e))
  189. if 'bar_code' in classes_name:
  190. try:
  191. for ele in regions:
  192. if 'bar_code' == ele['class_name']:
  193. sheet_dict['bar_code'] = ele['bounding_box']
  194. except Exception as e:
  195. traceback.print_exc()
  196. logger.info('试卷:{} 条形码区域解析失败: {}'.format(xml_save_path.replace('xml', '.jpg'), e))
  197. return sheet_dict
  198. def sheet_row_col_resolve(raw_img, sheet_dict, choice_sess, cloze_sess, xml_save_path):
  199. conf_thresh_0 = 0.7
  200. mns_thresh_0 = 0.3
  201. regions = sheet_dict['regions']
  202. classes_name = str([ele['class_name'] for ele in regions])
  203. region_tmp = regions.copy()
  204. # json.dumps(sheet_dict, ensure_ascii=False)
  205. # if 'choice' in classes_name:
  206. # try:
  207. # # sheet_dict['choice'] = choice(raw_img, regions, xml_save_path, conf_thresh_0, mns_thresh_0, choice_sess)
  208. # choice_dict_list = choice_row_col(raw_img, regions, xml_save_path, conf_thresh_0, mns_thresh_0, choice_sess)
  209. # if len(choice_dict_list) > 0:
  210. # region_tmp.extend(choice_dict_list)
  211. # except Exception as e:
  212. # traceback.print_exc()
  213. # logger.info('试卷:{} 选择题区域解析失败: {}'.format(xml_save_path.replace('xml', '.jpg'), e))
  214. if 'choice_m' in classes_name:
  215. try:
  216. choice_dict_list = choice_m_row_col(raw_img, regions, xml_save_path)
  217. if len(choice_dict_list) > 0:
  218. region_tmp.extend(choice_dict_list)
  219. except Exception as e:
  220. traceback.print_exc()
  221. logger.info('试卷:{} 选择题区域解析失败: {}'.format(xml_save_path.replace('xml', '.jpg'), e))
  222. if 'exam_number' in classes_name:
  223. try:
  224. # sheet_dict['exam_number'] = exam_number(raw_img, regions, xml_save_path)
  225. exam_number_dict_list = exam_number_row_col(raw_img, regions, xml_save_path)
  226. if len(exam_number_dict_list) > 0:
  227. region_tmp.extend(exam_number_dict_list)
  228. except Exception as e:
  229. traceback.print_exc()
  230. logger.info('试卷:{} 考号区域解析失败: {}'.format(xml_save_path.replace('xml', '.jpg'), e))
  231. if 'cloze' in classes_name:
  232. try:
  233. cloze_dict_list = cloze(raw_img, regions, xml_save_path, conf_thresh_0, mns_thresh_0, cloze_sess)
  234. if len(cloze_dict_list) > 0:
  235. region_tmp.extend(cloze_dict_list)
  236. except Exception as e:
  237. traceback.print_exc()
  238. logger.info('试卷:{} 填空题区域解析失败: {}'.format(xml_save_path.replace('xml', '.jpg'), e))
  239. sheet_dict.update({'regions': region_tmp})
  240. return sheet_dict
  241. def sheet_detail_resolve(raw_img, sheet_dict, xml_save_path, shrink=True):
  242. regions = sheet_dict['regions']
  243. classes_names_list = set([ele['class_name'] for ele in regions])
  244. region_tmp = regions.copy()
  245. # json.dumps(sheet_dict, ensure_ascii=False)
  246. if 'choice_m' in classes_names_list:
  247. try:
  248. choice_dict_list = choice_m_row_col(raw_img, regions, xml_save_path)
  249. if shrink:
  250. for ele in choice_dict_list:
  251. if 'all_small_coordinate' in ele.keys():
  252. ele.pop('all_small_coordinate')
  253. region_tmp = [ele for ele in region_tmp if ele['class_name'] != 'choice_m'] # 重名
  254. if len(choice_dict_list) > 0:
  255. region_tmp.extend(choice_dict_list)
  256. except Exception as e:
  257. traceback.print_exc()
  258. logger.info('试卷:{} 选择题区域解析失败: {}'.format(xml_save_path.replace('xml', '.jpg'), e))
  259. if 'exam_number' in classes_names_list:
  260. try:
  261. exam_number_dict_list = exam_number_row_col(raw_img, regions, xml_save_path)
  262. for ele in exam_number_dict_list:
  263. ele.pop('all_small_coordinate')
  264. if len(exam_number_dict_list) > 0:
  265. region_tmp.extend(exam_number_dict_list)
  266. except Exception as e:
  267. traceback.print_exc()
  268. logger.info('试卷:{} 考号区域解析失败: {}'.format(xml_save_path.replace('xml', '.jpg'), e))
  269. if 'solve' or 'solve0' or 'composition' or 'composition0' or 'correction' in classes_names_list:
  270. try:
  271. solve_number = solve_with_number(region_tmp, xml_save_path)
  272. region_tmp = [ele for ele in region_tmp if 'solve' not in ele['class_name']] # 重名
  273. region_tmp = [ele for ele in region_tmp if 'composition' not in ele['class_name']]
  274. region_tmp = [ele for ele in region_tmp if 'correction' not in ele['class_name']]
  275. if len(solve_number) > 0:
  276. region_tmp.extend(solve_number)
  277. except Exception as e:
  278. traceback.print_exc()
  279. logger.info('试卷:{} 解答题区域解析失败: {}'.format(xml_save_path.replace('xml', '.jpg'), e))
  280. if 'cloze' in classes_names_list or 'cloze_s' in classes_names_list:
  281. try:
  282. cloze_number = cloze_with_number(region_tmp, xml_save_path)
  283. region_tmp = [ele for ele in region_tmp if 'cloze' not in ele['class_name']] # 重名
  284. if len(cloze_number) > 0:
  285. region_tmp.extend(cloze_number)
  286. except Exception as e:
  287. traceback.print_exc()
  288. logger.info('试卷:{} 解答题区域解析失败: {}'.format(xml_save_path.replace('xml', '.jpg'), e))
  289. sheet_dict.update({'regions': region_tmp})
  290. return sheet_dict
  291. def sheet_points(sheet_dict_list, image_list, ocr_list, if_ocr=False):
  292. sheet_list = []
  293. for index, ele in enumerate(sheet_dict_list):
  294. ocr_res = ocr_list[index]
  295. h, w = image_list[index].shape[0], image_list[index].shape[1]
  296. sheet_dict = {'sheet_dict': sheet_dict_list[index], 'ocr': ocr_res, 'shape': (h, w), 'raw_image': image_list[index]}
  297. sheet_list.append(sheet_dict)
  298. try:
  299. res = get_sheet_points(sheet_list)
  300. sheet_dict_list = [ele['sheet_dict'] for ele in res]
  301. except Exception as e:
  302. traceback.print_exc()
  303. sheet_dict_list = [ele['sheet_dict'] for ele in sheet_list]
  304. try:
  305. sheet_total_list = []
  306. for index, ele in enumerate(sheet_dict_list):
  307. ocr_res = change_format_baidu_to_google(ocr_list[index])
  308. sheet_dict = get_sheet_number_total(ele, ocr_res, image_list[index])
  309. regions_list = sheet_dict['regions']
  310. type_score_ocr = [ele for ele in regions_list if 'type_score_ocr' in ele]
  311. if len(type_score_ocr) == 0:
  312. sheet_total_list.append(sheet_dict)
  313. else:
  314. sheet_dict0 = get_sheet_points_by_nlp(sheet_dict)
  315. sheet_total_list.append(sheet_dict0)
  316. except Exception as e:
  317. traceback.print_exc()
  318. sheet_total_list = sheet_dict_list
  319. if if_ocr:
  320. for index, ele in enumerate(sheet_total_list):
  321. ele.update({'sheet_ocr': ocr_list[index]})
  322. return sheet_total_list
  323. def sheet_format_output(init_number, crt_numbers, sheet_dict, image, subject, shrink):
  324. # 去除无用的class、改名、加选做
  325. sheet_dict = box_region_format(sheet_dict, image, subject, shrink)
  326. # 排序
  327. col_regions_list = sheet_sorted(sheet_dict["regions"], sheet_dict["col_split"])
  328. # 改题号
  329. for col_regions in col_regions_list:
  330. _, init_number, crt_numbers = question_number_format(init_number, crt_numbers, col_regions)
  331. merge_span_boxes(col_regions_list)
  332. regions = list(itertools.chain(*col_regions_list))
  333. for i, box in enumerate(regions, 1):
  334. box.update({'sort_id': i})
  335. sheet_dict.update({"regions": regions})
  336. return sheet_dict, init_number, crt_numbers
  337. def sheet_anchor(image):
  338. anchor_list = find_anchor(image)
  339. return anchor_list
  340. def gen_xml(sheet_region_dict, xml_path):
  341. tree = ET.parse(xml_template_path)
  342. for index_num, box in enumerate(sheet_region_dict):
  343. if len(box['bounding_box']) > 0:
  344. abcd = box['bounding_box']
  345. name = box["class_name"]
  346. box_tmp = box.copy()
  347. box_tmp.pop('bounding_box')
  348. box_tmp.pop('class_name')
  349. info = str(box_tmp)
  350. name = '{}_{}'.format(name, info)
  351. tree = utils.create_xml(name, tree,
  352. abcd['xmin'], abcd['ymin'],
  353. abcd['xmax'], abcd['ymax'])
  354. tree.write(xml_path)