# @Author : lightXu # @File : tf_sess.py # @Time : 2018/11/23 0023 下午 14:26 import os import tensorflow as tf from segment.sheet_resolve.lib.model.config import cfg from segment.sheet_resolve.lib.nets.resnet_v1 import resnetv1 from segment.sheet_resolve.tools.tf_settings import model_dict # os.environ['CUDA_VISIBLE_DEVICES'] = '0,1' class TfSess: def __init__(self, resolve_type): model_info = model_dict[resolve_type] self.graph = tf.Graph() with self.graph.as_default(): cfg.TEST.HAS_RPN = True # Use RPN for proposals # set config tfconfig = tf.ConfigProto(allow_soft_placement=True) tfconfig.gpu_options.allow_growth = True # tfconfig.gpu_options.per_process_gpu_memory_fraction = 0.4 self.net = resnetv1(num_layers=101) self.net.create_architecture("TEST", len(model_info['classes']), tag='default', anchor_scales=model_info['anchor_scales'], anchor_ratios=model_info['anchor_ratios']) self.saver = tf.train.Saver() tfmodel = model_info['path'] if not os.path.isfile(tfmodel + '.meta'): raise IOError('{:s} not found.'.format(tfmodel + '.meta')) self.sess = tf.Session(config=tfconfig, graph=self.graph) # self.sess.as_default() # try: # self.saver.restore(self.sess, tfmodel) # except Exception: # traceback.print_exc() with self.sess.as_default(): with self.graph.as_default(): self.saver.restore(self.sess, tfmodel) # 从恢复点恢复参数 print('\n Loaded network {:s}\n'.format(tfmodel)) def sess_close(self): self.sess.close() class SsdSess: def __init__(self, ssd_type): model_info = model_dict[ssd_type] graph = tf.Graph() with graph.as_default(): od_graph_def = tf.GraphDef() with tf.gfile.GFile(model_info['path'], 'rb') as fid: serialized_graph = fid.read() od_graph_def.ParseFromString(serialized_graph) tf.import_graph_def(od_graph_def, name='') self.sess = tf.Session(graph=graph) with graph.as_default(): with self.sess.as_default(): self.graph = tf.get_default_graph() # with self.sess.as_default(): # self.graph = graph print('\n Loaded network {:s}\n'.format(model_info['path'])) def sess_close(self): self.sess.close()