choice_m_row_column.py 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422
  1. # @Author : liu fan
  2. import numpy as np
  3. import tensorflow as tf
  4. from segment.sheet_resolve.lib.ssd_model.utils import label_map_util, ops as utils_ops
  5. from segment.sheet_resolve.tools import tf_settings
  6. from segment.sheet_resolve.tools.tf_sess import SsdSess
  7. from PIL import Image
  8. tf_sess_dict = {
  9. 'choice_ssd': SsdSess('choice_ssd'),
  10. }
  11. choice_ssd_sess = tf_sess_dict['choice_ssd']
  12. sess = choice_ssd_sess.sess
  13. detection_graph = choice_ssd_sess.graph
  14. def load_image_into_numpy_array(image):
  15. # print(image)
  16. image = image.convert('RGB')
  17. (im_width, im_height) = image.size
  18. return np.array(image.getdata()).reshape((im_height, im_width, 3)).astype(np.uint8)
  19. def run_inference_for_single_image(image):
  20. ops = detection_graph.get_operations()
  21. all_tensor_names = {output.name for op in ops for output in op.outputs}
  22. tensor_dict = {}
  23. for key in [
  24. 'num_detections', 'detection_boxes', 'detection_scores',
  25. 'detection_classes', 'detection_masks'
  26. ]:
  27. tensor_name = key + ':0'
  28. if tensor_name in all_tensor_names:
  29. tensor_dict[key] = detection_graph.get_tensor_by_name(
  30. tensor_name)
  31. if 'detection_masks' in tensor_dict:
  32. # The following processing is only for single image
  33. detection_boxes = tf.squeeze(tensor_dict['detection_boxes'], [0])
  34. detection_masks = tf.squeeze(tensor_dict['detection_masks'], [0])
  35. # Reframe is required to translate mask from box coordinates to image coordinates and fit the image size.
  36. real_num_detection = tf.cast(tensor_dict['num_detections'][0], tf.int32)
  37. detection_boxes = tf.slice(detection_boxes, [0, 0], [real_num_detection, -1])
  38. detection_masks = tf.slice(detection_masks, [0, 0, 0], [real_num_detection, -1, -1])
  39. detection_masks_reframed = utils_ops.reframe_box_masks_to_image_masks(
  40. detection_masks, detection_boxes, image.shape[0], image.shape[1])
  41. detection_masks_reframed = tf.cast(
  42. tf.greater(detection_masks_reframed, 0.5), tf.uint8)
  43. # Follow the convention by adding back the batch dimension
  44. tensor_dict['detection_masks'] = tf.expand_dims(
  45. detection_masks_reframed, 0)
  46. image_tensor = detection_graph.get_tensor_by_name('image_tensor:0')
  47. # Run inference
  48. # start = time.time()
  49. output_dict = sess.run(tensor_dict,
  50. feed_dict={image_tensor: np.expand_dims(image, 0)})
  51. # print(time.time()-start)
  52. # all outputs are float32 numpy arrays, so convert types as appropriate
  53. output_dict['num_detections'] = int(output_dict['num_detections'][0])
  54. output_dict['detection_classes'] = output_dict[
  55. 'detection_classes'][0].astype(np.uint8)
  56. output_dict['detection_boxes'] = output_dict['detection_boxes'][0]
  57. output_dict['detection_scores'] = output_dict['detection_scores'][0]
  58. if 'detection_masks' in output_dict:
  59. output_dict['detection_masks'] = output_dict['detection_masks'][0]
  60. return output_dict
  61. def image_detect(image_np, category, score_threshold):
  62. image_np = load_image_into_numpy_array(image_np)
  63. detections = []
  64. w, h = image_np.shape[1], image_np.shape[0]
  65. with tf.device("/device:GPU:{}".format(0)):
  66. output_dict = run_inference_for_single_image(image_np)
  67. boxes = output_dict['detection_boxes']
  68. scores = output_dict['detection_scores']
  69. labels = output_dict['detection_classes']
  70. indices = np.where(scores > score_threshold)
  71. image_scores = scores[indices]
  72. image_boxes = boxes[indices]
  73. image_labels = labels[indices]
  74. image_detections = np.concatenate(
  75. [image_boxes, np.expand_dims(image_scores, axis=1), np.expand_dims(image_labels, axis=1)], axis=1)
  76. for detection in image_detections:
  77. y0 = int(detection[0] * h)
  78. x0 = int(detection[1] * w)
  79. y1 = int(detection[2] * h)
  80. x1 = int(detection[3] * w)
  81. label_index = int(detection[5])
  82. label_name = category[label_index]['name']
  83. detections.append((x0, y0, x1, y1, label_index, detection[4], label_name))
  84. return detections
  85. # def get_choice_m_row_and_col(left, top, image):
  86. # im_resize = 300
  87. # ''' choice_m resize to 300*300'''
  88. # image_src = Image.fromarray(image)
  89. # if image_src.mode == 'RGB':
  90. # image_src = image_src.convert("L")
  91. # w, h = image_src.size
  92. # if h > w:
  93. # image_src = image_src.resize((int(im_resize / h * w), im_resize))
  94. # else:
  95. # image_src = image_src.resize((im_resize, int(im_resize / w * h)))
  96. # w_, h_ = image_src.size
  97. # image_300 = Image.new(image_src.mode, (im_resize, im_resize), (255))
  98. # image_300.paste(image_src, [0, 0, w_, h_])
  99. #
  100. # category_index = label_map_util.create_category_index_from_labelmap(tf_settings.choice_m_ssd_label,
  101. # use_display_name=True)
  102. # detections = image_detect(image_300, category_index, 0.5)
  103. # if len(detections) > 1:
  104. # box_xmin = []
  105. # box_ymin = []
  106. # box_xmax = []
  107. # box_ymax = []
  108. # x_distance_all = []
  109. # y_distance_all = []
  110. # x_width_all = []
  111. # y_height_all = []
  112. # all_small_coordinate = []
  113. # ssd_column = 1
  114. # ssd_row = 1
  115. # count_x = 0
  116. # count_y = 0
  117. # for index, box in enumerate(detections):
  118. # if box[-1] != 'T' and box[2] <= w_ and box[3] <= h_:
  119. # box0 = round(box[0] * (w / w_)) # Map to the original image
  120. # box1 = round(box[1] * (h / h_))
  121. # box2 = round(box[2] * (w / w_))
  122. # box3 = round(box[3] * (h / h_))
  123. # box_xmin.append(box0)
  124. # box_ymin.append(box1)
  125. # box_xmax.append(box2)
  126. # box_ymax.append(box3)
  127. # small_coordinate = {'xmin': box0 + left,
  128. # 'ymin': box1 + top,
  129. # 'xmax': box2 + left,
  130. # 'ymax': box3 + top}
  131. # all_small_coordinate.append(small_coordinate)
  132. # x_width = box2 - box0
  133. # y_height = box3 - box1
  134. # x_width_all.append(x_width)
  135. # y_height_all.append(y_height)
  136. #
  137. # sorted_xmin = sorted(box_xmin)
  138. # sorted_ymin = sorted(box_ymin)
  139. # sorted_xmax = sorted(box_xmax)
  140. # sorted_ymax = sorted(box_ymax)
  141. #
  142. # x_width_all_sorted = sorted(x_width_all, reverse=True)
  143. # y_height_all_sorted = sorted(y_height_all, reverse=True)
  144. # len_x = len(x_width_all)
  145. # len_y = len(y_height_all)
  146. # x_width_median = np.median(x_width_all_sorted)
  147. # y_height_median = np.median(y_height_all_sorted)
  148. #
  149. # for i in range(len(sorted_xmin) - 1):
  150. # x_distance = abs(sorted_xmin[i + 1] - sorted_xmin[i])
  151. # y_distance = abs(sorted_ymin[i + 1] - sorted_ymin[i])
  152. # if x_distance > 20:
  153. # ssd_column = ssd_column + 1
  154. # x_distance_all.append(x_distance)
  155. # if x_distance > 2 * x_width_median + 4:
  156. # count_x = count_x + 1
  157. # if y_distance > 10:
  158. # ssd_row = ssd_row + 1
  159. # y_distance_all.append(y_distance)
  160. # if y_distance > 2 * y_height_median + 3:
  161. # count_y = count_y + 1
  162. # if x_width_all_sorted[i] - x_width_median > 40:
  163. # ssd_column = ssd_column - 1
  164. # elif x_width_median - x_width_all_sorted[i] > 40:
  165. # ssd_column = ssd_column - 1
  166. # if y_height_all_sorted[i] - y_height_median > 20:
  167. # ssd_row = ssd_row - 1
  168. # elif y_height_median - y_height_all_sorted[i] > 20:
  169. # ssd_row = ssd_row - 1
  170. #
  171. # if count_x < len(x_distance_all) / 2 + 1:
  172. # ssd_column = ssd_column + count_x
  173. # elif count_y < len(y_distance_all) / 2 + 1:
  174. # ssd_row = ssd_row + count_y
  175. #
  176. # average_height = int(np.mean(y_height_all))
  177. # average_width = int(np.mean(x_width_all))
  178. #
  179. # # average_height = format(np.mean(y_height_all), '.2f')
  180. # # average_width = format(np.mean(x_width_all), '.2f')
  181. # # average_height = int(np.mean(y_distance_all))
  182. # # average_width = int(np.mean(x_distance_all))
  183. # location_ssd = {'xmin': sorted_xmin[0] + left,
  184. # 'ymin': sorted_ymin[0] + top,
  185. # 'xmax': sorted_xmax[-1] + left,
  186. # 'ymax': sorted_ymax[-1] + top}
  187. #
  188. # choice_m_ssd = {'bounding_box': location_ssd,
  189. # "single_height": average_height,
  190. # "single_width": average_width,
  191. # "rows": ssd_row,
  192. # "cols": ssd_column,
  193. # 'class_name': 'choice_m',
  194. # 'all_small_coordinate': all_small_coordinate
  195. # }
  196. # else:
  197. # choice_m_ssd = {}
  198. # return choice_m_ssd
  199. def get_choice_m_row_and_col(left, top, image):
  200. im_resize = 300
  201. da_number_h = 0
  202. da_number_w = 0
  203. w_ = 300
  204. h_ = 300
  205. ''' choice_m resize to 300*300'''
  206. image_src = Image.fromarray(image)
  207. if image_src.mode == 'RGB':
  208. image_src = image_src.convert("L")
  209. w, h = image_src.size
  210. bounder_w = w
  211. bounder_h = h
  212. image_300 = Image.new(image_src.mode, (im_resize, im_resize), (255))
  213. if h > w:
  214. if h > 300:
  215. # w = int(w/(h/300))
  216. h_1 = 300
  217. image_src_resize = image_src.resize((w, 300))
  218. else:
  219. h_1 = h
  220. image_src_resize = image_src
  221. da_number_h = int((im_resize - w) / (w + 4))
  222. if da_number_h > 0:
  223. w_ = w
  224. h_ = h_1
  225. bounder_w = w
  226. for idx in range(da_number_h):
  227. x0 = idx * w + 4
  228. x1 = idx * w + 4 + w
  229. image_300.paste(image_src_resize, [x0, 0, x1, h_1])
  230. else:
  231. image_src_resize = image_src.resize((int(im_resize / h * w), im_resize))
  232. w_, h_ = image_src_resize.size
  233. image_300.paste(image_src_resize, [0, 0, w_, h_])
  234. else:
  235. if w > 300:
  236. # h = int(h/(w/300))
  237. w_1 = 300
  238. image_src_resize = image_src.resize((300, h))
  239. else:
  240. w_1 = w
  241. image_src_resize = image_src
  242. da_number_w = int((im_resize - h) / (h + 4))
  243. if da_number_w > 0:
  244. h_ = h
  245. w_ = w_1
  246. bounder_h = h
  247. for idx in range(da_number_w):
  248. h0 = idx * h + 4
  249. h1 = idx * h + 4 + h
  250. image_300.paste(image_src_resize, [0, h0, w_1, h1])
  251. else:
  252. image_src_resize = image_src.resize((im_resize, int(im_resize / w * h)))
  253. w_, h_ = image_src_resize.size
  254. image_300.paste(image_src_resize, [0, 0, w_, h_])
  255. w_resize, h_resize = image_src_resize.size
  256. category_index = label_map_util.create_category_index_from_labelmap(tf_settings.choice_m_ssd_label,
  257. use_display_name=True)
  258. detections = image_detect(image_300, category_index, 0.5)
  259. if len(detections) > 1:
  260. box_xmin = []
  261. box_ymin = []
  262. box_xmax = []
  263. box_ymax = []
  264. x_distance_all = []
  265. y_distance_all = []
  266. x_width_all = []
  267. y_height_all = []
  268. all_small_coordinate = []
  269. ssd_column = 1
  270. ssd_row = 1
  271. count_x = 0
  272. count_y = 0
  273. for index, box in enumerate(detections):
  274. if box[-1] != 'T' and box[2] <= 300 and box[3] <= 300:
  275. box0 = round(box[0] * (w / w_)) # Map to the original image
  276. box1 = round(box[1] * (h / h_))
  277. box2 = round(box[2] * (w / w_))
  278. box3 = round(box[3] * (h / h_))
  279. box_xmin.append(box0)
  280. box_ymin.append(box1)
  281. box_xmax.append(box2)
  282. box_ymax.append(box3)
  283. if box2 < bounder_w and box3 < bounder_h:
  284. small_coordinate = {'xmin': box0 + left,
  285. 'ymin': box1 + top,
  286. 'xmax': box2 + left,
  287. 'ymax': box3 + top}
  288. all_small_coordinate.append(small_coordinate)
  289. x_width = box2 - box0
  290. y_height = box3 - box1
  291. x_width_all.append(x_width)
  292. y_height_all.append(y_height)
  293. sorted_xmin = sorted(box_xmin)
  294. sorted_ymin = sorted(box_ymin)
  295. sorted_xmax = sorted(box_xmax)
  296. sorted_ymax = sorted(box_ymax)
  297. x_width_all_sorted = sorted(x_width_all, reverse=True)
  298. y_height_all_sorted = sorted(y_height_all, reverse=True)
  299. len_x = len(x_width_all)
  300. len_y = len(y_height_all)
  301. x_width_median = np.median(x_width_all_sorted)
  302. y_height_median = np.median(y_height_all_sorted)
  303. for i in range(len(sorted_xmin) - 1):
  304. if sorted_xmin[i + 1] < bounder_w:
  305. x_distance = abs(sorted_xmin[i + 1] - sorted_xmin[i])
  306. else:
  307. x_distance = 0
  308. if sorted_ymin[i + 1] < bounder_h:
  309. y_distance = abs(sorted_ymin[i + 1] - sorted_ymin[i])
  310. else:
  311. y_distance = 0
  312. if x_distance > 20:
  313. ssd_column = ssd_column + 1
  314. x_distance_all.append(x_distance)
  315. if x_distance > 2 * x_width_median + 4:
  316. count_x = count_x + 1
  317. if y_distance > 10:
  318. ssd_row = ssd_row + 1
  319. y_distance_all.append(y_distance)
  320. if y_distance > 2 * y_height_median + 3:
  321. count_y = count_y + 1
  322. if x_width_all_sorted[i] - x_width_median > 40:
  323. ssd_column = ssd_column - 1
  324. elif x_width_median - x_width_all_sorted[i] > 40:
  325. ssd_column = ssd_column - 1
  326. if y_height_all_sorted[i] - y_height_median > 20:
  327. ssd_row = ssd_row - 1
  328. elif y_height_median - y_height_all_sorted[i] > 20:
  329. ssd_row = ssd_row - 1
  330. if count_x < len(x_distance_all) / 2 + 1:
  331. ssd_column = ssd_column + count_x
  332. elif count_y < len(y_distance_all) / 2 + 1:
  333. ssd_row = ssd_row + count_y
  334. average_height = int(np.mean(y_height_all))
  335. average_width = int(np.mean(x_width_all))
  336. if da_number_w > 1 and da_number_h < 1:
  337. location_ssd = {'xmin': sorted_xmin[0] + left,
  338. 'ymin': sorted_ymin[0] + top,
  339. 'xmax': sorted_xmax[-1] + left,
  340. 'ymax': h + top}
  341. choice_m_ssd = {'bounding_box': location_ssd,
  342. "single_height": average_height,
  343. "single_width": average_width,
  344. "rows": ssd_row,
  345. "cols": ssd_column,
  346. 'class_name': 'choice_m',
  347. 'all_small_coordinate': all_small_coordinate
  348. }
  349. elif da_number_h > 1 and da_number_w < 1:
  350. location_ssd = {'xmin': sorted_xmin[0] + left,
  351. 'ymin': sorted_ymin[0] + top,
  352. 'xmax': w + left,
  353. 'ymax': sorted_ymax[-1] + top}
  354. choice_m_ssd = {'bounding_box': location_ssd,
  355. "single_height": average_height,
  356. "single_width": average_width,
  357. "rows": ssd_row,
  358. "cols": ssd_column,
  359. 'class_name': 'choice_m',
  360. 'all_small_coordinate': all_small_coordinate
  361. }
  362. elif da_number_h > 1 and da_number_w > 1:
  363. location_ssd = {'xmin': sorted_xmin[0] + left,
  364. 'ymin': sorted_ymin[0] + top,
  365. 'xmax': w + left,
  366. 'ymax': h + top}
  367. choice_m_ssd = {'bounding_box': location_ssd,
  368. "single_height": average_height,
  369. "single_width": average_width,
  370. "rows": ssd_row,
  371. "cols": ssd_column,
  372. 'class_name': 'choice_m',
  373. 'all_small_coordinate': all_small_coordinate
  374. }
  375. else:
  376. location_ssd = {'xmin': sorted_xmin[0] + left,
  377. 'ymin': sorted_ymin[0] + top,
  378. 'xmax': sorted_xmax[-1] + left,
  379. 'ymax': sorted_ymax[-1] + top}
  380. choice_m_ssd = {'bounding_box': location_ssd,
  381. "single_height": average_height,
  382. "single_width": average_width,
  383. "rows": ssd_row,
  384. "cols": ssd_column,
  385. 'class_name': 'choice_m',
  386. 'all_small_coordinate': all_small_coordinate
  387. }
  388. else:
  389. choice_m_ssd = {}
  390. return choice_m_ssd