import torch import json class dtypeEncoder(json.JSONEncoder): def default(self, obj): if isinstance(obj, torch.dtype): # might wanna add np, jnp types too? return str(obj) return json.JSONEncoder.default(self, obj) # d = {"torch_dtype": torch.float16} # # json.dumps(d) ## fail: TypeError: Object of type dtype is not JSON serializable # json.dumps(d, cls=dtypeEncoder) ## woot: {"torch_dtype": "torch.float16"} import json import torch import sys import re sys.path.append(r'/home/cv/workspace/tujintao/document_segmentation') # from Utils.read_data import read_data from tqdm import tqdm import numpy as np from pprint import pprint from torch.utils.data import DataLoader, Dataset from concurrent.futures import ThreadPoolExecutor class ListDataset(Dataset): def __init__(self, file_path=None, data=None, tokenizer=None, max_len=None, label_list=None, **kwargs): self.kwargs = kwargs self.tokenizer = tokenizer self.max_len = max_len self.label_list = label_list if isinstance(file_path, (str, list)): self.data = self.load_data(file_path, tokenizer, max_len, label_list) elif isinstance(data, list): self.data = data elif isinstance(data, tuple): # 数据量大,需要用多进程 all_data, all_allback_info = [], [] executor = ThreadPoolExecutor(max_workers=20) # 开2个线程会稍微快点 for res in executor.map(self.format_data, zip(data[0], data[1])): all_data.append(res[0]) all_allback_info.append(res[1]) self.data = (all_data, all_allback_info) # 单进程处理 # self.data = format_data(data, label_list, tokenizer, max_len) else: raise ValueError('The input args shall be str format file_path / list format dataset') def __len__(self): return len(self.data) def __getitem__(self, index): return self.data[index] def format_data(self, doc_seg_labels): """ doc_seg_labels:(doc_list:list, seg_labels: list) 将划分了训练集、验证集、测试集的数据,再按规定格式进行整理 """ one_d, d_labels = doc_seg_labels[0], doc_seg_labels[1] inputs = self.tokenizer(one_d, padding='max_length', truncation=True, max_length=self.max_len, return_tensors='pt') label = [] label_dict = {x: [] for x in self.label_list} for lab in d_labels: label.append([lab[0], lab[1], "TOPIC"]) label_dict.get("TOPIC", []).append((one_d[lab[0]:lab[1]], lab[0])) # label为[[start, end, entity], ...] # one_d[start:end]为一个topic_item return (inputs, label), (one_d, label_dict) # 加载实体(试题)识别数据集 class NerDataset(ListDataset): @staticmethod def load_data1(filename, tokenizer, max_len, label_list): data = [] callback_info = [] # 用于计算评价指标 with open(filename, encoding='utf-8') as f: f = f.read() f = json.loads(f) for d in f: text = d['text'] if len(text) == 0: continue labels = d['labels'] tokens = [i for i in text] if len(tokens) > max_len - 2: tokens = tokens[:max_len - 2] text = text[:max_len] tokens = ['[CLS]'] + tokens + ['[SEP]'] token_ids = tokenizer.convert_tokens_to_ids(tokens) label = [] label_dict = {x: [] for x in label_list} for lab in labels: # 这里需要加上CLS的位置, lab[3]不用加1,因为是实体结尾的后一位 label.append([lab[2] + 1, lab[3], lab[1]]) label_dict.get(lab[1], []).append((text[lab[2]:lab[3]], lab[2])) data.append((token_ids, label)) # label为[[start, end, entity], ...] callback_info.append((text, label_dict)) return data, callback_info from transformers import BertTokenizer model_dir = r'/home/cv/workspace/tujintao/PointerNet_Chinese_Information_Extraction/UIE/model_hub/chinese-bert-wwm-ext/' tokenizer = BertTokenizer.from_pretrained(model_dir) def format_data(doc_seg_labels, max_seq_len=240): """ doc_seg_labels:doc_list及seg_labels 将划分了训练集、验证集、测试集的数据,再按规定格式进行整理 """ # one_d, d_labels = doc_seg_labels["input_txt"], doc_seg_labels["segment_label"] # input_ids = [] # seg_labels = [] # input_txts = [] # data, callback_info = [], [] one_d, d_labels = doc_seg_labels # for one_d, d_labels in zip(doc_seg_labels["input_txt"], doc_seg_labels["segment_label"]): # 存在空句子的情况处理 if any([True for sent in one_d if not sent.strip()]): sentences, new_labels = [], [] for lab in d_labels: # print(lab) if not one_d[lab[0]].strip(): one_d[lab[0]+1] = "【start:1】" + one_d[lab[0]+1] else: one_d[lab[0]] = "【start:1】" + one_d[lab[0]] if not one_d[lab[1]-1].strip(): one_d[lab[1]-2] += "【end:1】" else: one_d[lab[1]-1] += "【end:1】" # print(one_d, 999999999999999999999) all_sents = [sent for sent in one_d if sent.replace("【start:1】", "").replace("【end:1】", "").strip()] st = 0 for n, sentence in enumerate(all_sents): if sentence.startswith("【start:1】"): sentence = sentence.replace("【start:1】", "") st = n if sentence.endswith("【end:1】"): sentence = sentence.replace("【end:1】", "") new_labels.append((st, n+1)) # pprint(all_sents[st:n+1]) sentences.append(sentence) one_d = sentences d_labels = new_labels # -----------异常检测----------------------------------- # lst = 0 # for dd in d_labels: # if dd[1] < dd[0] or dd[0]