import numpy as np import re from collections import defaultdict def sigmoid(x): return 1 / (1 + np.exp(-x)) def topic_ner_decode_1(start_logits, end_logits, raw_text_list, id2label): """ 针对第一批400份测试样本,由于标签位置不同,会有稍许差异 """ predict_entities = defaultdict(list) if len(id2label) == 1: # raw_text_list为句子列表 start_pred = np.where(start_logits > 0.5, 1, 0) end_pred = np.where(end_logits > 0.5, 1, 0) # 局部纠正:相邻的start以前一个为主;相邻的end以后一个为主 start_pred = [0 if i and n>0 and start_pred[n-1] else i for n, i in enumerate(start_pred)] end_pred = [0 if i and n+1 0: # 若考虑end_pred会降低准确率, # if pos - start_idx[nn-1] > 15: # may_split = [i for i in end_idx if i < pos and i > start_idx[nn-1]] # for may_pos in may_split: # if may_pos - last_st >= 2: # 一个题至少3个句子 # tmp_ent = raw_text_list[last_st: may_pos+1] # predict_entities[id2label[0]].append((tmp_ent, last_st)) # last_st = may_pos + 1 # else: tmp_ent = raw_text_list[start_idx[nn-1]: pos] predict_entities[id2label[0]].append((tmp_ent, start_idx[nn-1])) elif nn == len(start_idx) - 1: # 只有一个题 tmp_ent = raw_text_list predict_entities[id2label[0]].append((tmp_ent, 0)) return predict_entities def content_bc_decode(con_logits): """ 对content_logit设置阈值,判断是否为试题内容 """ con_pred = np.where(con_logits > 0.5, 1, 0) return con_pred def topic_ner_decode(start_logits, end_logits, con_logits, raw_text_list, id2label): """ 理想状态:start召回个数应该与end召回个数一样,但实际不同 """ predict_entities = defaultdict(list) topic_item_pred = [] if len(id2label) == 1: # raw_text_list为句子列表 start_pred = np.where(start_logits > 0.5, 1, 0) # 设置更高的阈值将会使准确率更高,但召回率降低 end_pred = np.where(end_logits > 0.5, 1, 0) con_pred = np.where(con_logits > 0.5, 1, 0) topic_item_pred = con_pred.tolist() # 局部纠正:相邻的start以前一个为主;相邻的end以后一个为主 # start_pred = [0 if i and n>0 and start_pred[n-1] else i for n, i in enumerate(start_pred)] end_pred = [0 if i and n+1 0: # 若考虑end_pred会降低准确率, # if pos - start_idx[nn-1] > 15: # may_split = [i for i in end_idx if i < pos and i > start_idx[nn-1]] # for may_pos in may_split: # if may_pos - last_st >= 2: # 一个题至少3个句子 # tmp_ent = raw_text_list[last_st: may_pos+1] # predict_entities[id2label[0]].append((tmp_ent, last_st)) # last_st = may_pos + 1 # else: tmp_ent = raw_text_list[start_idx[nn-1]: pos] # 只有单行文本时对很短的文本:数字则放在下一行,其他放上一行 if len(tmp_ent) == 1 and con_pred[pos-1]: if re.match(r"\d\s*[^\u4e00-\u9fa5]?$", tmp_ent[0]): last_txt = tmp_ent[0] continue if re.match(r"\d\s*[.、.::]\s*[\u4e00-\u9fa5]{,5}$", tmp_ent[0]): last_txt = tmp_ent[0] continue if len(tmp_ent[0]) < 4 and predict_entities[id2label[0]]: predict_entities[id2label[0]][-1][0].append(tmp_ent[0]) continue if last_txt: tmp_ent.insert(0, last_txt) start_idx[nn-1] -= 1 last_txt = "" # 结合con_pred一起判断: true_con_idx = [i for i in range(start_idx[nn-1], pos) if i not in not_con_idx] # 去除not_con后,true_con要求必须是连续的 if len(true_con_idx) < pos - start_idx[nn-1] and true_con_idx and true_con_idx[-1] - true_con_idx[0] == len(true_con_idx)-1: if len(true_con_idx) == 1: tmp_ent = [raw_text_list[true_con_idx[0]]] elif len(true_con_idx) > 1: tmp_ent = raw_text_list[true_con_idx[0]:true_con_idx[-1]+1] predict_entities[id2label[0]].append((tmp_ent, true_con_idx[0])) else: predict_entities[id2label[0]].append((tmp_ent, start_idx[nn-1])) if nn == len(start_idx) - 1: # 最后一题 tmp_ent = raw_text_list[pos:] if len(tmp_ent) == 1 and len(tmp_ent[0]) < 4: if predict_entities[id2label[0]]: predict_entities[id2label[0]][-1][0].append(tmp_ent[0]) else: predict_entities[id2label[0]].append((tmp_ent, pos)) else: predict_entities[id2label[0]].append((tmp_ent, pos)) elif nn == len(start_idx) - 1: # 只有一个题 tmp_ent = raw_text_list predict_entities[id2label[0]].append((tmp_ent, 0)) # else: # last_st = nn # start_pred为空,end_pred不为空,参考end_pred if not start_idx and end_idx: last_st = 0 for nn, pos in enumerate(end_idx): tmp_ent = raw_text_list[last_st: pos+1] # print("tmp_ent::", tmp_ent) predict_entities[id2label[0]].append((tmp_ent, last_st)) last_st = pos+1 if nn == len(end_idx) - 1: tmp_ent = raw_text_list[pos+1:] # print("tmp_ent222::", tmp_ent) predict_entities[id2label[0]].append((tmp_ent, pos+1)) # pprint(predict_entities) return predict_entities, topic_item_pred def ner_decode(start_logits, end_logits, raw_text_list, id2label): predict_entities = defaultdict(list) # print(start_pred) # print(end_pred) if len(id2label) > 1: # ner for label_id in range(len(id2label)): start_pred = np.where(start_logits[label_id] > 0.5, 1, 0) end_pred = np.where(end_logits[label_id] > 0.5, 1, 0) # print(raw_text) # print(start_pred) # print(end_pred) for i, s_type in enumerate(start_pred): if s_type == 0: continue for j, e_type in enumerate(end_pred[i:]): if s_type == e_type: tmp_ent = raw_text_list[i:i + j + 1] if tmp_ent == '': continue predict_entities[id2label[label_id]].append((tmp_ent, i)) break else: # raw_text_list为句子列表 start_pred = np.where(start_logits > 0.5, 1, 0) end_pred = np.where(end_logits > 0.5, 1, 0) # 局部纠正:相邻的start以前一个为主;相邻的end以后一个为主 start_pred = [0 if i and n>0 and start_pred[n-1] else i for n, i in enumerate(start_pred)] end_pred = [0 if i and n+1 10 and end_num < start_num - 1: # end召回低,此时以 下一个start作为结束 last_st = i else: tmp_ent = raw_text_list[i:i + j + 1] predict_entities[id2label[0]].append((tmp_ent, i)) next = i+j+1 break # 查遍end_pred最后一个还是没有找到时,以下一个start位置为截止 if j == len(end_pred[i:]) - 1: last_st = i return predict_entities def ner_decode2(start_logits, end_logits, length, id2label): predict_entities = {x:[] for x in list(id2label.values())} # predict_entities = defaultdict(list) # print(start_pred) # print(end_pred) for label_id in range(len(id2label)): start_logit = np.where(sigmoid(start_logits[label_id]) > 0.5, 1, 0) end_logit = np.where(sigmoid(end_logits[label_id]) > 0.5, 1, 0) # print(start_logit) # print(end_logit) # print("="*100) start_pred = start_logit[1:length + 1] end_pred = end_logit[1:length+ 1] for i, s_type in enumerate(start_pred): if s_type == 0: continue for j, e_type in enumerate(end_pred[i:]): if s_type == e_type: predict_entities[id2label[label_id]].append((i, i+j+1)) break return predict_entities def bj_decode(start_logits, end_logits, length, id2label): predict_entities = {x:[] for x in list(id2label.values())} start_logit = np.where(sigmoid(start_logits) > 0.5, 1, 0) end_logit = np.where(sigmoid(end_logits) > 0.5, 1, 0) start_pred = start_logit[1:length + 1] end_pred = end_logit[1:length+ 1] # print(start_pred) # print(end_pred) for i, s_type in enumerate(start_pred): if s_type == 0: continue for j, e_type in enumerate(end_pred[i:]): if s_type == e_type: predict_entities[id2label[0]].append((i, i+j+1)) break return predict_entities if __name__ == '__main__': # print(np.zeros([1, 3])) start_logits = {0:[0.1,0.3,0.7,0.4]} start_logit = np.where(start_logits[0] > 0.5, 1, 0) print(start_logit)