proposal_layer.py 1.7 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344
  1. # --------------------------------------------------------
  2. # Faster R-CNN
  3. # Licensed under The MIT License [see LICENSE for details]
  4. # Written by Ross Girshick and Xinlei Chen
  5. # --------------------------------------------------------
  6. from __future__ import absolute_import
  7. from __future__ import division
  8. from __future__ import print_function
  9. import tensorflow as tf
  10. import numpy as np
  11. from segment.sheet_resolve.lib.model.config import cfg
  12. from segment.sheet_resolve.lib.model.bbox_transform import bbox_transform_inv, clip_boxes, bbox_transform_inv_tf, clip_boxes_tf
  13. from segment.sheet_resolve.lib.model.nms_wrapper import nms
  14. def proposal_layer(rpn_cls_prob, rpn_bbox_pred, im_info, cfg_key, _feat_stride, anchors, num_anchors):
  15. if type(cfg_key) == bytes:
  16. cfg_key = cfg_key.decode('utf-8')
  17. pre_nms_topN = cfg[cfg_key].RPN_PRE_NMS_TOP_N
  18. post_nms_topN = cfg[cfg_key].RPN_POST_NMS_TOP_N
  19. nms_thresh = cfg[cfg_key].RPN_NMS_THRESH
  20. # Get the scores and bounding boxes
  21. scores = rpn_cls_prob[:, :, :, num_anchors:]
  22. scores = tf.reshape(scores, shape=(-1,))
  23. rpn_bbox_pred = tf.reshape(rpn_bbox_pred, shape=(-1, 4))
  24. proposals = bbox_transform_inv_tf(anchors, rpn_bbox_pred)
  25. proposals = clip_boxes_tf(proposals, im_info[:2])
  26. # Non-maximal suppression
  27. indices = tf.image.non_max_suppression(proposals, scores, max_output_size=post_nms_topN, iou_threshold=nms_thresh)
  28. boxes = tf.gather(proposals, indices)
  29. boxes = tf.to_float(boxes)
  30. scores = tf.gather(scores, indices)
  31. scores = tf.reshape(scores, shape=(-1, 1))
  32. # Only support single image as input
  33. batch_inds = tf.zeros((tf.shape(indices)[0], 1), dtype=tf.float32)
  34. blob = tf.concat([batch_inds, boxes], 1)
  35. return blob, scores