snippets.py 1.3 KB

123456789101112131415161718192021222324252627282930313233
  1. # --------------------------------------------------------
  2. # Tensorflow Faster R-CNN
  3. # Licensed under The MIT License [see LICENSE for details]
  4. # Written by Xinlei Chen
  5. # --------------------------------------------------------
  6. from __future__ import absolute_import
  7. from __future__ import division
  8. from __future__ import print_function
  9. import tensorflow as tf
  10. import numpy as np
  11. from segment.sheet_resolve.lib.layer_utils.generate_anchors import generate_anchors
  12. def generate_anchors_pre(height, width, feat_stride=16, anchor_scales=(8, 16, 32), anchor_ratios=(0.5, 1, 2)):
  13. shift_x = tf.range(width) * feat_stride # width
  14. shift_y = tf.range(height) * feat_stride # height
  15. shift_x, shift_y = tf.meshgrid(shift_x, shift_y)
  16. sx = tf.reshape(shift_x, shape=(-1,))
  17. sy = tf.reshape(shift_y, shape=(-1,))
  18. shifts = tf.transpose(tf.stack([sx, sy, sx, sy]))
  19. K = tf.multiply(width, height)
  20. shifts = tf.transpose(tf.reshape(shifts, shape=[1, K, 4]), perm=(1, 0, 2))
  21. anchors = generate_anchors(ratios=np.array(anchor_ratios), scales=np.array(anchor_scales))
  22. A = anchors.shape[0]
  23. anchor_constant = tf.constant(anchors.reshape((1, A, 4)), dtype=tf.int32)
  24. length = K * A
  25. anchors_tf = tf.reshape(tf.add(anchor_constant, shifts), shape=(length, 4))
  26. return tf.cast(anchors_tf, dtype=tf.float32), length