123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270 |
- import numpy as np
- import re
- from collections import defaultdict
- def sigmoid(x):
- return 1 / (1 + np.exp(-x))
- def topic_ner_decode_1(start_logits, end_logits, raw_text_list, id2label):
- """
- 针对第一批400份测试样本,由于标签位置不同,会有稍许差异
- """
- predict_entities = defaultdict(list)
- if len(id2label) == 1:
- # raw_text_list为句子列表
- start_pred = np.where(start_logits > 0.5, 1, 0)
- end_pred = np.where(end_logits > 0.5, 1, 0)
- # 局部纠正:相邻的start以前一个为主;相邻的end以后一个为主
- start_pred = [0 if i and n>0 and start_pred[n-1] else i for n, i in enumerate(start_pred)]
- end_pred = [0 if i and n+1<len(end_pred) and end_pred[n+1] else i for n, i in enumerate(end_pred)]
- start_idx = [i for i, j in enumerate(start_pred) if j]
- end_idx = [i for i, j in enumerate(end_pred) if j]
- print(start_pred, len(start_pred))
- print("start_pred索引位置:", start_idx)
- print(end_pred)
- print("end_pred索引位置:", end_idx)
- # 两个TOPIC不可能出现重叠
- # start或end召回低时可以参考对方位置
- last_st = 0
- for nn, pos in enumerate(start_idx):
- if nn > 0:
- # 若考虑end_pred会降低准确率,
- # if pos - start_idx[nn-1] > 15:
- # may_split = [i for i in end_idx if i < pos and i > start_idx[nn-1]]
- # for may_pos in may_split:
- # if may_pos - last_st >= 2: # 一个题至少3个句子
- # tmp_ent = raw_text_list[last_st: may_pos+1]
- # predict_entities[id2label[0]].append((tmp_ent, last_st))
- # last_st = may_pos + 1
- # else:
- tmp_ent = raw_text_list[start_idx[nn-1]: pos]
- predict_entities[id2label[0]].append((tmp_ent, start_idx[nn-1]))
- elif nn == len(start_idx) - 1: # 只有一个题
- tmp_ent = raw_text_list
- predict_entities[id2label[0]].append((tmp_ent, 0))
- return predict_entities
- def content_bc_decode(con_logits):
- """
- 对content_logit设置阈值,判断是否为试题内容
- """
- con_pred = np.where(con_logits > 0.5, 1, 0)
- return con_pred
-
- def topic_ner_decode(start_logits, end_logits, con_logits, raw_text_list, id2label):
- """
- 理想状态:start召回个数应该与end召回个数一样,但实际不同
- """
- predict_entities = defaultdict(list)
- topic_item_pred = []
- if len(id2label) == 1:
- # raw_text_list为句子列表
- start_pred = np.where(start_logits > 0.5, 1, 0) # 设置更高的阈值将会使准确率更高,但召回率降低
- end_pred = np.where(end_logits > 0.5, 1, 0)
- con_pred = np.where(con_logits > 0.5, 1, 0)
- topic_item_pred = con_pred.tolist()
- # 局部纠正:相邻的start以前一个为主;相邻的end以后一个为主
- # start_pred = [0 if i and n>0 and start_pred[n-1] else i for n, i in enumerate(start_pred)]
- end_pred = [0 if i and n+1<len(end_pred) and end_pred[n+1] else i for n, i in enumerate(end_pred)]
- start_idx = [i for i, j in enumerate(start_pred) if j]
- end_idx = [i for i, j in enumerate(end_pred) if j]
- not_con_idx = [i for i, j in enumerate(con_pred) if not j]
- print("start_logits:", start_logits)
- print(start_pred, len(start_pred))
- print("start_pred索引位置:", start_idx)
- print(end_pred)
- print("end_pred索引位置:", end_idx)
- print(con_pred)
- print("not_con_idx索引位置:", not_con_idx)
- # 两个TOPIC不可能出现重叠
- last_st = 0
- last_txt = ""
- for nn, pos in enumerate(start_idx):
- if nn > 0:
- # 若考虑end_pred会降低准确率,
- # if pos - start_idx[nn-1] > 15:
- # may_split = [i for i in end_idx if i < pos and i > start_idx[nn-1]]
- # for may_pos in may_split:
- # if may_pos - last_st >= 2: # 一个题至少3个句子
- # tmp_ent = raw_text_list[last_st: may_pos+1]
- # predict_entities[id2label[0]].append((tmp_ent, last_st))
- # last_st = may_pos + 1
- # else:
- tmp_ent = raw_text_list[start_idx[nn-1]: pos]
- # 只有单行文本时对很短的文本:数字则放在下一行,其他放上一行
- if len(tmp_ent) == 1 and con_pred[pos-1]:
- if re.match(r"\d\s*[^\u4e00-\u9fa5]?$", tmp_ent[0]):
- last_txt = tmp_ent[0]
- continue
- if re.match(r"\d\s*[.、.::]\s*[\u4e00-\u9fa5]{,5}$", tmp_ent[0]):
- last_txt = tmp_ent[0]
- continue
- if len(tmp_ent[0]) < 4 and predict_entities[id2label[0]]:
- predict_entities[id2label[0]][-1][0].append(tmp_ent[0])
- continue
- if last_txt:
- tmp_ent.insert(0, last_txt)
- start_idx[nn-1] -= 1
- last_txt = ""
- # 结合con_pred一起判断:
- true_con_idx = [i for i in range(start_idx[nn-1], pos) if i not in not_con_idx]
- # 去除not_con后,true_con要求必须是连续的
- if len(true_con_idx) < pos - start_idx[nn-1] and true_con_idx and true_con_idx[-1] - true_con_idx[0] == len(true_con_idx)-1:
- if len(true_con_idx) == 1:
- tmp_ent = [raw_text_list[true_con_idx[0]]]
- elif len(true_con_idx) > 1:
- tmp_ent = raw_text_list[true_con_idx[0]:true_con_idx[-1]+1]
- predict_entities[id2label[0]].append((tmp_ent, true_con_idx[0]))
- else:
- predict_entities[id2label[0]].append((tmp_ent, start_idx[nn-1]))
- if nn == len(start_idx) - 1: # 最后一题
- tmp_ent = raw_text_list[pos:]
- if len(tmp_ent) == 1 and len(tmp_ent[0]) < 4:
- if predict_entities[id2label[0]]:
- predict_entities[id2label[0]][-1][0].append(tmp_ent[0])
- else:
- predict_entities[id2label[0]].append((tmp_ent, pos))
- else:
- predict_entities[id2label[0]].append((tmp_ent, pos))
- elif nn == len(start_idx) - 1: # 只有一个题
- tmp_ent = raw_text_list
- predict_entities[id2label[0]].append((tmp_ent, 0))
- # else:
- # last_st = nn
- # start_pred为空,end_pred不为空,参考end_pred
- if not start_idx and end_idx:
- last_st = 0
- for nn, pos in enumerate(end_idx):
- tmp_ent = raw_text_list[last_st: pos+1]
- # print("tmp_ent::", tmp_ent)
- predict_entities[id2label[0]].append((tmp_ent, last_st))
- last_st = pos+1
- if nn == len(end_idx) - 1:
- tmp_ent = raw_text_list[pos+1:]
- # print("tmp_ent222::", tmp_ent)
- predict_entities[id2label[0]].append((tmp_ent, pos+1))
- # pprint(predict_entities)
- return predict_entities, topic_item_pred
- def ner_decode(start_logits, end_logits, raw_text_list, id2label):
- predict_entities = defaultdict(list)
- # print(start_pred)
- # print(end_pred)
- if len(id2label) > 1: # ner
- for label_id in range(len(id2label)):
- start_pred = np.where(start_logits[label_id] > 0.5, 1, 0)
- end_pred = np.where(end_logits[label_id] > 0.5, 1, 0)
- # print(raw_text)
- # print(start_pred)
- # print(end_pred)
- for i, s_type in enumerate(start_pred):
- if s_type == 0:
- continue
- for j, e_type in enumerate(end_pred[i:]):
- if s_type == e_type:
- tmp_ent = raw_text_list[i:i + j + 1]
- if tmp_ent == '':
- continue
- predict_entities[id2label[label_id]].append((tmp_ent, i))
- break
- else: # raw_text_list为句子列表
- start_pred = np.where(start_logits > 0.5, 1, 0)
- end_pred = np.where(end_logits > 0.5, 1, 0)
- # 局部纠正:相邻的start以前一个为主;相邻的end以后一个为主
- start_pred = [0 if i and n>0 and start_pred[n-1] else i for n, i in enumerate(start_pred)]
- end_pred = [0 if i and n+1<len(end_pred) and end_pred[n+1] else i for n, i in enumerate(end_pred)]
- print(start_pred, len(start_pred))
- print("start_pred索引位置:", [i for i, j in enumerate(start_pred) if j])
- print(end_pred)
- print("end_pred索引位置:", [i for i, j in enumerate(end_pred) if j])
- # 两个TOPIC不可能出现重叠
- # start或end召回低时可以参考下对方
- last_st = None
- next = 0
- for i, s_type in enumerate(start_pred):
- if s_type == 0:
- # start_pred最后一个,没找到end时
- if i == len(start_pred) - 1 and last_st is not None:
- tmp_ent = raw_text_list[last_st:]
- predict_entities[id2label[0]].append((tmp_ent, last_st))
- continue
- if i < next:
- continue
- if last_st is not None:
- tmp_ent = raw_text_list[last_st:i]
- predict_entities[id2label[0]].append((tmp_ent, last_st))
- next = i
- last_st = None
- for j, e_type in enumerate(end_pred[i:]):
- if s_type == e_type:
- end_num = sum(end_pred[i+j:])
- start_num = sum(start_pred[i+1:])
- if j > 10 and end_num < start_num - 1:
- # end召回低,此时以 下一个start作为结束
- last_st = i
- else:
- tmp_ent = raw_text_list[i:i + j + 1]
- predict_entities[id2label[0]].append((tmp_ent, i))
- next = i+j+1
- break
- # 查遍end_pred最后一个还是没有找到时,以下一个start位置为截止
- if j == len(end_pred[i:]) - 1:
- last_st = i
- return predict_entities
- def ner_decode2(start_logits, end_logits, length, id2label):
- predict_entities = {x:[] for x in list(id2label.values())}
- # predict_entities = defaultdict(list)
- # print(start_pred)
- # print(end_pred)
- for label_id in range(len(id2label)):
- start_logit = np.where(sigmoid(start_logits[label_id]) > 0.5, 1, 0)
- end_logit = np.where(sigmoid(end_logits[label_id]) > 0.5, 1, 0)
- # print(start_logit)
- # print(end_logit)
- # print("="*100)
- start_pred = start_logit[1:length + 1]
- end_pred = end_logit[1:length+ 1]
-
- for i, s_type in enumerate(start_pred):
- if s_type == 0:
- continue
- for j, e_type in enumerate(end_pred[i:]):
- if s_type == e_type:
- predict_entities[id2label[label_id]].append((i, i+j+1))
- break
- return predict_entities
- def bj_decode(start_logits, end_logits, length, id2label):
- predict_entities = {x:[] for x in list(id2label.values())}
- start_logit = np.where(sigmoid(start_logits) > 0.5, 1, 0)
- end_logit = np.where(sigmoid(end_logits) > 0.5, 1, 0)
- start_pred = start_logit[1:length + 1]
- end_pred = end_logit[1:length+ 1]
- # print(start_pred)
- # print(end_pred)
- for i, s_type in enumerate(start_pred):
- if s_type == 0:
- continue
- for j, e_type in enumerate(end_pred[i:]):
- if s_type == e_type:
- predict_entities[id2label[0]].append((i, i+j+1))
- break
- return predict_entities
- if __name__ == '__main__':
- # print(np.zeros([1, 3]))
- start_logits = {0:[0.1,0.3,0.7,0.4]}
- start_logit = np.where(start_logits[0] > 0.5, 1, 0)
- print(start_logit)
|