tf_sess.py 2.5 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576
  1. # @Author : lightXu
  2. # @File : tf_sess.py
  3. # @Time : 2018/11/23 0023 下午 14:26
  4. import os
  5. import tensorflow as tf
  6. from segment.sheet_resolve.lib.model.config import cfg
  7. from segment.sheet_resolve.lib.nets.resnet_v1 import resnetv1
  8. from segment.sheet_resolve.tools.tf_settings import model_dict
  9. class TfSess:
  10. def __init__(self, resolve_type):
  11. model_info = model_dict[resolve_type]
  12. self.graph = tf.Graph()
  13. with self.graph.as_default():
  14. cfg.TEST.HAS_RPN = True # Use RPN for proposals
  15. # set config
  16. tfconfig = tf.ConfigProto(allow_soft_placement=True)
  17. tfconfig.gpu_options.allow_growth = True
  18. # tfconfig.gpu_options.per_process_gpu_memory_fraction = 0.4
  19. self.net = resnetv1(num_layers=101)
  20. self.net.create_architecture("TEST", len(model_info['classes']),
  21. tag='default', anchor_scales=model_info['anchor_scales'],
  22. anchor_ratios=model_info['anchor_ratios'])
  23. self.saver = tf.train.Saver()
  24. tfmodel = model_info['path']
  25. if not os.path.isfile(tfmodel + '.meta'):
  26. raise IOError('{:s} not found.'.format(tfmodel + '.meta'))
  27. self.sess = tf.Session(config=tfconfig, graph=self.graph)
  28. # self.sess.as_default()
  29. # try:
  30. # self.saver.restore(self.sess, tfmodel)
  31. # except Exception:
  32. # traceback.print_exc()
  33. with self.sess.as_default():
  34. with self.graph.as_default():
  35. self.saver.restore(self.sess, tfmodel) # 从恢复点恢复参数
  36. print('\n Loaded network {:s}\n'.format(tfmodel))
  37. def sess_close(self):
  38. self.sess.close()
  39. class SsdSess:
  40. def __init__(self, ssd_type):
  41. model_info = model_dict[ssd_type]
  42. graph = tf.Graph()
  43. with graph.as_default():
  44. od_graph_def = tf.GraphDef()
  45. with tf.gfile.GFile(model_info['path'], 'rb') as fid:
  46. serialized_graph = fid.read()
  47. od_graph_def.ParseFromString(serialized_graph)
  48. tf.import_graph_def(od_graph_def, name='')
  49. self.sess = tf.Session(graph=graph)
  50. with graph.as_default():
  51. with self.sess.as_default():
  52. self.graph = tf.get_default_graph()
  53. # with self.sess.as_default():
  54. # self.graph = graph
  55. print('\n Loaded network {:s}\n'.format(model_info['path']))
  56. def sess_close(self):
  57. self.sess.close()