tf_sess.py 2.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778
  1. # @Author : lightXu
  2. # @File : tf_sess.py
  3. # @Time : 2018/11/23 0023 下午 14:26
  4. import os
  5. import tensorflow as tf
  6. from segment.sheet_resolve.lib.model.config import cfg
  7. from segment.sheet_resolve.lib.nets.resnet_v1 import resnetv1
  8. from segment.sheet_resolve.tools.tf_settings import model_dict
  9. # os.environ['CUDA_VISIBLE_DEVICES'] = '0,1'
  10. class TfSess:
  11. def __init__(self, resolve_type):
  12. model_info = model_dict[resolve_type]
  13. self.graph = tf.Graph()
  14. with self.graph.as_default():
  15. cfg.TEST.HAS_RPN = True # Use RPN for proposals
  16. # set config
  17. tfconfig = tf.ConfigProto(allow_soft_placement=True)
  18. tfconfig.gpu_options.allow_growth = True
  19. # tfconfig.gpu_options.per_process_gpu_memory_fraction = 0.4
  20. self.net = resnetv1(num_layers=101)
  21. self.net.create_architecture("TEST", len(model_info['classes']),
  22. tag='default', anchor_scales=model_info['anchor_scales'],
  23. anchor_ratios=model_info['anchor_ratios'])
  24. self.saver = tf.train.Saver()
  25. tfmodel = model_info['path']
  26. if not os.path.isfile(tfmodel + '.meta'):
  27. raise IOError('{:s} not found.'.format(tfmodel + '.meta'))
  28. self.sess = tf.Session(config=tfconfig, graph=self.graph)
  29. # self.sess.as_default()
  30. # try:
  31. # self.saver.restore(self.sess, tfmodel)
  32. # except Exception:
  33. # traceback.print_exc()
  34. with self.sess.as_default():
  35. with self.graph.as_default():
  36. self.saver.restore(self.sess, tfmodel) # 从恢复点恢复参数
  37. print('\n Loaded network {:s}\n'.format(tfmodel))
  38. def sess_close(self):
  39. self.sess.close()
  40. class SsdSess:
  41. def __init__(self, ssd_type):
  42. model_info = model_dict[ssd_type]
  43. graph = tf.Graph()
  44. with graph.as_default():
  45. od_graph_def = tf.GraphDef()
  46. with tf.gfile.GFile(model_info['path'], 'rb') as fid:
  47. serialized_graph = fid.read()
  48. od_graph_def.ParseFromString(serialized_graph)
  49. tf.import_graph_def(od_graph_def, name='')
  50. self.sess = tf.Session(graph=graph)
  51. with graph.as_default():
  52. with self.sess.as_default():
  53. self.graph = tf.get_default_graph()
  54. # with self.sess.as_default():
  55. # self.graph = graph
  56. print('\n Loaded network {:s}\n'.format(model_info['path']))
  57. def sess_close(self):
  58. self.sess.close()