brain_api.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382
  1. # @Author : lightXu
  2. # @File : brain_api.py
  3. # @Time : 2018/11/21 0021 下午 16:20
  4. import shutil
  5. import requests
  6. import base64
  7. from urllib import parse, request
  8. import cv2
  9. import time
  10. import numpy as np
  11. import pytesseract
  12. from segment.server import ocr_login
  13. from segment.sheet_resolve.tools import utils
  14. import xml.etree.cElementTree as ET
  15. # access_token = '24.82b09618f94abe2a35113177f4eec593.2592000.1546765941.282335-14614857'
  16. access_token = ocr_login()
  17. OCR_BOX_URL = 'https://aip.baidubce.com/rest/2.0/ocr/v1/'
  18. OCR_URL = 'https://aip.baidubce.com/rest/2.0/ocr/v1/'
  19. OCR_HAND_URL = 'https://aip.baidubce.com/rest/2.0/ocr/v1/handwriting'
  20. # OCR_ACCURACY = 'general'
  21. OCR_ACCURACY = 'accurate'
  22. OCR_CLIENT_ID = 'AVH7VGKG8QxoSotp6wG9LyZq'
  23. OCR_CLIENT_SECRET = 'gG7VYvBWLU8Rusnin8cS8Ta4dOckGFl6'
  24. OCR_TOKEN_UPDATE_DATE = 10
  25. def preprocess(img):
  26. scale = 0
  27. dilate = 1
  28. blur = 3
  29. # rescale the image
  30. if scale != 0:
  31. img = cv2.resize(img, None, fx=scale, fy=scale, interpolation=cv2.INTER_CUBIC)
  32. # Convert to gray
  33. img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
  34. # Apply dilation and erosion to remove some noise
  35. if dilate != 0:
  36. kernel = np.ones((dilate, dilate), np.uint8)
  37. img = cv2.dilate(img, kernel, iterations=1)
  38. img = cv2.erode(img, kernel, iterations=1)
  39. # Apply blur to smooth out the edges
  40. if blur != 0:
  41. img = cv2.GaussianBlur(img, (blur, blur), 0)
  42. # Apply threshold to get image with only b&w (binarization)
  43. img = cv2.threshold(img, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU)[1]
  44. return img
  45. def opecv2base64(img):
  46. image = cv2.imencode('.jpg', img)[1]
  47. base64_data = str(base64.b64encode(image))[2:-1]
  48. return base64_data
  49. def get_ocr_raw_result(img, ocr_accuracy=OCR_ACCURACY, language_type='CHN_ENG'):
  50. textmod = {'access_token': access_token}
  51. textmod = parse.urlencode(textmod)
  52. url = '{}{}{}{}'.format(OCR_BOX_URL, ocr_accuracy, '?', textmod)
  53. url_general = '{}{}{}{}'.format(OCR_BOX_URL, 'general', '?', textmod)
  54. headers = {'Content-Type': 'application/x-www-form-urlencoded'}
  55. image_type = 'base64'
  56. group_id = 'group001'
  57. user_id = 'usr001'
  58. image = opecv2base64(img)
  59. data = {
  60. 'image_type': image_type,
  61. 'group_id': group_id,
  62. 'user_id': user_id,
  63. 'image': image,
  64. 'detect_direction': 'true',
  65. 'recognize_granularity': 'small',
  66. 'language_type': language_type,
  67. # 'vertexes_location': 'true',
  68. # 'probability': 'true'
  69. }
  70. resp = requests.post(url, data=data, headers=headers, timeout=15).json()
  71. if resp.get('error_msg'):
  72. if 'internal error' in resp.get('error_msg'):
  73. resp = requests.post(url_general, data=data, headers=headers).json()
  74. if resp.get('error_msg'):
  75. raise Exception("ocr {}!".format(resp.get('error_msg')))
  76. else:
  77. raise Exception("ocr {}!".format(resp.get('error_msg')))
  78. return resp
  79. def get_ocr_text_and_coordinate(img, ocr_accuracy=OCR_ACCURACY, language_type='CHN_ENG'):
  80. textmod = {'access_token': access_token}
  81. textmod = parse.urlencode(textmod)
  82. url = '{}{}{}{}'.format(OCR_BOX_URL, ocr_accuracy, '?', textmod)
  83. url_general = '{}{}{}{}'.format(OCR_BOX_URL, 'general', '?', textmod)
  84. headers = {'Content-Type': 'application/x-www-form-urlencoded'}
  85. image_type = 'base64'
  86. group_id = 'group001'
  87. user_id = 'usr001'
  88. image = opecv2base64(img)
  89. data = {
  90. 'image_type': image_type,
  91. 'group_id': group_id,
  92. 'user_id': user_id,
  93. 'image': image,
  94. # 'detect_direction': 'true',
  95. 'recognize_granularity': 'small',
  96. 'language_type': language_type,
  97. # 'vertexes_location': 'true',
  98. # 'probability': 'true'
  99. }
  100. # resp = requests.post(url, data=data, headers=headers, timeout=15).json()
  101. resp = requests.post(url, data=data, headers=headers).json()
  102. if resp.get('error_msg'):
  103. if 'internal error' in resp.get('error_msg'):
  104. resp = requests.post(url_general, data=data, headers=headers).json()
  105. if resp.get('error_msg'):
  106. raise Exception("ocr {}!".format(resp.get('error_msg')))
  107. else:
  108. raise Exception("ocr {}!".format(resp.get('error_msg')))
  109. words_result = resp.get('words_result')
  110. return words_result
  111. def get_ocr_text_and_coordinate0(img, ocr_accuracy=OCR_ACCURACY, language_type='CHN_ENG'):
  112. textmod = {'access_token': access_token}
  113. textmod = parse.urlencode(textmod)
  114. url = '{}{}{}{}'.format(OCR_BOX_URL, ocr_accuracy, '?', textmod)
  115. url_general = '{}{}{}{}'.format(OCR_BOX_URL, 'general', '?', textmod)
  116. headers = {'Content-Type': 'application/x-www-form-urlencoded'}
  117. image_type = 'base64'
  118. group_id = 'group001'
  119. user_id = 'usr001'
  120. image = opecv2base64(img)
  121. data = {
  122. 'image_type': image_type,
  123. 'group_id': group_id,
  124. 'user_id': user_id,
  125. 'image': image,
  126. 'detect_direction': 'false',
  127. 'recognize_granularity': 'small',
  128. 'language_type': language_type,
  129. # 'vertexes_location': 'true',
  130. # 'probability': 'true'
  131. }
  132. # resp = requests.post(url, data=data, headers=headers, timeout=15).json()
  133. resp = requests.post(url, data=data, headers=headers).json()
  134. if resp.get('error_msg'):
  135. if 'internal error' in resp.get('error_msg'):
  136. resp = requests.post(url_general, data=data, headers=headers).json()
  137. if resp.get('error_msg'):
  138. raise Exception("ocr {}!".format(resp.get('error_msg')))
  139. else:
  140. raise Exception("ocr {}!".format(resp.get('error_msg')))
  141. words_result = resp.get('words_result')
  142. return words_result
  143. def get_ocr_text_and_coordinate_direction(img, ocr_accuracy=OCR_ACCURACY, language_type='CHN_ENG'):
  144. textmod = {'access_token': access_token}
  145. textmod = parse.urlencode(textmod)
  146. url = '{}{}{}{}'.format(OCR_BOX_URL, ocr_accuracy, '?', textmod)
  147. url_general = '{}{}{}{}'.format(OCR_BOX_URL, 'general', '?', textmod)
  148. headers = {'Content-Type': 'application/x-www-form-urlencoded'}
  149. image_type = 'base64'
  150. group_id = 'group001'
  151. user_id = 'usr001'
  152. image = opecv2base64(img)
  153. data = {
  154. 'image_type': image_type,
  155. 'group_id': group_id,
  156. 'user_id': user_id,
  157. 'image': image,
  158. 'detect_direction': 'true',
  159. 'recognize_granularity': 'small',
  160. 'language_type': language_type,
  161. # 'vertexes_location': 'true',
  162. # 'probability': 'true'
  163. }
  164. resp = requests.post(url, data=data, headers=headers, timeout=15).json()
  165. if resp.get('error_msg'):
  166. if 'internal error' in resp.get('error_msg'):
  167. resp = requests.post(url_general, data=data, headers=headers).json()
  168. if resp.get('error_msg'):
  169. raise Exception("ocr {}!".format(resp.get('error_msg')))
  170. else:
  171. raise Exception("ocr {}!".format(resp.get('error_msg')))
  172. words_result = resp.get('words_result')
  173. direction = resp.get('direction')
  174. # d_map = {0: 180,
  175. # - 1: 90,
  176. # - 2: -180,
  177. # - 3: -270}
  178. d_map = {0: 180,
  179. -1: 90,
  180. -2: 180,
  181. -3: 90}
  182. return words_result, d_map[direction]
  183. def get_ocr_text_and_coordinate_in_google_format(img, ocr_accuracy=OCR_ACCURACY, language_type='CHN_ENG'):
  184. textmod = {'access_token': access_token}
  185. textmod = parse.urlencode(textmod)
  186. url = '{}{}{}{}'.format(OCR_BOX_URL, ocr_accuracy, '?', textmod)
  187. url_general = '{}{}{}{}'.format(OCR_BOX_URL, 'general', '?', textmod)
  188. headers = {'Content-Type': 'application/x-www-form-urlencoded'}
  189. image_type = 'base64'
  190. group_id = 'group001'
  191. user_id = 'usr001'
  192. image = opecv2base64(img)
  193. data = {
  194. 'image_type': image_type,
  195. 'group_id': group_id,
  196. 'user_id': user_id,
  197. 'image': image,
  198. 'detect_direction': 'true',
  199. 'recognize_granularity': 'small',
  200. 'language_type': language_type,
  201. # 'vertexes_location': 'true',
  202. # 'probability': 'true'
  203. }
  204. resp = requests.post(url, data=data, headers=headers).json()
  205. if resp.get('error_msg'):
  206. if 'internal error' in resp.get('error_msg'):
  207. resp = requests.post(url_general, data=data, headers=headers).json()
  208. if resp.get('error_msg'):
  209. raise Exception("ocr {}!".format(resp.get('error_msg')))
  210. else:
  211. raise Exception("ocr {}!".format(resp.get('error_msg')))
  212. words_result = resp.get('words_result')
  213. dict_list = [item2.get('location') for item in words_result for item2 in item['chars']]
  214. char_list = [item2.get('char') for item in words_result for item2 in item['chars']]
  215. words = [item.get('words') for item in words_result]
  216. matrix = []
  217. for adict in dict_list:
  218. xmin = adict['left']
  219. ymin = adict['top']
  220. xmax = adict['width'] + adict['left']
  221. ymax = adict['top'] + adict['height']
  222. item0 = (xmin, ymin, xmax, ymax)
  223. matrix.append(item0)
  224. res_dict = {'chars': char_list, 'coordinates': matrix, 'words': words}
  225. return res_dict
  226. def change_format_baidu_to_google(words_result):
  227. dict_list = [item2.get('location') for item in words_result for item2 in item['chars']]
  228. char_list = [item2.get('char') for item in words_result for item2 in item['chars']]
  229. words = [item.get('words') for item in words_result]
  230. matrix = []
  231. for adict in dict_list:
  232. xmin = adict['left']
  233. ymin = adict['top']
  234. xmax = adict['width'] + adict['left']
  235. ymax = adict['top'] + adict['height']
  236. item0 = (xmin, ymin, xmax, ymax)
  237. matrix.append(item0)
  238. res_dict = {'chars': char_list, 'coordinates': matrix, 'words': words}
  239. return res_dict
  240. def get_handwriting_ocr_text_and_coordinate_in_google_format(img, words_type='words'):
  241. textmod = {'access_token': access_token}
  242. textmod = parse.urlencode(textmod)
  243. url = '{}{}{}'.format(OCR_HAND_URL, '?', textmod)
  244. headers = {'Content-Type': 'application/x-www-form-urlencoded'}
  245. image = opecv2base64(img)
  246. data = {
  247. 'image': image,
  248. 'recognize_granularity': 'small',
  249. 'words_type': words_type,
  250. }
  251. resp = requests.post(url, data=data, headers=headers).json()
  252. if resp.get('error_msg'):
  253. raise Exception("ocr {}!".format(resp.get('error_msg')))
  254. words_result = resp.get('words_result')
  255. dict_list = [item2.get('location') for item in words_result for item2 in item['chars']]
  256. char_list = [item2.get('char') for item in words_result for item2 in item['chars']]
  257. words = [item.get('words') for item in words_result]
  258. matrix = []
  259. for adict in dict_list:
  260. xmin = adict['left']
  261. ymin = adict['top']
  262. xmax = adict['width'] + adict['left']
  263. ymax = adict['top'] + adict['height']
  264. item0 = (xmin, ymin, xmax, ymax)
  265. matrix.append(item0)
  266. res_dict = {'chars': char_list, 'coordinates': matrix, 'words': words}
  267. return res_dict
  268. def tesseract_boxes_by_py(image, ocr_lang='chi_sim+eng'):
  269. img = preprocess(image)
  270. txt = pytesseract.image_to_boxes(img, lang=ocr_lang, output_type='dict')
  271. h, w = img.shape
  272. char_list = txt['char']
  273. left = txt['left']
  274. bottom = [(h - top) for top in txt['top']]
  275. right = txt['right']
  276. top = [(h - bottom) for bottom in txt['bottom']]
  277. matrix = []
  278. for i, ele in enumerate(left):
  279. matrix.append((ele, top[i], right[i], bottom[i]))
  280. res_dict = {'chars': char_list, 'coordinates': matrix}
  281. return res_dict
  282. def gen_xml_of_per_char(img_path):
  283. img = utils.read_single_img(img_path)
  284. res_dict = get_ocr_text_and_coordinate_in_google_format(img, 'accurate', 'CHN_ENG')
  285. box_list = res_dict['coordinates']
  286. tree = ET.parse(r'./000000-template.xml') # xml tree
  287. for index_num, exam_bbox in enumerate(box_list):
  288. tree = utils.create_xml('{}'.format(res_dict['chars'][index_num]), tree,
  289. exam_bbox[0], exam_bbox[1], exam_bbox[2], exam_bbox[3])
  290. # print(exam_items_bbox)
  291. tree.write(img_path.replace('.jpg', '.xml'))
  292. res_dict_google = tesseract_boxes_by_py(img, ocr_lang='chi_sim+equ+eng')
  293. box_list_g = res_dict_google['coordinates']
  294. tree_g = ET.parse(r'./000000-template.xml') # xml tree
  295. for index_num, exam_bbox in enumerate(box_list_g):
  296. tree_g = utils.create_xml('{}'.format(res_dict_google['chars'][index_num]), tree_g,
  297. exam_bbox[0], exam_bbox[1], exam_bbox[2], exam_bbox[3])
  298. # print(exam_items_bbox)
  299. tree_g.write(img_path.replace('.jpg', '_g.xml'))
  300. shutil.copy(img_path, img_path.replace('.jpg', '_g.jpg'))
  301. if __name__ == '__main__':
  302. img_path0 = r'C:\Users\Administrator\Desktop\sheet\mark-test\002_mark.jpg'
  303. image0 = cv2.imread(img_path0)
  304. t1 = time.time()
  305. res = get_ocr_text_and_coordinate(image0)
  306. t2 = time.time()
  307. print(t2 - t1)
  308. print(res)