123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382 |
- # @Author : lightXu
- # @File : brain_api.py
- # @Time : 2018/11/21 0021 下午 16:20
- import shutil
- import requests
- import base64
- from urllib import parse, request
- import cv2
- import time
- import numpy as np
- import pytesseract
- from segment.server import ocr_login
- from segment.sheet_resolve.tools import utils
- import xml.etree.cElementTree as ET
- # access_token = '24.82b09618f94abe2a35113177f4eec593.2592000.1546765941.282335-14614857'
- access_token = ocr_login()
- OCR_BOX_URL = 'https://aip.baidubce.com/rest/2.0/ocr/v1/'
- OCR_URL = 'https://aip.baidubce.com/rest/2.0/ocr/v1/'
- OCR_HAND_URL = 'https://aip.baidubce.com/rest/2.0/ocr/v1/handwriting'
- # OCR_ACCURACY = 'general'
- OCR_ACCURACY = 'accurate'
- OCR_CLIENT_ID = 'AVH7VGKG8QxoSotp6wG9LyZq'
- OCR_CLIENT_SECRET = 'gG7VYvBWLU8Rusnin8cS8Ta4dOckGFl6'
- OCR_TOKEN_UPDATE_DATE = 10
- def preprocess(img):
- scale = 0
- dilate = 1
- blur = 3
- # rescale the image
- if scale != 0:
- img = cv2.resize(img, None, fx=scale, fy=scale, interpolation=cv2.INTER_CUBIC)
- # Convert to gray
- img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
- # Apply dilation and erosion to remove some noise
- if dilate != 0:
- kernel = np.ones((dilate, dilate), np.uint8)
- img = cv2.dilate(img, kernel, iterations=1)
- img = cv2.erode(img, kernel, iterations=1)
- # Apply blur to smooth out the edges
- if blur != 0:
- img = cv2.GaussianBlur(img, (blur, blur), 0)
- # Apply threshold to get image with only b&w (binarization)
- img = cv2.threshold(img, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU)[1]
- return img
- def opecv2base64(img):
- image = cv2.imencode('.jpg', img)[1]
- base64_data = str(base64.b64encode(image))[2:-1]
- return base64_data
- def get_ocr_raw_result(img, ocr_accuracy=OCR_ACCURACY, language_type='CHN_ENG'):
- textmod = {'access_token': access_token}
- textmod = parse.urlencode(textmod)
- url = '{}{}{}{}'.format(OCR_BOX_URL, ocr_accuracy, '?', textmod)
- url_general = '{}{}{}{}'.format(OCR_BOX_URL, 'general', '?', textmod)
- headers = {'Content-Type': 'application/x-www-form-urlencoded'}
- image_type = 'base64'
- group_id = 'group001'
- user_id = 'usr001'
- image = opecv2base64(img)
- data = {
- 'image_type': image_type,
- 'group_id': group_id,
- 'user_id': user_id,
- 'image': image,
- 'detect_direction': 'true',
- 'recognize_granularity': 'small',
- 'language_type': language_type,
- # 'vertexes_location': 'true',
- # 'probability': 'true'
- }
- resp = requests.post(url, data=data, headers=headers, timeout=15).json()
- if resp.get('error_msg'):
- if 'internal error' in resp.get('error_msg'):
- resp = requests.post(url_general, data=data, headers=headers).json()
- if resp.get('error_msg'):
- raise Exception("ocr {}!".format(resp.get('error_msg')))
- else:
- raise Exception("ocr {}!".format(resp.get('error_msg')))
- return resp
- def get_ocr_text_and_coordinate(img, ocr_accuracy=OCR_ACCURACY, language_type='CHN_ENG'):
- textmod = {'access_token': access_token}
- textmod = parse.urlencode(textmod)
- url = '{}{}{}{}'.format(OCR_BOX_URL, ocr_accuracy, '?', textmod)
- url_general = '{}{}{}{}'.format(OCR_BOX_URL, 'general', '?', textmod)
- headers = {'Content-Type': 'application/x-www-form-urlencoded'}
- image_type = 'base64'
- group_id = 'group001'
- user_id = 'usr001'
- image = opecv2base64(img)
- data = {
- 'image_type': image_type,
- 'group_id': group_id,
- 'user_id': user_id,
- 'image': image,
- # 'detect_direction': 'true',
- 'recognize_granularity': 'small',
- 'language_type': language_type,
- # 'vertexes_location': 'true',
- # 'probability': 'true'
- }
- # resp = requests.post(url, data=data, headers=headers, timeout=15).json()
- resp = requests.post(url, data=data, headers=headers).json()
- if resp.get('error_msg'):
- if 'internal error' in resp.get('error_msg'):
- resp = requests.post(url_general, data=data, headers=headers).json()
- if resp.get('error_msg'):
- raise Exception("ocr {}!".format(resp.get('error_msg')))
- else:
- raise Exception("ocr {}!".format(resp.get('error_msg')))
- words_result = resp.get('words_result')
- return words_result
- def get_ocr_text_and_coordinate0(img, ocr_accuracy=OCR_ACCURACY, language_type='CHN_ENG'):
- textmod = {'access_token': access_token}
- textmod = parse.urlencode(textmod)
- url = '{}{}{}{}'.format(OCR_BOX_URL, ocr_accuracy, '?', textmod)
- url_general = '{}{}{}{}'.format(OCR_BOX_URL, 'general', '?', textmod)
- headers = {'Content-Type': 'application/x-www-form-urlencoded'}
- image_type = 'base64'
- group_id = 'group001'
- user_id = 'usr001'
- image = opecv2base64(img)
- data = {
- 'image_type': image_type,
- 'group_id': group_id,
- 'user_id': user_id,
- 'image': image,
- 'detect_direction': 'false',
- 'recognize_granularity': 'small',
- 'language_type': language_type,
- # 'vertexes_location': 'true',
- # 'probability': 'true'
- }
- # resp = requests.post(url, data=data, headers=headers, timeout=15).json()
- resp = requests.post(url, data=data, headers=headers).json()
- if resp.get('error_msg'):
- if 'internal error' in resp.get('error_msg'):
- resp = requests.post(url_general, data=data, headers=headers).json()
- if resp.get('error_msg'):
- raise Exception("ocr {}!".format(resp.get('error_msg')))
- else:
- raise Exception("ocr {}!".format(resp.get('error_msg')))
- words_result = resp.get('words_result')
- return words_result
- def get_ocr_text_and_coordinate_direction(img, ocr_accuracy=OCR_ACCURACY, language_type='CHN_ENG'):
- textmod = {'access_token': access_token}
- textmod = parse.urlencode(textmod)
- url = '{}{}{}{}'.format(OCR_BOX_URL, ocr_accuracy, '?', textmod)
- url_general = '{}{}{}{}'.format(OCR_BOX_URL, 'general', '?', textmod)
- headers = {'Content-Type': 'application/x-www-form-urlencoded'}
- image_type = 'base64'
- group_id = 'group001'
- user_id = 'usr001'
- image = opecv2base64(img)
- data = {
- 'image_type': image_type,
- 'group_id': group_id,
- 'user_id': user_id,
- 'image': image,
- 'detect_direction': 'true',
- 'recognize_granularity': 'small',
- 'language_type': language_type,
- # 'vertexes_location': 'true',
- # 'probability': 'true'
- }
- resp = requests.post(url, data=data, headers=headers, timeout=15).json()
- if resp.get('error_msg'):
- if 'internal error' in resp.get('error_msg'):
- resp = requests.post(url_general, data=data, headers=headers).json()
- if resp.get('error_msg'):
- raise Exception("ocr {}!".format(resp.get('error_msg')))
- else:
- raise Exception("ocr {}!".format(resp.get('error_msg')))
- words_result = resp.get('words_result')
- direction = resp.get('direction')
- # d_map = {0: 180,
- # - 1: 90,
- # - 2: -180,
- # - 3: -270}
- d_map = {0: 180,
- -1: 90,
- -2: 180,
- -3: 90}
- return words_result, d_map[direction]
- def get_ocr_text_and_coordinate_in_google_format(img, ocr_accuracy=OCR_ACCURACY, language_type='CHN_ENG'):
- textmod = {'access_token': access_token}
- textmod = parse.urlencode(textmod)
- url = '{}{}{}{}'.format(OCR_BOX_URL, ocr_accuracy, '?', textmod)
- url_general = '{}{}{}{}'.format(OCR_BOX_URL, 'general', '?', textmod)
- headers = {'Content-Type': 'application/x-www-form-urlencoded'}
- image_type = 'base64'
- group_id = 'group001'
- user_id = 'usr001'
- image = opecv2base64(img)
- data = {
- 'image_type': image_type,
- 'group_id': group_id,
- 'user_id': user_id,
- 'image': image,
- 'detect_direction': 'true',
- 'recognize_granularity': 'small',
- 'language_type': language_type,
- # 'vertexes_location': 'true',
- # 'probability': 'true'
- }
- resp = requests.post(url, data=data, headers=headers).json()
- if resp.get('error_msg'):
- if 'internal error' in resp.get('error_msg'):
- resp = requests.post(url_general, data=data, headers=headers).json()
- if resp.get('error_msg'):
- raise Exception("ocr {}!".format(resp.get('error_msg')))
- else:
- raise Exception("ocr {}!".format(resp.get('error_msg')))
- words_result = resp.get('words_result')
- dict_list = [item2.get('location') for item in words_result for item2 in item['chars']]
- char_list = [item2.get('char') for item in words_result for item2 in item['chars']]
- words = [item.get('words') for item in words_result]
- matrix = []
- for adict in dict_list:
- xmin = adict['left']
- ymin = adict['top']
- xmax = adict['width'] + adict['left']
- ymax = adict['top'] + adict['height']
- item0 = (xmin, ymin, xmax, ymax)
- matrix.append(item0)
- res_dict = {'chars': char_list, 'coordinates': matrix, 'words': words}
- return res_dict
- def change_format_baidu_to_google(words_result):
- dict_list = [item2.get('location') for item in words_result for item2 in item['chars']]
- char_list = [item2.get('char') for item in words_result for item2 in item['chars']]
- words = [item.get('words') for item in words_result]
- matrix = []
- for adict in dict_list:
- xmin = adict['left']
- ymin = adict['top']
- xmax = adict['width'] + adict['left']
- ymax = adict['top'] + adict['height']
- item0 = (xmin, ymin, xmax, ymax)
- matrix.append(item0)
- res_dict = {'chars': char_list, 'coordinates': matrix, 'words': words}
- return res_dict
- def get_handwriting_ocr_text_and_coordinate_in_google_format(img, words_type='words'):
- textmod = {'access_token': access_token}
- textmod = parse.urlencode(textmod)
- url = '{}{}{}'.format(OCR_HAND_URL, '?', textmod)
- headers = {'Content-Type': 'application/x-www-form-urlencoded'}
- image = opecv2base64(img)
- data = {
- 'image': image,
- 'recognize_granularity': 'small',
- 'words_type': words_type,
- }
- resp = requests.post(url, data=data, headers=headers).json()
- if resp.get('error_msg'):
- raise Exception("ocr {}!".format(resp.get('error_msg')))
- words_result = resp.get('words_result')
- dict_list = [item2.get('location') for item in words_result for item2 in item['chars']]
- char_list = [item2.get('char') for item in words_result for item2 in item['chars']]
- words = [item.get('words') for item in words_result]
- matrix = []
- for adict in dict_list:
- xmin = adict['left']
- ymin = adict['top']
- xmax = adict['width'] + adict['left']
- ymax = adict['top'] + adict['height']
- item0 = (xmin, ymin, xmax, ymax)
- matrix.append(item0)
- res_dict = {'chars': char_list, 'coordinates': matrix, 'words': words}
- return res_dict
- def tesseract_boxes_by_py(image, ocr_lang='chi_sim+eng'):
- img = preprocess(image)
- txt = pytesseract.image_to_boxes(img, lang=ocr_lang, output_type='dict')
- h, w = img.shape
- char_list = txt['char']
- left = txt['left']
- bottom = [(h - top) for top in txt['top']]
- right = txt['right']
- top = [(h - bottom) for bottom in txt['bottom']]
- matrix = []
- for i, ele in enumerate(left):
- matrix.append((ele, top[i], right[i], bottom[i]))
- res_dict = {'chars': char_list, 'coordinates': matrix}
- return res_dict
- def gen_xml_of_per_char(img_path):
- img = utils.read_single_img(img_path)
- res_dict = get_ocr_text_and_coordinate_in_google_format(img, 'accurate', 'CHN_ENG')
- box_list = res_dict['coordinates']
- tree = ET.parse(r'./000000-template.xml') # xml tree
- for index_num, exam_bbox in enumerate(box_list):
- tree = utils.create_xml('{}'.format(res_dict['chars'][index_num]), tree,
- exam_bbox[0], exam_bbox[1], exam_bbox[2], exam_bbox[3])
- # print(exam_items_bbox)
- tree.write(img_path.replace('.jpg', '.xml'))
- res_dict_google = tesseract_boxes_by_py(img, ocr_lang='chi_sim+equ+eng')
- box_list_g = res_dict_google['coordinates']
- tree_g = ET.parse(r'./000000-template.xml') # xml tree
- for index_num, exam_bbox in enumerate(box_list_g):
- tree_g = utils.create_xml('{}'.format(res_dict_google['chars'][index_num]), tree_g,
- exam_bbox[0], exam_bbox[1], exam_bbox[2], exam_bbox[3])
- # print(exam_items_bbox)
- tree_g.write(img_path.replace('.jpg', '_g.xml'))
- shutil.copy(img_path, img_path.replace('.jpg', '_g.jpg'))
- if __name__ == '__main__':
- img_path0 = r'C:\Users\Administrator\Desktop\sheet\mark-test\002_mark.jpg'
- image0 = cv2.imread(img_path0)
- t1 = time.time()
- res = get_ocr_text_and_coordinate(image0)
- t2 = time.time()
- print(t2 - t1)
- print(res)
|