anchor_target_layer.py 6.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163
  1. # --------------------------------------------------------
  2. # Faster R-CNN
  3. # Copyright (c) 2015 Microsoft
  4. # Licensed under The MIT License [see LICENSE for details]
  5. # Written by Ross Girshick and Xinlei Chen
  6. # --------------------------------------------------------
  7. from __future__ import absolute_import
  8. from __future__ import division
  9. from __future__ import print_function
  10. import os
  11. from segment.sheet_resolve.lib.model.config import cfg
  12. import numpy as np
  13. import numpy.random as npr
  14. from segment.sheet_resolve.lib.utils.py_bbox import bbox_overlaps
  15. from segment.sheet_resolve.lib.model.bbox_transform import bbox_transform
  16. def anchor_target_layer(rpn_cls_score, gt_boxes, im_info, _feat_stride, all_anchors, num_anchors):
  17. """Same as the anchor target layer in original Fast/er RCNN """
  18. A = num_anchors
  19. total_anchors = all_anchors.shape[0]
  20. K = total_anchors / num_anchors
  21. # allow boxes to sit over the edge by a small amount
  22. _allowed_border = 0
  23. # map of shape (..., H, W)
  24. height, width = rpn_cls_score.shape[1:3]
  25. # only keep anchors inside the image
  26. inds_inside = np.where(
  27. (all_anchors[:, 0] >= -_allowed_border) &
  28. (all_anchors[:, 1] >= -_allowed_border) &
  29. (all_anchors[:, 2] < im_info[1] + _allowed_border) & # width
  30. (all_anchors[:, 3] < im_info[0] + _allowed_border) # height
  31. )[0]
  32. # keep only inside anchors
  33. anchors = all_anchors[inds_inside, :]
  34. # label: 1 is positive, 0 is negative, -1 is dont care
  35. labels = np.empty((len(inds_inside),), dtype=np.float32)
  36. labels.fill(-1)
  37. # overlaps between the anchors and the gt boxes
  38. # overlaps (ex, gt)
  39. overlaps = bbox_overlaps(
  40. np.ascontiguousarray(anchors, dtype=np.float),
  41. np.ascontiguousarray(gt_boxes, dtype=np.float))
  42. argmax_overlaps = overlaps.argmax(axis=1)
  43. max_overlaps = overlaps[np.arange(len(inds_inside)), argmax_overlaps]
  44. gt_argmax_overlaps = overlaps.argmax(axis=0)
  45. gt_max_overlaps = overlaps[gt_argmax_overlaps,
  46. np.arange(overlaps.shape[1])]
  47. gt_argmax_overlaps = np.where(overlaps == gt_max_overlaps)[0]
  48. if not cfg.TRAIN.RPN_CLOBBER_POSITIVES:
  49. # assign bg labels first so that positive labels can clobber them
  50. # first set the negatives
  51. labels[max_overlaps < cfg.TRAIN.RPN_NEGATIVE_OVERLAP] = 0
  52. # fg label: for each gt, anchor with highest overlap
  53. labels[gt_argmax_overlaps] = 1
  54. # fg label: above threshold IOU
  55. labels[max_overlaps >= cfg.TRAIN.RPN_POSITIVE_OVERLAP] = 1
  56. if cfg.TRAIN.RPN_CLOBBER_POSITIVES:
  57. # assign bg labels last so that negative labels can clobber positives
  58. labels[max_overlaps < cfg.TRAIN.RPN_NEGATIVE_OVERLAP] = 0
  59. # subsample positive labels if we have too many
  60. num_fg = int(cfg.TRAIN.RPN_FG_FRACTION * cfg.TRAIN.RPN_BATCHSIZE)
  61. fg_inds = np.where(labels == 1)[0]
  62. if len(fg_inds) > num_fg:
  63. disable_inds = npr.choice(
  64. fg_inds, size=(len(fg_inds) - num_fg), replace=False)
  65. labels[disable_inds] = -1
  66. # subsample negative labels if we have too many
  67. num_bg = cfg.TRAIN.RPN_BATCHSIZE - np.sum(labels == 1)
  68. bg_inds = np.where(labels == 0)[0]
  69. if len(bg_inds) > num_bg:
  70. disable_inds = npr.choice(
  71. bg_inds, size=(len(bg_inds) - num_bg), replace=False)
  72. labels[disable_inds] = -1
  73. bbox_targets = np.zeros((len(inds_inside), 4), dtype=np.float32)
  74. bbox_targets = _compute_targets(anchors, gt_boxes[argmax_overlaps, :])
  75. bbox_inside_weights = np.zeros((len(inds_inside), 4), dtype=np.float32)
  76. # only the positive ones have regression targets
  77. bbox_inside_weights[labels == 1, :] = np.array(cfg.TRAIN.RPN_BBOX_INSIDE_WEIGHTS)
  78. bbox_outside_weights = np.zeros((len(inds_inside), 4), dtype=np.float32)
  79. if cfg.TRAIN.RPN_POSITIVE_WEIGHT < 0:
  80. # uniform weighting of examples (given non-uniform sampling)
  81. num_examples = np.sum(labels >= 0)
  82. positive_weights = np.ones((1, 4)) * 1.0 / num_examples
  83. negative_weights = np.ones((1, 4)) * 1.0 / num_examples
  84. else:
  85. assert ((cfg.TRAIN.RPN_POSITIVE_WEIGHT > 0) &
  86. (cfg.TRAIN.RPN_POSITIVE_WEIGHT < 1))
  87. positive_weights = (cfg.TRAIN.RPN_POSITIVE_WEIGHT /
  88. np.sum(labels == 1))
  89. negative_weights = ((1.0 - cfg.TRAIN.RPN_POSITIVE_WEIGHT) /
  90. np.sum(labels == 0))
  91. bbox_outside_weights[labels == 1, :] = positive_weights
  92. bbox_outside_weights[labels == 0, :] = negative_weights
  93. # map up to original set of anchors
  94. labels = _unmap(labels, total_anchors, inds_inside, fill=-1)
  95. bbox_targets = _unmap(bbox_targets, total_anchors, inds_inside, fill=0)
  96. bbox_inside_weights = _unmap(bbox_inside_weights, total_anchors, inds_inside, fill=0)
  97. bbox_outside_weights = _unmap(bbox_outside_weights, total_anchors, inds_inside, fill=0)
  98. # labels
  99. labels = labels.reshape((1, height, width, A)).transpose(0, 3, 1, 2)
  100. labels = labels.reshape((1, 1, A * height, width))
  101. rpn_labels = labels
  102. # bbox_targets
  103. bbox_targets = bbox_targets \
  104. .reshape((1, height, width, A * 4))
  105. rpn_bbox_targets = bbox_targets
  106. # bbox_inside_weights
  107. bbox_inside_weights = bbox_inside_weights \
  108. .reshape((1, height, width, A * 4))
  109. rpn_bbox_inside_weights = bbox_inside_weights
  110. # bbox_outside_weights
  111. bbox_outside_weights = bbox_outside_weights \
  112. .reshape((1, height, width, A * 4))
  113. rpn_bbox_outside_weights = bbox_outside_weights
  114. return rpn_labels, rpn_bbox_targets, rpn_bbox_inside_weights, rpn_bbox_outside_weights
  115. def _unmap(data, count, inds, fill=0):
  116. """ Unmap a subset of item (data) back to the original set of items (of
  117. size count) """
  118. if len(data.shape) == 1:
  119. ret = np.empty((count,), dtype=np.float32)
  120. ret.fill(fill)
  121. ret[inds] = data
  122. else:
  123. ret = np.empty((count,) + data.shape[1:], dtype=np.float32)
  124. ret.fill(fill)
  125. ret[inds, :] = data
  126. return ret
  127. def _compute_targets(ex_rois, gt_rois):
  128. """Compute bounding-box regression targets for an image."""
  129. assert ex_rois.shape[0] == gt_rois.shape[0]
  130. assert ex_rois.shape[1] == 4
  131. assert gt_rois.shape[1] == 5
  132. return bbox_transform(ex_rois, gt_rois[:, :4]).astype(np.float32, copy=False)