sheet_server.py 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399
  1. # @Author : lightXu
  2. # @File : sheet_server.py
  3. # @Time : 2018/12/19 0019 下午 14:33
  4. import json
  5. import os
  6. import shutil
  7. import time
  8. import traceback
  9. import uuid
  10. import xml.etree.cElementTree as ET
  11. import cv2
  12. import numpy as np
  13. from PIL import Image
  14. from django.conf import settings
  15. from segment.sheet_resolve.tools.tf_settings import xml_template_path
  16. import segment.logging_config as logging
  17. from segment.sheet_resolve.analysis.anchor.marker_detection import find_anchor
  18. from segment.sheet_resolve.analysis.resolve import choice, choice_m_row_col
  19. from segment.sheet_resolve.analysis.resolve import cloze
  20. from segment.sheet_resolve.analysis.resolve import exam_number_row_col
  21. from segment.sheet_resolve.analysis.resolve import sheet
  22. from segment.sheet_resolve.analysis.resolve import solve, solve_with_number, cloze_with_number
  23. from segment.sheet_resolve.analysis.sheet.analysis_sheet import box_region_format, question_number_format
  24. from segment.sheet_resolve.analysis.sheet.sheet_points import get_sheet_points
  25. from segment.sheet_resolve.analysis.sheet.sheet_points_total import get_sheet_number_total
  26. from segment.sheet_resolve.tools import utils
  27. from segment.sheet_resolve.tools.brain_api import get_ocr_text_and_coordinate, change_format_baidu_to_google
  28. from segment.sheet_resolve.analysis.sheet.sheet_points_by_nlp import get_sheet_points_by_nlp
  29. logger = logging.getLogger(settings.LOGGING_TYPE)
  30. def decide_blank_sheet(image):
  31. if len(image.shape) <= 2:
  32. gray_image = image
  33. else:
  34. gray_image = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY)
  35. height = gray_image.shape[0]
  36. width = gray_image.shape[1]
  37. if max(height, width) > 800:
  38. percent = max(height, width) / 800
  39. new_x = int(width * percent)
  40. new_y = int(height * percent)
  41. gray_image = cv2.resize(gray_image, (new_x, new_y), interpolation=cv2.INTER_AREA)
  42. if height > width: # 纵向
  43. image = gray_image[height//2:, :]
  44. PIX_VALUE_LOW = 25.0 # 二进制参数
  45. PIX_VALUE_HIGH = 220 # 原始图像参数
  46. else: # 横向
  47. image = gray_image[:, width // 2:]
  48. PIX_VALUE_LOW = 15.0
  49. PIX_VALUE_HIGH = 250
  50. bin_img = cv2.threshold(image, 0, 255, cv2.THRESH_BINARY_INV + cv2.THRESH_OTSU)[1]
  51. bin_img_mean = np.mean(bin_img)
  52. img_raw_mean = np.mean(image)
  53. print(bin_img_mean, img_raw_mean)
  54. blank_cond = bin_img_mean < PIX_VALUE_LOW or img_raw_mean > PIX_VALUE_HIGH
  55. return blank_cond
  56. def convert_pil_to_jpeg(raw_img):
  57. if raw_img.mode == 'L':
  58. channels = raw_img.split()
  59. img = Image.merge("RGB", (channels[0], channels[0], channels[0]))
  60. elif raw_img.mode == 'RGB':
  61. img = raw_img
  62. elif raw_img.mode == 'RGBA':
  63. img = Image.new("RGB", raw_img.size, (255, 255, 255))
  64. img.paste(raw_img, mask=raw_img.split()[3]) # 3 is the alpha channel
  65. else:
  66. img = raw_img
  67. open_cv_image = np.array(img)
  68. return img, open_cv_image
  69. def handle_uploaded_xml_file(f, save_path):
  70. with open(save_path, 'wb+') as destination:
  71. for chunk in f.chunks():
  72. destination.write(chunk)
  73. def generate_serial_number(time_str, sheet_big_boxes):
  74. if len(sheet_big_boxes.objects.all()) < 1:
  75. last_number_gen = time_str + '000001'
  76. else:
  77. objects = sheet_big_boxes.objects.latest('update_time')
  78. last_number_in_db = objects.series_number
  79. if time_str in last_number_in_db[0:9]:
  80. last_number_gen = str(int(last_number_in_db) + 1)
  81. else:
  82. last_number_gen = time_str + '000001'
  83. return last_number_gen
  84. def save_raw_image_with_paper_id(subject, paper_id, img_file, analysis_type):
  85. time_str = time.strftime('%Y-%m-%d', time.localtime(time.time()))
  86. # 随机生成新的图片名,自定义路径。
  87. ext = img_file.name.split('.')[-1]
  88. # raw_name = img_file.name[0:-len(ext) - 1]
  89. # file_name = '{}_{}.{}'.format(raw_name, uuid.uuid4().hex[:10], 'jpg')
  90. file_name = '{}.{}'.format(paper_id, ext)
  91. raw_img = Image.open(img_file) # 读取上传的网络图像
  92. save_dir = os.path.join(settings.MEDIA_ROOT, analysis_type, subject, time_str)
  93. if not os.path.exists(save_dir):
  94. os.makedirs(save_dir)
  95. save_path = os.path.join(save_dir, file_name)
  96. pil_img, open_cv_image = convert_pil_to_jpeg(raw_img)
  97. try:
  98. pil_img.save(save_path)
  99. except Exception as e:
  100. raise e
  101. url_path = os.path.join(settings.MEDIA_URL, analysis_type, subject, time_str, file_name).replace('\\', '/')
  102. return save_path, open_cv_image, url_path
  103. def save_raw_image_without_segment(subject, datetime, img_file, analysis_type):
  104. # 随机生成新的图片名,自定义路径。
  105. ext = img_file.name.split('.')[-1]
  106. raw_name = img_file.name[0:-len(ext) - 1]
  107. file_name = '{}_{}.{}'.format(raw_name, uuid.uuid4().hex[:10], 'jpg')
  108. raw_img = Image.open(img_file) # 读取上传的网络图像
  109. save_dir = os.path.join(settings.MEDIA_ROOT, analysis_type, subject, datetime)
  110. if not os.path.exists(save_dir):
  111. os.makedirs(save_dir)
  112. save_path = os.path.join(save_dir, file_name)
  113. pil_img, open_cv_image = convert_pil_to_jpeg(raw_img)
  114. try:
  115. pil_img.save(save_path)
  116. shutil.copy(save_path, save_path.replace('.jpg', '_small.jpg'))
  117. except Exception as e:
  118. raise e
  119. url_path = os.path.join(settings.MEDIA_URL, analysis_type, subject, datetime, file_name).replace('\\', '/')
  120. return save_path, open_cv_image, url_path
  121. def sheet_big_boxes_resolve(series_number, image, saved_path, subject, sheet_sess, ocr=''):
  122. status = 1
  123. conf_thresh_0 = 0.7
  124. mns_thresh_0 = 0.3
  125. sheets_dict_0 = ''
  126. xml_save_path = ''
  127. try:
  128. sheets_dict_0, xml_save_path = sheet(series_number, saved_path, image,
  129. conf_thresh_0, mns_thresh_0, subject, sheet_sess, ocr)
  130. except Exception as e:
  131. status = 0
  132. logger.info('试卷:{} 答题卡区域解析失败: {}'.format(saved_path, e))
  133. return status, sheets_dict_0, xml_save_path
  134. def sheet_small_boxes_resolve(raw_img, sheet_dict, choice_sess, cloze_sess, xml_save_path):
  135. conf_thresh_0 = 0.7
  136. mns_thresh_0 = 0.3
  137. regions = sheet_dict['regions']
  138. classes_name = str([ele['class_name'] for ele in regions])
  139. sheet_dict.pop('regions')
  140. json.dumps(sheet_dict, ensure_ascii=False)
  141. if 'choice' in classes_name:
  142. try:
  143. sheet_dict['choice'] = choice(raw_img, regions, xml_save_path, conf_thresh_0, mns_thresh_0, choice_sess)
  144. except Exception as e:
  145. traceback.print_exc()
  146. logger.info('试卷:{} 选择题区域解析失败: {}'.format(xml_save_path.replace('xml', '.jpg'), e))
  147. if 'exam_number' in classes_name:
  148. try:
  149. # sheet_dict['exam_number'] = exam_number(raw_img, regions, xml_save_path)
  150. sheet_dict['exam_number'] = exam_number_row_col(raw_img, regions, xml_save_path)
  151. except Exception as e:
  152. traceback.print_exc()
  153. logger.info('试卷:{} 考号区域解析失败: {}'.format(xml_save_path.replace('xml', '.jpg'), e))
  154. if 'cloze' in classes_name:
  155. try:
  156. sheet_dict['cloze'] = cloze(raw_img, regions, xml_save_path, conf_thresh_0, mns_thresh_0, cloze_sess)
  157. except Exception as e:
  158. traceback.print_exc()
  159. logger.info('试卷:{} 填空题区域解析失败: {}'.format(xml_save_path.replace('xml', '.jpg'), e))
  160. if 'solve' in classes_name:
  161. try:
  162. # solve_list, mark_list = solve(raw_img, regions, xml_save_path,)
  163. # sheet_dict['solve'] = solve_list
  164. # sheet_dict['mark'] = mark_list
  165. sheet_dict['solve'] = solve(raw_img, regions, xml_save_path)
  166. except Exception as e:
  167. traceback.print_exc()
  168. logger.info('试卷:{} 解答题区域解析失败: {}'.format(xml_save_path.replace('xml', '.jpg'), e))
  169. if 'qr_code' in classes_name:
  170. try:
  171. for ele in regions:
  172. if 'qr_code' == ele['class_name']:
  173. sheet_dict['qr_code'] = ele['bounding_box']
  174. except Exception as e:
  175. traceback.print_exc()
  176. logger.info('试卷:{} 二维码区域解析失败: {}'.format(xml_save_path.replace('xml', '.jpg'), e))
  177. if 'bar_code' in classes_name:
  178. try:
  179. for ele in regions:
  180. if 'bar_code' == ele['class_name']:
  181. sheet_dict['bar_code'] = ele['bounding_box']
  182. except Exception as e:
  183. traceback.print_exc()
  184. logger.info('试卷:{} 条形码区域解析失败: {}'.format(xml_save_path.replace('xml', '.jpg'), e))
  185. return sheet_dict
  186. def sheet_row_col_resolve(raw_img, sheet_dict, choice_sess, cloze_sess, xml_save_path):
  187. conf_thresh_0 = 0.7
  188. mns_thresh_0 = 0.3
  189. regions = sheet_dict['regions']
  190. classes_name = str([ele['class_name'] for ele in regions])
  191. region_tmp = regions.copy()
  192. # json.dumps(sheet_dict, ensure_ascii=False)
  193. # if 'choice' in classes_name:
  194. # try:
  195. # # sheet_dict['choice'] = choice(raw_img, regions, xml_save_path, conf_thresh_0, mns_thresh_0, choice_sess)
  196. # choice_dict_list = choice_row_col(raw_img, regions, xml_save_path, conf_thresh_0, mns_thresh_0, choice_sess)
  197. # if len(choice_dict_list) > 0:
  198. # region_tmp.extend(choice_dict_list)
  199. # except Exception as e:
  200. # traceback.print_exc()
  201. # logger.info('试卷:{} 选择题区域解析失败: {}'.format(xml_save_path.replace('xml', '.jpg'), e))
  202. if 'choice_m' in classes_name:
  203. try:
  204. choice_dict_list = choice_m_row_col(raw_img, regions, xml_save_path)
  205. if len(choice_dict_list) > 0:
  206. region_tmp.extend(choice_dict_list)
  207. except Exception as e:
  208. traceback.print_exc()
  209. logger.info('试卷:{} 选择题区域解析失败: {}'.format(xml_save_path.replace('xml', '.jpg'), e))
  210. if 'exam_number' in classes_name:
  211. try:
  212. # sheet_dict['exam_number'] = exam_number(raw_img, regions, xml_save_path)
  213. exam_number_dict_list = exam_number_row_col(raw_img, regions, xml_save_path)
  214. if len(exam_number_dict_list) > 0:
  215. region_tmp.extend(exam_number_dict_list)
  216. except Exception as e:
  217. traceback.print_exc()
  218. logger.info('试卷:{} 考号区域解析失败: {}'.format(xml_save_path.replace('xml', '.jpg'), e))
  219. if 'cloze' in classes_name:
  220. try:
  221. cloze_dict_list = cloze(raw_img, regions, xml_save_path, conf_thresh_0, mns_thresh_0, cloze_sess)
  222. if len(cloze_dict_list) > 0:
  223. region_tmp.extend(cloze_dict_list)
  224. except Exception as e:
  225. traceback.print_exc()
  226. logger.info('试卷:{} 填空题区域解析失败: {}'.format(xml_save_path.replace('xml', '.jpg'), e))
  227. sheet_dict.update({'regions': region_tmp})
  228. return sheet_dict
  229. def sheet_detail_resolve(raw_img, sheet_dict, xml_save_path, shrink=True):
  230. regions = sheet_dict['regions']
  231. classes_names_list = set([ele['class_name'] for ele in regions])
  232. region_tmp = regions.copy()
  233. # json.dumps(sheet_dict, ensure_ascii=False)
  234. if 'choice_m' in classes_names_list:
  235. try:
  236. choice_dict_list = choice_m_row_col(raw_img, regions, xml_save_path)
  237. if shrink:
  238. for ele in choice_dict_list:
  239. if 'all_small_coordinate' in ele.keys():
  240. ele.pop('all_small_coordinate')
  241. region_tmp = [ele for ele in region_tmp if ele['class_name'] != 'choice_m'] # 重名
  242. if len(choice_dict_list) > 0:
  243. region_tmp.extend(choice_dict_list)
  244. except Exception as e:
  245. traceback.print_exc()
  246. logger.info('试卷:{} 选择题区域解析失败: {}'.format(xml_save_path.replace('xml', '.jpg'), e))
  247. if 'exam_number' in classes_names_list:
  248. try:
  249. exam_number_dict_list = exam_number_row_col(raw_img, regions, xml_save_path)
  250. for ele in exam_number_dict_list:
  251. ele.pop('all_small_coordinate')
  252. if len(exam_number_dict_list) > 0:
  253. region_tmp.extend(exam_number_dict_list)
  254. except Exception as e:
  255. traceback.print_exc()
  256. logger.info('试卷:{} 考号区域解析失败: {}'.format(xml_save_path.replace('xml', '.jpg'), e))
  257. if 'solve' or 'solve0' or 'composition' or 'composition0' in classes_names_list:
  258. try:
  259. solve_number = solve_with_number(region_tmp, xml_save_path)
  260. region_tmp = [ele for ele in region_tmp if 'solve' not in ele['class_name']] # 重名
  261. region_tmp = [ele for ele in region_tmp if 'composition' not in ele['class_name']]
  262. if len(solve_number) > 0:
  263. region_tmp.extend(solve_number)
  264. except Exception as e:
  265. traceback.print_exc()
  266. logger.info('试卷:{} 解答题区域解析失败: {}'.format(xml_save_path.replace('xml', '.jpg'), e))
  267. if 'cloze' in classes_names_list or 'cloze_s' in classes_names_list:
  268. try:
  269. cloze_number = cloze_with_number(region_tmp, xml_save_path)
  270. region_tmp = [ele for ele in region_tmp if 'cloze' not in ele['class_name']] # 重名
  271. if len(cloze_number) > 0:
  272. region_tmp.extend(cloze_number)
  273. except Exception as e:
  274. traceback.print_exc()
  275. logger.info('试卷:{} 解答题区域解析失败: {}'.format(xml_save_path.replace('xml', '.jpg'), e))
  276. sheet_dict.update({'regions': region_tmp})
  277. return sheet_dict
  278. def sheet_points(sheet_dict_list, image_list, ocr_list, if_ocr=False):
  279. sheet_list = []
  280. for index, ele in enumerate(sheet_dict_list):
  281. ocr_res = ocr_list[index]
  282. h, w = image_list[index].shape[0], image_list[index].shape[1]
  283. sheet_dict = {'sheet_dict': sheet_dict_list[index], 'ocr': ocr_res, 'shape': (h, w), 'raw_image': image_list[index]}
  284. sheet_list.append(sheet_dict)
  285. try:
  286. res = get_sheet_points(sheet_list)
  287. sheet_dict_list = [ele['sheet_dict'] for ele in res]
  288. except Exception as e:
  289. traceback.print_exc()
  290. sheet_dict_list = [ele['sheet_dict'] for ele in sheet_list]
  291. try:
  292. sheet_total_list = []
  293. for index, ele in enumerate(sheet_dict_list):
  294. ocr_res = change_format_baidu_to_google(ocr_list[index])
  295. sheet_dict = get_sheet_number_total(ele, ocr_res, image_list[index])
  296. regions_list = sheet_dict['regions']
  297. type_score_ocr = [ele for ele in regions_list if 'type_score_ocr' in ele]
  298. if len(type_score_ocr) == 0:
  299. sheet_total_list.append(sheet_dict)
  300. else:
  301. sheet_dict0 = get_sheet_points_by_nlp(sheet_dict)
  302. sheet_total_list.append(sheet_dict0)
  303. except Exception as e:
  304. traceback.print_exc()
  305. sheet_total_list = sheet_dict_list
  306. if if_ocr:
  307. for index, ele in enumerate(sheet_total_list):
  308. ele.update({'sheet_ocr': ocr_list[index]})
  309. return sheet_total_list
  310. def sheet_format_output(init_number, crt_numbers, sheet_dict, image, subject, shrink):
  311. sheet_dict = box_region_format(sheet_dict, image, subject, shrink)
  312. sheet_dict, init_number, crt_numbers = question_number_format(init_number, crt_numbers, sheet_dict)
  313. return sheet_dict, init_number, crt_numbers
  314. def sheet_anchor(image):
  315. anchor_list = find_anchor(image)
  316. return anchor_list
  317. def gen_xml(sheet_region_dict, xml_path):
  318. tree = ET.parse(xml_template_path)
  319. for index_num, box in enumerate(sheet_region_dict):
  320. if len(box['bounding_box']) > 0:
  321. abcd = box['bounding_box']
  322. name = box["class_name"]
  323. box_tmp = box.copy()
  324. box_tmp.pop('bounding_box')
  325. box_tmp.pop('class_name')
  326. info = str(box_tmp)
  327. name = '{}_{}'.format(name, info)
  328. tree = utils.create_xml(name, tree,
  329. abcd['xmin'], abcd['ymin'],
  330. abcd['xmax'], abcd['ymax'])
  331. tree.write(xml_path)