12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576 |
- # @Author : lightXu
- # @File : tf_sess.py
- # @Time : 2018/11/23 0023 下午 14:26
- import os
- import tensorflow as tf
- from segment.sheet_resolve.lib.model.config import cfg
- from segment.sheet_resolve.lib.nets.resnet_v1 import resnetv1
- from segment.sheet_resolve.tools.tf_settings import model_dict
- class TfSess:
- def __init__(self, resolve_type):
- model_info = model_dict[resolve_type]
- self.graph = tf.Graph()
- with self.graph.as_default():
- cfg.TEST.HAS_RPN = True # Use RPN for proposals
- # set config
- tfconfig = tf.ConfigProto(allow_soft_placement=True)
- tfconfig.gpu_options.allow_growth = True
- # tfconfig.gpu_options.per_process_gpu_memory_fraction = 0.4
- self.net = resnetv1(num_layers=101)
- self.net.create_architecture("TEST", len(model_info['classes']),
- tag='default', anchor_scales=model_info['anchor_scales'],
- anchor_ratios=model_info['anchor_ratios'])
- self.saver = tf.train.Saver()
- tfmodel = model_info['path']
- if not os.path.isfile(tfmodel + '.meta'):
- raise IOError('{:s} not found.'.format(tfmodel + '.meta'))
- self.sess = tf.Session(config=tfconfig, graph=self.graph)
- # self.sess.as_default()
- # try:
- # self.saver.restore(self.sess, tfmodel)
- # except Exception:
- # traceback.print_exc()
- with self.sess.as_default():
- with self.graph.as_default():
- self.saver.restore(self.sess, tfmodel) # 从恢复点恢复参数
- print('\n Loaded network {:s}\n'.format(tfmodel))
- def sess_close(self):
- self.sess.close()
- class SsdSess:
- def __init__(self, ssd_type):
- model_info = model_dict[ssd_type]
- graph = tf.Graph()
- with graph.as_default():
- od_graph_def = tf.GraphDef()
- with tf.gfile.GFile(model_info['path'], 'rb') as fid:
- serialized_graph = fid.read()
- od_graph_def.ParseFromString(serialized_graph)
- tf.import_graph_def(od_graph_def, name='')
- self.sess = tf.Session(graph=graph)
- with graph.as_default():
- with self.sess.as_default():
- self.graph = tf.get_default_graph()
- # with self.sess.as_default():
- # self.graph = graph
- print('\n Loaded network {:s}\n'.format(model_info['path']))
- def sess_close(self):
- self.sess.close()
|