decode.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270
  1. import numpy as np
  2. import re
  3. from collections import defaultdict
  4. def sigmoid(x):
  5. return 1 / (1 + np.exp(-x))
  6. def topic_ner_decode_1(start_logits, end_logits, raw_text_list, id2label):
  7. """
  8. 针对第一批400份测试样本,由于标签位置不同,会有稍许差异
  9. """
  10. predict_entities = defaultdict(list)
  11. if len(id2label) == 1:
  12. # raw_text_list为句子列表
  13. start_pred = np.where(start_logits > 0.5, 1, 0)
  14. end_pred = np.where(end_logits > 0.5, 1, 0)
  15. # 局部纠正:相邻的start以前一个为主;相邻的end以后一个为主
  16. start_pred = [0 if i and n>0 and start_pred[n-1] else i for n, i in enumerate(start_pred)]
  17. end_pred = [0 if i and n+1<len(end_pred) and end_pred[n+1] else i for n, i in enumerate(end_pred)]
  18. start_idx = [i for i, j in enumerate(start_pred) if j]
  19. end_idx = [i for i, j in enumerate(end_pred) if j]
  20. print(start_pred, len(start_pred))
  21. print("start_pred索引位置:", start_idx)
  22. print(end_pred)
  23. print("end_pred索引位置:", end_idx)
  24. # 两个TOPIC不可能出现重叠
  25. # start或end召回低时可以参考对方位置
  26. last_st = 0
  27. for nn, pos in enumerate(start_idx):
  28. if nn > 0:
  29. # 若考虑end_pred会降低准确率,
  30. # if pos - start_idx[nn-1] > 15:
  31. # may_split = [i for i in end_idx if i < pos and i > start_idx[nn-1]]
  32. # for may_pos in may_split:
  33. # if may_pos - last_st >= 2: # 一个题至少3个句子
  34. # tmp_ent = raw_text_list[last_st: may_pos+1]
  35. # predict_entities[id2label[0]].append((tmp_ent, last_st))
  36. # last_st = may_pos + 1
  37. # else:
  38. tmp_ent = raw_text_list[start_idx[nn-1]: pos]
  39. predict_entities[id2label[0]].append((tmp_ent, start_idx[nn-1]))
  40. elif nn == len(start_idx) - 1: # 只有一个题
  41. tmp_ent = raw_text_list
  42. predict_entities[id2label[0]].append((tmp_ent, 0))
  43. return predict_entities
  44. def content_bc_decode(con_logits):
  45. """
  46. 对content_logit设置阈值,判断是否为试题内容
  47. """
  48. con_pred = np.where(con_logits > 0.5, 1, 0)
  49. return con_pred
  50. def topic_ner_decode(start_logits, end_logits, con_logits, raw_text_list, id2label):
  51. """
  52. 理想状态:start召回个数应该与end召回个数一样,但实际不同
  53. """
  54. predict_entities = defaultdict(list)
  55. topic_item_pred = []
  56. if len(id2label) == 1:
  57. # raw_text_list为句子列表
  58. start_pred = np.where(start_logits > 0.5, 1, 0) # 设置更高的阈值将会使准确率更高,但召回率降低
  59. end_pred = np.where(end_logits > 0.5, 1, 0)
  60. con_pred = np.where(con_logits > 0.5, 1, 0)
  61. topic_item_pred = con_pred.tolist()
  62. # 局部纠正:相邻的start以前一个为主;相邻的end以后一个为主
  63. # start_pred = [0 if i and n>0 and start_pred[n-1] else i for n, i in enumerate(start_pred)]
  64. end_pred = [0 if i and n+1<len(end_pred) and end_pred[n+1] else i for n, i in enumerate(end_pred)]
  65. start_idx = [i for i, j in enumerate(start_pred) if j]
  66. end_idx = [i for i, j in enumerate(end_pred) if j]
  67. not_con_idx = [i for i, j in enumerate(con_pred) if not j]
  68. print("start_logits:", start_logits)
  69. print(start_pred, len(start_pred))
  70. print("start_pred索引位置:", start_idx)
  71. print(end_pred)
  72. print("end_pred索引位置:", end_idx)
  73. print(con_pred)
  74. print("not_con_idx索引位置:", not_con_idx)
  75. # 两个TOPIC不可能出现重叠
  76. last_st = 0
  77. last_txt = ""
  78. for nn, pos in enumerate(start_idx):
  79. if nn > 0:
  80. # 若考虑end_pred会降低准确率,
  81. # if pos - start_idx[nn-1] > 15:
  82. # may_split = [i for i in end_idx if i < pos and i > start_idx[nn-1]]
  83. # for may_pos in may_split:
  84. # if may_pos - last_st >= 2: # 一个题至少3个句子
  85. # tmp_ent = raw_text_list[last_st: may_pos+1]
  86. # predict_entities[id2label[0]].append((tmp_ent, last_st))
  87. # last_st = may_pos + 1
  88. # else:
  89. tmp_ent = raw_text_list[start_idx[nn-1]: pos]
  90. # 只有单行文本时对很短的文本:数字则放在下一行,其他放上一行
  91. if len(tmp_ent) == 1 and con_pred[pos-1]:
  92. if re.match(r"\d\s*[^\u4e00-\u9fa5]?$", tmp_ent[0]):
  93. last_txt = tmp_ent[0]
  94. continue
  95. if re.match(r"\d\s*[.、.::]\s*[\u4e00-\u9fa5]{,5}$", tmp_ent[0]):
  96. last_txt = tmp_ent[0]
  97. continue
  98. if len(tmp_ent[0]) < 4 and predict_entities[id2label[0]]:
  99. predict_entities[id2label[0]][-1][0].append(tmp_ent[0])
  100. continue
  101. if last_txt:
  102. tmp_ent.insert(0, last_txt)
  103. start_idx[nn-1] -= 1
  104. last_txt = ""
  105. # 结合con_pred一起判断:
  106. true_con_idx = [i for i in range(start_idx[nn-1], pos) if i not in not_con_idx]
  107. # 去除not_con后,true_con要求必须是连续的
  108. if len(true_con_idx) < pos - start_idx[nn-1] and true_con_idx and true_con_idx[-1] - true_con_idx[0] == len(true_con_idx)-1:
  109. if len(true_con_idx) == 1:
  110. tmp_ent = [raw_text_list[true_con_idx[0]]]
  111. elif len(true_con_idx) > 1:
  112. tmp_ent = raw_text_list[true_con_idx[0]:true_con_idx[-1]+1]
  113. predict_entities[id2label[0]].append((tmp_ent, true_con_idx[0]))
  114. else:
  115. predict_entities[id2label[0]].append((tmp_ent, start_idx[nn-1]))
  116. if nn == len(start_idx) - 1: # 最后一题
  117. tmp_ent = raw_text_list[pos:]
  118. if len(tmp_ent) == 1 and len(tmp_ent[0]) < 4:
  119. if predict_entities[id2label[0]]:
  120. predict_entities[id2label[0]][-1][0].append(tmp_ent[0])
  121. else:
  122. predict_entities[id2label[0]].append((tmp_ent, pos))
  123. else:
  124. predict_entities[id2label[0]].append((tmp_ent, pos))
  125. elif nn == len(start_idx) - 1: # 只有一个题
  126. tmp_ent = raw_text_list
  127. predict_entities[id2label[0]].append((tmp_ent, 0))
  128. # else:
  129. # last_st = nn
  130. # start_pred为空,end_pred不为空,参考end_pred
  131. if not start_idx and end_idx:
  132. last_st = 0
  133. for nn, pos in enumerate(end_idx):
  134. tmp_ent = raw_text_list[last_st: pos+1]
  135. # print("tmp_ent::", tmp_ent)
  136. predict_entities[id2label[0]].append((tmp_ent, last_st))
  137. last_st = pos+1
  138. if nn == len(end_idx) - 1:
  139. tmp_ent = raw_text_list[pos+1:]
  140. # print("tmp_ent222::", tmp_ent)
  141. predict_entities[id2label[0]].append((tmp_ent, pos+1))
  142. # pprint(predict_entities)
  143. return predict_entities, topic_item_pred
  144. def ner_decode(start_logits, end_logits, raw_text_list, id2label):
  145. predict_entities = defaultdict(list)
  146. # print(start_pred)
  147. # print(end_pred)
  148. if len(id2label) > 1: # ner
  149. for label_id in range(len(id2label)):
  150. start_pred = np.where(start_logits[label_id] > 0.5, 1, 0)
  151. end_pred = np.where(end_logits[label_id] > 0.5, 1, 0)
  152. # print(raw_text)
  153. # print(start_pred)
  154. # print(end_pred)
  155. for i, s_type in enumerate(start_pred):
  156. if s_type == 0:
  157. continue
  158. for j, e_type in enumerate(end_pred[i:]):
  159. if s_type == e_type:
  160. tmp_ent = raw_text_list[i:i + j + 1]
  161. if tmp_ent == '':
  162. continue
  163. predict_entities[id2label[label_id]].append((tmp_ent, i))
  164. break
  165. else: # raw_text_list为句子列表
  166. start_pred = np.where(start_logits > 0.5, 1, 0)
  167. end_pred = np.where(end_logits > 0.5, 1, 0)
  168. # 局部纠正:相邻的start以前一个为主;相邻的end以后一个为主
  169. start_pred = [0 if i and n>0 and start_pred[n-1] else i for n, i in enumerate(start_pred)]
  170. end_pred = [0 if i and n+1<len(end_pred) and end_pred[n+1] else i for n, i in enumerate(end_pred)]
  171. print(start_pred, len(start_pred))
  172. print("start_pred索引位置:", [i for i, j in enumerate(start_pred) if j])
  173. print(end_pred)
  174. print("end_pred索引位置:", [i for i, j in enumerate(end_pred) if j])
  175. # 两个TOPIC不可能出现重叠
  176. # start或end召回低时可以参考下对方
  177. last_st = None
  178. next = 0
  179. for i, s_type in enumerate(start_pred):
  180. if s_type == 0:
  181. # start_pred最后一个,没找到end时
  182. if i == len(start_pred) - 1 and last_st is not None:
  183. tmp_ent = raw_text_list[last_st:]
  184. predict_entities[id2label[0]].append((tmp_ent, last_st))
  185. continue
  186. if i < next:
  187. continue
  188. if last_st is not None:
  189. tmp_ent = raw_text_list[last_st:i]
  190. predict_entities[id2label[0]].append((tmp_ent, last_st))
  191. next = i
  192. last_st = None
  193. for j, e_type in enumerate(end_pred[i:]):
  194. if s_type == e_type:
  195. end_num = sum(end_pred[i+j:])
  196. start_num = sum(start_pred[i+1:])
  197. if j > 10 and end_num < start_num - 1:
  198. # end召回低,此时以 下一个start作为结束
  199. last_st = i
  200. else:
  201. tmp_ent = raw_text_list[i:i + j + 1]
  202. predict_entities[id2label[0]].append((tmp_ent, i))
  203. next = i+j+1
  204. break
  205. # 查遍end_pred最后一个还是没有找到时,以下一个start位置为截止
  206. if j == len(end_pred[i:]) - 1:
  207. last_st = i
  208. return predict_entities
  209. def ner_decode2(start_logits, end_logits, length, id2label):
  210. predict_entities = {x:[] for x in list(id2label.values())}
  211. # predict_entities = defaultdict(list)
  212. # print(start_pred)
  213. # print(end_pred)
  214. for label_id in range(len(id2label)):
  215. start_logit = np.where(sigmoid(start_logits[label_id]) > 0.5, 1, 0)
  216. end_logit = np.where(sigmoid(end_logits[label_id]) > 0.5, 1, 0)
  217. # print(start_logit)
  218. # print(end_logit)
  219. # print("="*100)
  220. start_pred = start_logit[1:length + 1]
  221. end_pred = end_logit[1:length+ 1]
  222. for i, s_type in enumerate(start_pred):
  223. if s_type == 0:
  224. continue
  225. for j, e_type in enumerate(end_pred[i:]):
  226. if s_type == e_type:
  227. predict_entities[id2label[label_id]].append((i, i+j+1))
  228. break
  229. return predict_entities
  230. def bj_decode(start_logits, end_logits, length, id2label):
  231. predict_entities = {x:[] for x in list(id2label.values())}
  232. start_logit = np.where(sigmoid(start_logits) > 0.5, 1, 0)
  233. end_logit = np.where(sigmoid(end_logits) > 0.5, 1, 0)
  234. start_pred = start_logit[1:length + 1]
  235. end_pred = end_logit[1:length+ 1]
  236. # print(start_pred)
  237. # print(end_pred)
  238. for i, s_type in enumerate(start_pred):
  239. if s_type == 0:
  240. continue
  241. for j, e_type in enumerate(end_pred[i:]):
  242. if s_type == e_type:
  243. predict_entities[id2label[0]].append((i, i+j+1))
  244. break
  245. return predict_entities
  246. if __name__ == '__main__':
  247. # print(np.zeros([1, 3]))
  248. start_logits = {0:[0.1,0.3,0.7,0.4]}
  249. start_logit = np.where(start_logits[0] > 0.5, 1, 0)
  250. print(start_logit)