cnn_predict.py 5.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122
  1. # -*- coding: utf-8 -*-
  2. import tensorflow as tf
  3. import numpy as np
  4. import pickle
  5. import os
  6. import jieba
  7. from util import data_split_by_label
  8. def separate_items_answers(path):
  9. """
  10. test paper have 3 types:1.answer is at the final sign:#########################
  11. 2.answer is in the paper
  12. 3.no answer sign:************************************
  13. :param path:where your txt_file is
  14. :return:list:items;list:answers << one line content is one element
  15. """
  16. items = []
  17. answers = []
  18. with open(path,'r',encoding='utf-8') as f:
  19. # print('txt的所有内容: %s' % f.read())
  20. content = f.readlines()
  21. # print(content)
  22. answers_start = 0
  23. if '####################' in ''.join(content): # if txt has no ##############,has no answers or answers among test_paper,do not deal with this kind of txt
  24. for i,line in enumerate(content):
  25. Nline = line.replace(' ','').replace('#','')
  26. if line == '' or line.rstrip() == '': # \n
  27. # items.append('')
  28. pass
  29. elif Nline.rstrip() != '': # not :################### not \n
  30. items.append(line)
  31. elif '####################' in line and Nline.rstrip() == '': # only get data before ################
  32. answers_start = i+1
  33. break
  34. for line in content[answers_start:]:
  35. Nline = line.replace(' ', '').replace('#', '')
  36. if line == '' or line.rstrip() == '': # \n
  37. answers.append('')
  38. # elif '#####' in line and Nline.rstrip() != '': # not :################### not \n
  39. else:
  40. answers.append(line)
  41. elif '*******************' in ''.join(content): # no answer
  42. for i,line in enumerate(content):
  43. Nline = line.replace(' ', '').replace('#', '')
  44. if line == '' or line.rstrip() == '': # \n
  45. # items.append('')
  46. pass
  47. elif line.rstrip() != '': # not :################### not \n
  48. items.append(line)
  49. elif '*************************' in line and Nline.rstrip() == '': # only get data before ################
  50. answers_start = i+1
  51. break
  52. return items, answers
  53. def load_data(x_data,x_vec_path):
  54. """
  55. :param x_vec: during trainging,haved saved this file as a pkl
  56. :return: the data to predict which transformed array
  57. """
  58. with open(x_vec_path,'rb') as f:
  59. vocab_processor = pickle.load(f)
  60. x_test = np.array(list(vocab_processor.fit_transform(x_data)))
  61. return x_test
  62. def load_model(checkpoint_path, x_test, model_index):
  63. """
  64. :param checkpoint_path:
  65. :return: checkpoint file:meta,index..about DL model structure ans parameters
  66. """
  67. tf.flags.DEFINE_boolean("allow_soft_placement", True, "Allow device soft device placement")
  68. FLAGS = tf.flags.FLAGS
  69. FLAGS._parse_flags()
  70. gragh = tf.Graph()
  71. with gragh.as_default():
  72. session_conf = tf.ConfigProto(
  73. allow_soft_placement=FLAGS.allow_soft_placement)
  74. sess = tf.Session(config=session_conf)
  75. with sess.as_default():
  76. saver = tf.train.import_meta_graph(os.path.join(checkpoint_path,'model-'+str(model_index)+'.meta'))
  77. saver.restore(sess, os.path.join(checkpoint_path,'model-'+str(model_index)))
  78. input_x = gragh.get_operation_by_name("input_x").outputs[0]
  79. dropout_keep_prob = gragh.get_operation_by_name("dropout_keep_prob").outputs[0]
  80. predictions = gragh.get_operation_by_name("output/predictions").outputs[0]
  81. label = []
  82. for each_line in x_test:
  83. each_line_T = each_line.reshape(-1,len(each_line))
  84. prediction = sess.run(predictions, {input_x: each_line_T, dropout_keep_prob: 1.0})
  85. label.append(prediction.tolist()[0])
  86. sess.close()
  87. return label
  88. def lable_predict(x_origin, x_vec_path, checkpoint_path, model_index):
  89. # x_data = []
  90. # x_origin = []
  91. # x_origin, ans = separate_items_answers(txt_file_path)
  92. x_data = [' '.join(jieba.cut(i)) for i in x_origin]
  93. x_test = load_data(x_data,x_vec_path)
  94. label = load_model(checkpoint_path,x_test, model_index)
  95. return label
  96. if __name__ == '__main__':
  97. txt_file_path = r'test_paper.txt'
  98. x_vec_path = r'./des_content_label_predict/x_vec'
  99. checkpoint_path = r'./des_content_label_predict/checkpoints'
  100. x_origin, ans = separate_items_answers(txt_file_path)
  101. label = lable_predict(x_origin, x_vec_path, checkpoint_path,5550)
  102. one_file_split = data_split_by_label(x_origin, label)
  103. item_type = ['听力','Na','听力','听力','听力','单项填空','完形填空','阅读理解','Na','Na','Na','短文改错','书面表达'] #题型分类
  104. # type_content_dict = dict(zip(item_type,one_file_split))
  105. #single_filling
  106. single_filling = one_file_split[item_type.index('单项填空')]
  107. x_vec_path = r'./single_filling_stem_opt_predict/x_vec'
  108. checkpoint_path = r'./single_filling_stem_opt_predict/checkpoints'
  109. stem_opt_label = lable_predict(x_origin, x_vec_path, checkpoint_path,2280)
  110. single_filling_stem_opt_split = data_split_by_label(single_filling[1], label)