# -*- coding: utf-8 -*- import tensorflow as tf import numpy as np import pickle import os import jieba from util import data_split_by_label def separate_items_answers(path): """ test paper have 3 types:1.answer is at the final sign:######################### 2.answer is in the paper 3.no answer sign:************************************ :param path:where your txt_file is :return:list:items;list:answers << one line content is one element """ items = [] answers = [] with open(path,'r',encoding='utf-8') as f: # print('txt的所有内容: %s' % f.read()) content = f.readlines() # print(content) answers_start = 0 if '####################' in ''.join(content): # if txt has no ##############,has no answers or answers among test_paper,do not deal with this kind of txt for i,line in enumerate(content): Nline = line.replace(' ','').replace('#','') if line == '' or line.rstrip() == '': # \n # items.append('') pass elif Nline.rstrip() != '': # not :################### not \n items.append(line) elif '####################' in line and Nline.rstrip() == '': # only get data before ################ answers_start = i+1 break for line in content[answers_start:]: Nline = line.replace(' ', '').replace('#', '') if line == '' or line.rstrip() == '': # \n answers.append('') # elif '#####' in line and Nline.rstrip() != '': # not :################### not \n else: answers.append(line) elif '*******************' in ''.join(content): # no answer for i,line in enumerate(content): Nline = line.replace(' ', '').replace('#', '') if line == '' or line.rstrip() == '': # \n # items.append('') pass elif line.rstrip() != '': # not :################### not \n items.append(line) elif '*************************' in line and Nline.rstrip() == '': # only get data before ################ answers_start = i+1 break return items, answers def load_data(x_data,x_vec_path): """ :param x_vec: during trainging,haved saved this file as a pkl :return: the data to predict which transformed array """ with open(x_vec_path,'rb') as f: vocab_processor = pickle.load(f) x_test = np.array(list(vocab_processor.fit_transform(x_data))) return x_test def load_model(checkpoint_path, x_test, model_index): """ :param checkpoint_path: :return: checkpoint file:meta,index..about DL model structure ans parameters """ tf.flags.DEFINE_boolean("allow_soft_placement", True, "Allow device soft device placement") FLAGS = tf.flags.FLAGS FLAGS._parse_flags() gragh = tf.Graph() with gragh.as_default(): session_conf = tf.ConfigProto( allow_soft_placement=FLAGS.allow_soft_placement) sess = tf.Session(config=session_conf) with sess.as_default(): saver = tf.train.import_meta_graph(os.path.join(checkpoint_path,'model-'+str(model_index)+'.meta')) saver.restore(sess, os.path.join(checkpoint_path,'model-'+str(model_index))) input_x = gragh.get_operation_by_name("input_x").outputs[0] dropout_keep_prob = gragh.get_operation_by_name("dropout_keep_prob").outputs[0] predictions = gragh.get_operation_by_name("output/predictions").outputs[0] label = [] for each_line in x_test: each_line_T = each_line.reshape(-1,len(each_line)) prediction = sess.run(predictions, {input_x: each_line_T, dropout_keep_prob: 1.0}) label.append(prediction.tolist()[0]) sess.close() return label def lable_predict(x_origin, x_vec_path, checkpoint_path, model_index): # x_data = [] # x_origin = [] # x_origin, ans = separate_items_answers(txt_file_path) x_data = [' '.join(jieba.cut(i)) for i in x_origin] x_test = load_data(x_data,x_vec_path) label = load_model(checkpoint_path,x_test, model_index) return label if __name__ == '__main__': txt_file_path = r'test_paper.txt' x_vec_path = r'./des_content_label_predict/x_vec' checkpoint_path = r'./des_content_label_predict/checkpoints' x_origin, ans = separate_items_answers(txt_file_path) label = lable_predict(x_origin, x_vec_path, checkpoint_path,5550) one_file_split = data_split_by_label(x_origin, label) item_type = ['听力','Na','听力','听力','听力','单项填空','完形填空','阅读理解','Na','Na','Na','短文改错','书面表达'] #题型分类 # type_content_dict = dict(zip(item_type,one_file_split)) #single_filling single_filling = one_file_split[item_type.index('单项填空')] x_vec_path = r'./single_filling_stem_opt_predict/x_vec' checkpoint_path = r'./single_filling_stem_opt_predict/checkpoints' stem_opt_label = lable_predict(x_origin, x_vec_path, checkpoint_path,2280) single_filling_stem_opt_split = data_split_by_label(single_filling[1], label)