exam_number_row_column.py 9.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234
  1. import numpy as np
  2. import tensorflow as tf
  3. from segment.sheet_resolve.lib.ssd_model.utils import label_map_util, ops as utils_ops
  4. from segment.sheet_resolve.tools import tf_settings
  5. from segment.sheet_resolve.tools.tf_sess import SsdSess
  6. from PIL import Image
  7. import math
  8. tf_sess_dict = {
  9. 'exam_number_ssd': SsdSess('exam_number_ssd'),
  10. }
  11. exam_number_sess = tf_sess_dict['exam_number_ssd']
  12. sess = exam_number_sess.sess
  13. detection_graph = exam_number_sess.graph
  14. def load_image_into_numpy_array(image):
  15. # print(image)
  16. image = image.convert('RGB')
  17. (im_width, im_height) = image.size
  18. return np.array(image.getdata()).reshape((im_height, im_width, 3)).astype(np.uint8)
  19. def run_inference_for_single_image(image):
  20. ops = detection_graph.get_operations()
  21. all_tensor_names = {output.name for op in ops for output in op.outputs}
  22. tensor_dict = {}
  23. for key in [
  24. 'num_detections', 'detection_boxes', 'detection_scores',
  25. 'detection_classes', 'detection_masks'
  26. ]:
  27. tensor_name = key + ':0'
  28. if tensor_name in all_tensor_names:
  29. tensor_dict[key] = detection_graph.get_tensor_by_name(
  30. tensor_name)
  31. if 'detection_masks' in tensor_dict:
  32. # The following processing is only for single image
  33. detection_boxes = tf.squeeze(tensor_dict['detection_boxes'], [0])
  34. detection_masks = tf.squeeze(tensor_dict['detection_masks'], [0])
  35. # Reframe is required to translate mask from box coordinates to image coordinates and fit the image size.
  36. real_num_detection = tf.cast(tensor_dict['num_detections'][0], tf.int32)
  37. detection_boxes = tf.slice(detection_boxes, [0, 0], [real_num_detection, -1])
  38. detection_masks = tf.slice(detection_masks, [0, 0, 0], [real_num_detection, -1, -1])
  39. detection_masks_reframed = utils_ops.reframe_box_masks_to_image_masks(
  40. detection_masks, detection_boxes, image.shape[0], image.shape[1])
  41. detection_masks_reframed = tf.cast(
  42. tf.greater(detection_masks_reframed, 0.5), tf.uint8)
  43. # Follow the convention by adding back the batch dimension
  44. tensor_dict['detection_masks'] = tf.expand_dims(
  45. detection_masks_reframed, 0)
  46. image_tensor = detection_graph.get_tensor_by_name('image_tensor:0')
  47. # Run inference
  48. # start = time.time()
  49. output_dict = sess.run(tensor_dict,
  50. feed_dict={image_tensor: np.expand_dims(image, 0)})
  51. # print(time.time()-start)
  52. # all outputs are float32 numpy arrays, so convert types as appropriate
  53. output_dict['num_detections'] = int(output_dict['num_detections'][0])
  54. output_dict['detection_classes'] = output_dict[
  55. 'detection_classes'][0].astype(np.uint8)
  56. output_dict['detection_boxes'] = output_dict['detection_boxes'][0]
  57. output_dict['detection_scores'] = output_dict['detection_scores'][0]
  58. if 'detection_masks' in output_dict:
  59. output_dict['detection_masks'] = output_dict['detection_masks'][0]
  60. return output_dict
  61. def image_detect(image_np, category, score_threshold):
  62. image_np = load_image_into_numpy_array(image_np)
  63. detections = []
  64. w, h = image_np.shape[1], image_np.shape[0]
  65. output_dict = run_inference_for_single_image(image_np)
  66. boxes = output_dict['detection_boxes']
  67. scores = output_dict['detection_scores']
  68. labels = output_dict['detection_classes']
  69. indices = np.where(scores > score_threshold)
  70. image_scores = scores[indices]
  71. image_boxes = boxes[indices]
  72. image_labels = labels[indices]
  73. image_detections = np.concatenate(
  74. [image_boxes, np.expand_dims(image_scores, axis=1), np.expand_dims(image_labels, axis=1)], axis=1)
  75. for detection in image_detections:
  76. y0 = int(detection[0] * h)
  77. x0 = int(detection[1] * w)
  78. y1 = int(detection[2] * h)
  79. x1 = int(detection[3] * w)
  80. label_index = int(detection[5])
  81. label_name = category[label_index]['name']
  82. detections.append((x0, y0, x1, y1, label_index, detection[4], label_name))
  83. return detections
  84. def get_exam_number_row_and_col(left, top, image):
  85. im_resize = 512
  86. ''' exam_number resize to 512*512'''
  87. image_src = Image.fromarray(image)
  88. if image_src.mode == 'RGB':
  89. image_src = image_src.convert("L")
  90. w, h = image_src.size
  91. if h > w:
  92. image_src = image_src.resize((int(im_resize / h * w), im_resize))
  93. else:
  94. image_src = image_src.resize((im_resize, int(im_resize / w * h)))
  95. w_, h_ = image_src.size
  96. image_512 = Image.new(image_src.mode, (im_resize, im_resize), (255))
  97. image_512.paste(image_src, [0, 0, w_, h_])
  98. n_z = "0123456789"
  99. category_index = label_map_util.create_category_index_from_labelmap(tf_settings.exam_number_ssd_label,
  100. use_display_name=True)
  101. detections = image_detect(image_512, category_index, 0.5)
  102. if len(detections):
  103. box_xmin = []
  104. box_ymin = []
  105. box_xmax = []
  106. box_ymax = []
  107. x_distance_all = []
  108. y_distance_all = []
  109. x_width_all = []
  110. y_height_all = []
  111. all_small_coordinate = []
  112. border = {}
  113. exam_number_ssd = {}
  114. ssd_column = 1
  115. ssd_row = 1
  116. for index, box in enumerate(detections):
  117. box0 = round(box[0] * (w / w_)) # Map to the original image
  118. box1 = round(box[1] * (h / h_))
  119. box2 = round(box[2] * (w / w_))
  120. box3 = round(box[3] * (h / h_))
  121. if box[-1] == 'border':
  122. border = {'xmin': box0,
  123. 'ymin': box1,
  124. 'xmax': box2,
  125. 'ymax': box3
  126. }
  127. # if box[2] - box[0] > 80 or box[3] - box[1] >80:
  128. # continue
  129. else:
  130. box_xmin.append(box0)
  131. box_ymin.append(box1)
  132. box_xmax.append(box2)
  133. box_ymax.append(box3)
  134. small_coordinate = {'xmin': box0 + left,
  135. 'ymin': box1 + top,
  136. 'xmax': box2 + left,
  137. 'ymax': box3 + top}
  138. all_small_coordinate.append(small_coordinate)
  139. x_width = box2 - box0
  140. y_height = box3 - box1
  141. x_width_all.append(x_width)
  142. y_height_all.append(y_height)
  143. sorted_xmin = sorted(box_xmin)
  144. sorted_ymin = sorted(box_ymin)
  145. sorted_xmax = sorted(box_xmax)
  146. sorted_ymax = sorted(box_ymax)
  147. # print(sorted_xmin, sorted_ymin)
  148. x_width_all_sorted = sorted(x_width_all, reverse=True)
  149. y_height_all_sorted = sorted(y_height_all, reverse=True)
  150. len_x = len(x_width_all)
  151. len_y = len(y_height_all)
  152. x_width_median = np.median(x_width_all_sorted)
  153. y_height_median = np.median(y_height_all_sorted)
  154. for i in range(len(sorted_xmin) - 1):
  155. x_distance = sorted_xmin[i + 1] - sorted_xmin[i]
  156. y_distance = sorted_ymin[i + 1] - sorted_ymin[i]
  157. if x_distance > (x_width_median - 5):
  158. ssd_column = ssd_column + 1
  159. x_distance_all.append(x_distance)
  160. if y_distance > (y_height_median - 5):
  161. ssd_row = ssd_row + 1
  162. y_distance_all.append(y_distance)
  163. # del the borders where small items are too large
  164. if x_width_all_sorted[i] - x_width_median > x_width_median:
  165. ssd_column = ssd_column - 1
  166. elif x_width_median - x_width_all_sorted[i] > x_width_median:
  167. ssd_column = ssd_column - 1
  168. if y_height_all_sorted[i] - y_height_median > y_height_median:
  169. ssd_row = ssd_row - 1
  170. elif y_height_median - y_height_all_sorted[i] > y_height_median:
  171. ssd_row = ssd_row - 1
  172. # Add rows and columns that might be missed
  173. x_distance_all_sorted = sorted(x_distance_all, reverse=True)
  174. y_distance_all_sorted = sorted(y_height_all, reverse=True)
  175. len_x_distance = len(x_distance_all)
  176. len_y_distance = len(y_distance_all)
  177. x_distance_median = np.median(x_distance_all_sorted)
  178. y_distance_median = np.median(y_distance_all_sorted)
  179. for i in range(len_x_distance):
  180. if x_distance_all[i] > 2 * x_distance_median - 4:
  181. ssd_column = ssd_column + 1
  182. for i in range(len_y_distance):
  183. if y_distance_all[i] > 2 * y_distance_median - 4:
  184. ssd_row = ssd_row + 1
  185. if ssd_row < 10:
  186. test = math.ceil(len_y / ssd_column)
  187. if test > ssd_row:
  188. ssd_row = test
  189. if ssd_row > 10:
  190. ssd_row = 10
  191. average_height = int(np.mean(y_height_all))
  192. average_width = int(np.mean(x_width_all))
  193. location_ssd = {'xmin': sorted_xmin[0] + left,
  194. 'ymin': sorted_ymin[0] + top,
  195. 'xmax': sorted_xmax[-1] + left,
  196. 'ymax': sorted_ymax[-1] + top}
  197. exam_number_ssd = {'bounding_box': location_ssd,
  198. "single_height": average_height,
  199. "single_width": average_width,
  200. "rows": ssd_row,
  201. "cols": ssd_column,
  202. "option": n_z[:ssd_row].replace('', ',')[1:-1],
  203. "direction": 180,
  204. 'class_name': 'exam_number_col_row',
  205. 'all_small_coordinate': all_small_coordinate
  206. }
  207. else:
  208. exam_number_ssd = {}
  209. return exam_number_ssd