dataset_tokenizer.py 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264
  1. import torch
  2. import json
  3. class dtypeEncoder(json.JSONEncoder):
  4. def default(self, obj):
  5. if isinstance(obj, torch.dtype): # might wanna add np, jnp types too?
  6. return str(obj)
  7. return json.JSONEncoder.default(self, obj)
  8. # d = {"torch_dtype": torch.float16}
  9. # # json.dumps(d) ## fail: TypeError: Object of type dtype is not JSON serializable
  10. # json.dumps(d, cls=dtypeEncoder) ## woot: {"torch_dtype": "torch.float16"}
  11. import json
  12. import torch
  13. import sys
  14. import re
  15. sys.path.append(r'/home/cv/workspace/tujintao/document_segmentation')
  16. # from Utils.read_data import read_data
  17. from tqdm import tqdm
  18. import numpy as np
  19. from pprint import pprint
  20. from torch.utils.data import DataLoader, Dataset
  21. from concurrent.futures import ThreadPoolExecutor
  22. class ListDataset(Dataset):
  23. def __init__(self, file_path=None, data=None, tokenizer=None, max_len=None, label_list=None, **kwargs):
  24. self.kwargs = kwargs
  25. self.tokenizer = tokenizer
  26. self.max_len = max_len
  27. self.label_list = label_list
  28. if isinstance(file_path, (str, list)):
  29. self.data = self.load_data(file_path, tokenizer, max_len, label_list)
  30. elif isinstance(data, list):
  31. self.data = data
  32. elif isinstance(data, tuple):
  33. # 数据量大,需要用多进程
  34. all_data, all_allback_info = [], []
  35. executor = ThreadPoolExecutor(max_workers=20) # 开2个线程会稍微快点
  36. for res in executor.map(self.format_data, zip(data[0], data[1])):
  37. all_data.append(res[0])
  38. all_allback_info.append(res[1])
  39. self.data = (all_data, all_allback_info)
  40. # 单进程处理
  41. # self.data = format_data(data, label_list, tokenizer, max_len)
  42. else:
  43. raise ValueError('The input args shall be str format file_path / list format dataset')
  44. def __len__(self):
  45. return len(self.data)
  46. def __getitem__(self, index):
  47. return self.data[index]
  48. def format_data(self, doc_seg_labels):
  49. """
  50. doc_seg_labels:(doc_list:list, seg_labels: list)
  51. 将划分了训练集、验证集、测试集的数据,再按规定格式进行整理
  52. """
  53. one_d, d_labels = doc_seg_labels[0], doc_seg_labels[1]
  54. inputs = self.tokenizer(one_d, padding='max_length', truncation=True,
  55. max_length=self.max_len, return_tensors='pt')
  56. label = []
  57. label_dict = {x: [] for x in self.label_list}
  58. for lab in d_labels:
  59. label.append([lab[0], lab[1], "TOPIC"])
  60. label_dict.get("TOPIC", []).append((one_d[lab[0]:lab[1]], lab[0]))
  61. # label为[[start, end, entity], ...]
  62. # one_d[start:end]为一个topic_item
  63. return (inputs, label), (one_d, label_dict)
  64. # 加载实体(试题)识别数据集
  65. class NerDataset(ListDataset):
  66. @staticmethod
  67. def load_data1(filename, tokenizer, max_len, label_list):
  68. data = []
  69. callback_info = [] # 用于计算评价指标
  70. with open(filename, encoding='utf-8') as f:
  71. f = f.read()
  72. f = json.loads(f)
  73. for d in f:
  74. text = d['text']
  75. if len(text) == 0:
  76. continue
  77. labels = d['labels']
  78. tokens = [i for i in text]
  79. if len(tokens) > max_len - 2:
  80. tokens = tokens[:max_len - 2]
  81. text = text[:max_len]
  82. tokens = ['[CLS]'] + tokens + ['[SEP]']
  83. token_ids = tokenizer.convert_tokens_to_ids(tokens)
  84. label = []
  85. label_dict = {x: [] for x in label_list}
  86. for lab in labels: # 这里需要加上CLS的位置, lab[3]不用加1,因为是实体结尾的后一位
  87. label.append([lab[2] + 1, lab[3], lab[1]])
  88. label_dict.get(lab[1], []).append((text[lab[2]:lab[3]], lab[2]))
  89. data.append((token_ids, label)) # label为[[start, end, entity], ...]
  90. callback_info.append((text, label_dict))
  91. return data, callback_info
  92. from transformers import BertTokenizer
  93. model_dir = r'/home/cv/workspace/tujintao/PointerNet_Chinese_Information_Extraction/UIE/model_hub/chinese-bert-wwm-ext/'
  94. tokenizer = BertTokenizer.from_pretrained(model_dir)
  95. def format_data(doc_seg_labels, max_seq_len=240):
  96. """
  97. doc_seg_labels:doc_list及seg_labels
  98. 将划分了训练集、验证集、测试集的数据,再按规定格式进行整理
  99. """
  100. # one_d, d_labels = doc_seg_labels["input_txt"], doc_seg_labels["segment_label"]
  101. # input_ids = []
  102. # seg_labels = []
  103. # input_txts = []
  104. # data, callback_info = [], []
  105. one_d, d_labels = doc_seg_labels
  106. # for one_d, d_labels in zip(doc_seg_labels["input_txt"], doc_seg_labels["segment_label"]):
  107. # 存在空句子的情况处理
  108. if any([True for sent in one_d if not sent.strip()]):
  109. sentences, new_labels = [], []
  110. for lab in d_labels:
  111. # print(lab)
  112. if not one_d[lab[0]].strip():
  113. one_d[lab[0]+1] = "【start:1】" + one_d[lab[0]+1]
  114. else:
  115. one_d[lab[0]] = "【start:1】" + one_d[lab[0]]
  116. if not one_d[lab[1]-1].strip():
  117. one_d[lab[1]-2] += "【end:1】"
  118. else:
  119. one_d[lab[1]-1] += "【end:1】"
  120. # print(one_d, 999999999999999999999)
  121. all_sents = [sent for sent in one_d if sent.replace("【start:1】", "").replace("【end:1】", "").strip()]
  122. st = 0
  123. for n, sentence in enumerate(all_sents):
  124. if sentence.startswith("【start:1】"):
  125. sentence = sentence.replace("【start:1】", "")
  126. st = n
  127. if sentence.endswith("【end:1】"):
  128. sentence = sentence.replace("【end:1】", "")
  129. new_labels.append((st, n+1))
  130. # pprint(all_sents[st:n+1])
  131. sentences.append(sentence)
  132. one_d = sentences
  133. d_labels = new_labels
  134. # -----------异常检测-----------------------------------
  135. # lst = 0
  136. # for dd in d_labels:
  137. # if dd[1] < dd[0] or dd[0]<lst:
  138. # print("异常标签:", d_labels)
  139. # lst = dd[1]
  140. # print("******************************************************")
  141. inputs = tokenizer(one_d, padding='max_length', truncation=True,
  142. max_length=max_seq_len, return_tensors='pt')
  143. # input_ids.append(inputs)
  144. label = []
  145. label_dict = {"TOPIC": []}
  146. for lab in d_labels:
  147. label.append([lab[0], lab[1], "TOPIC"]) #
  148. label_dict.get("TOPIC", []).append((one_d[lab[0]:lab[1]], lab[0]))
  149. # data.append([inputs, label]) # label为[[start, end, entity], ...]
  150. # callback_info.append([one_d, label_dict]) # one_d[start:end]为一个topic_item
  151. # seg_labels.append(label)
  152. # input_txts.append(one_d)
  153. # callback_info.append(label_dict)
  154. # return {"input_ids": input_ids, "seg_labels": seg_labels,
  155. # "input_txts": input_txts, "callback_info": callback_info} #
  156. return (inputs, label), (one_d, label_dict)
  157. def load_and_split_dataset(filename, train_ratio=0.7, valid_ratio=0.1):
  158. # -----------正式的大批量样本----------------------
  159. with open(filename, "r", encoding="utf-8") as f1:
  160. sample5w = json.load(f1)
  161. input_texts = sample5w["input_txts"]
  162. segment_labels = sample5w["segment_labels"]
  163. # print("input_texts:::", input_texts[:1])
  164. # json.dump({"input_txts": input_txts, "segment_labels": segment_labels}, f1, ensure_ascii=False)
  165. # 把数据划分为 Train/Valid/Test Set
  166. total_samples = len(input_texts)
  167. train_size = int(total_samples * train_ratio)
  168. valid_size = int(total_samples * valid_ratio)
  169. test_size = total_samples - train_size - valid_size
  170. train_doc = input_texts[:train_size]
  171. train_seg_labels = segment_labels[:train_size]
  172. valid_doc = input_texts[train_size:train_size + valid_size]
  173. valid_seg_labels = segment_labels[train_size:train_size + valid_size]
  174. test_doc = input_texts[-test_size:]
  175. test_seg_labels = segment_labels[-test_size:]
  176. return (train_doc, train_seg_labels), (valid_doc, valid_seg_labels), (
  177. test_doc, test_seg_labels)
  178. def main():
  179. # with open(filename, "r", encoding="utf-8") as f1:
  180. # sample5w = json.load(f1)
  181. # input_texts = sample5w["input_txts"]
  182. # segment_labels = sample5w["segment_labels"]
  183. path0 = "/home/cv/workspace/tujintao/document_segmentation/Data/samples/train_data.jsonl"
  184. train_dataset = load_dataset("json", data_files=path0)["train"][:]
  185. all_data, all_allback_info = [], []
  186. executor = ThreadPoolExecutor(max_workers=20) # 开2个线程会稍微快点
  187. for res in executor.map(format_data, zip(train_dataset["input_txt"],train_dataset["segment_label"])):
  188. all_data.append(res[0])
  189. all_allback_info.append(res[1])
  190. # 保存
  191. with open("/home/cv/workspace/tujintao/document_segmentation/Data/samples/token_datasets_6w.json", "w", encoding="utf-8") as f1:
  192. json.dump({"input": all_data, "allback_info": all_allback_info}, f1, ensure_ascii=False)
  193. # return all_data, all_allback_info
  194. if __name__ == "__main__":
  195. from datasets import load_dataset
  196. import datasets
  197. # path0 = "/home/cv/workspace/tujintao/document_segmentation/Data/samples/train_data.jsonl"
  198. # train_dataset = load_dataset("json", data_files=path0)["train"]
  199. # path = "/home/cv/workspace/tujintao/document_segmentation/Data/samples/train_data.json"
  200. main()
  201. # 数据集处理
  202. # train_dataset = train_dataset.map(
  203. # format_data,
  204. # keep_in_memory=True,
  205. # remove_columns=list(train_dataset.features),
  206. # batched=True,
  207. # batch_size=1,
  208. # num_proc=2,
  209. # desc="Running tokenizer on dataset"
  210. # )
  211. # 此方法保存的数据太大
  212. # train_dataset.save_to_disk("/home/cv/workspace/tujintao/document_segmentation/Data/samples/dataset_6w")
  213. # train_data, valid_data, test_data = load_and_split_dataset(path,train_ratio=0.995, valid_ratio=0.003)
  214. # 数据加载
  215. # dataset = datasets.load_from_disk("/home/cv/workspace/tujintao/document_segmentation/Data/samples/dataset_6w")
  216. # a=dataset[0]['input_ids']['attention_mask']
  217. # print(len(a))