123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122 |
- # -*- coding: utf-8 -*-
- import tensorflow as tf
- import numpy as np
- import pickle
- import os
- import jieba
- from util import data_split_by_label
- def separate_items_answers(path):
- """
- test paper have 3 types:1.answer is at the final sign:#########################
- 2.answer is in the paper
- 3.no answer sign:************************************
- :param path:where your txt_file is
- :return:list:items;list:answers << one line content is one element
- """
- items = []
- answers = []
- with open(path,'r',encoding='utf-8') as f:
- # print('txt的所有内容: %s' % f.read())
- content = f.readlines()
- # print(content)
- answers_start = 0
- if '####################' in ''.join(content): # if txt has no ##############,has no answers or answers among test_paper,do not deal with this kind of txt
- for i,line in enumerate(content):
- Nline = line.replace(' ','').replace('#','')
- if line == '' or line.rstrip() == '': # \n
- # items.append('')
- pass
- elif Nline.rstrip() != '': # not :################### not \n
- items.append(line)
- elif '####################' in line and Nline.rstrip() == '': # only get data before ################
- answers_start = i+1
- break
- for line in content[answers_start:]:
- Nline = line.replace(' ', '').replace('#', '')
- if line == '' or line.rstrip() == '': # \n
- answers.append('')
- # elif '#####' in line and Nline.rstrip() != '': # not :################### not \n
- else:
- answers.append(line)
- elif '*******************' in ''.join(content): # no answer
- for i,line in enumerate(content):
- Nline = line.replace(' ', '').replace('#', '')
- if line == '' or line.rstrip() == '': # \n
- # items.append('')
- pass
- elif line.rstrip() != '': # not :################### not \n
- items.append(line)
- elif '*************************' in line and Nline.rstrip() == '': # only get data before ################
- answers_start = i+1
- break
- return items, answers
- def load_data(x_data,x_vec_path):
- """
- :param x_vec: during trainging,haved saved this file as a pkl
- :return: the data to predict which transformed array
- """
- with open(x_vec_path,'rb') as f:
- vocab_processor = pickle.load(f)
- x_test = np.array(list(vocab_processor.fit_transform(x_data)))
- return x_test
- def load_model(checkpoint_path, x_test, model_index):
- """
- :param checkpoint_path:
- :return: checkpoint file:meta,index..about DL model structure ans parameters
- """
- tf.flags.DEFINE_boolean("allow_soft_placement", True, "Allow device soft device placement")
- FLAGS = tf.flags.FLAGS
- FLAGS._parse_flags()
- gragh = tf.Graph()
- with gragh.as_default():
- session_conf = tf.ConfigProto(
- allow_soft_placement=FLAGS.allow_soft_placement)
- sess = tf.Session(config=session_conf)
- with sess.as_default():
- saver = tf.train.import_meta_graph(os.path.join(checkpoint_path,'model-'+str(model_index)+'.meta'))
- saver.restore(sess, os.path.join(checkpoint_path,'model-'+str(model_index)))
- input_x = gragh.get_operation_by_name("input_x").outputs[0]
- dropout_keep_prob = gragh.get_operation_by_name("dropout_keep_prob").outputs[0]
- predictions = gragh.get_operation_by_name("output/predictions").outputs[0]
- label = []
- for each_line in x_test:
- each_line_T = each_line.reshape(-1,len(each_line))
- prediction = sess.run(predictions, {input_x: each_line_T, dropout_keep_prob: 1.0})
- label.append(prediction.tolist()[0])
- sess.close()
- return label
- def lable_predict(x_origin, x_vec_path, checkpoint_path, model_index):
- # x_data = []
- # x_origin = []
- # x_origin, ans = separate_items_answers(txt_file_path)
- x_data = [' '.join(jieba.cut(i)) for i in x_origin]
- x_test = load_data(x_data,x_vec_path)
- label = load_model(checkpoint_path,x_test, model_index)
- return label
- if __name__ == '__main__':
- txt_file_path = r'test_paper.txt'
- x_vec_path = r'./des_content_label_predict/x_vec'
- checkpoint_path = r'./des_content_label_predict/checkpoints'
- x_origin, ans = separate_items_answers(txt_file_path)
- label = lable_predict(x_origin, x_vec_path, checkpoint_path,5550)
- one_file_split = data_split_by_label(x_origin, label)
- item_type = ['听力','Na','听力','听力','听力','单项填空','完形填空','阅读理解','Na','Na','Na','短文改错','书面表达'] #题型分类
- # type_content_dict = dict(zip(item_type,one_file_split))
-
- #single_filling
- single_filling = one_file_split[item_type.index('单项填空')]
- x_vec_path = r'./single_filling_stem_opt_predict/x_vec'
- checkpoint_path = r'./single_filling_stem_opt_predict/checkpoints'
- stem_opt_label = lable_predict(x_origin, x_vec_path, checkpoint_path,2280)
- single_filling_stem_opt_split = data_split_by_label(single_filling[1], label)
-
-
-
|