data_loader.py 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444
  1. import json
  2. import torch
  3. import sys
  4. import re
  5. import random
  6. # sys.path.append(r'/home/cv/workspace/tujintao/document_segmentation')
  7. from Utils.read_data import read_data
  8. from tqdm import tqdm
  9. import numpy as np
  10. from pprint import pprint
  11. # from Data.paper_combination import generate_paper_math, generate_paper_phy
  12. from torch.utils.data import DataLoader, Dataset
  13. from concurrent.futures import ThreadPoolExecutor
  14. class ListDataset(Dataset):
  15. def __init__(self, file_path=None, data=None, tokenizer=None, max_len=None, label_list=None, **kwargs):
  16. self.kwargs = kwargs
  17. self.tokenizer = tokenizer
  18. self.max_len = max_len
  19. self.label_list = label_list
  20. if isinstance(file_path, (str, list)):
  21. self.data = self.load_data(file_path, tokenizer, max_len, label_list)
  22. elif isinstance(data, list):
  23. self.data = data
  24. elif isinstance(data, tuple):
  25. # 数据量大,需要用多进程
  26. all_data, all_allback_info = [], []
  27. executor = ThreadPoolExecutor(max_workers=10) # 开2个线程会稍微快点
  28. for res in executor.map(self.format_data, zip(data[0], data[1])):
  29. all_data.append(res[0])
  30. all_allback_info.append(res[1])
  31. self.data = (all_data, all_allback_info)
  32. # 单进程处理
  33. # self.data = format_data(data, label_list, tokenizer, max_len)
  34. else:
  35. raise ValueError('The input args shall be str format file_path / list format dataset')
  36. def __len__(self):
  37. return len(self.data)
  38. def __getitem__(self, index):
  39. return self.data[index]
  40. def format_data(self, doc_seg_labels):
  41. """
  42. doc_seg_labels:(doc_list:list, seg_labels: list)
  43. 将划分了训练集、验证集、测试集的数据,再按规定格式进行整理
  44. """
  45. one_d, d_labels = doc_seg_labels[0], doc_seg_labels[1]
  46. # 重新对one_d中的空句子进行处理(原先判断有误:【start:1】【end:1】没有去除再判断)
  47. if any([True for sent in one_d if not sent.strip()]):
  48. sentences, new_labels = [], []
  49. for lab in d_labels:
  50. # print(lab)
  51. if not one_d[lab[0]].strip():
  52. one_d[lab[0]+1] = "【start:1】" + one_d[lab[0]+1]
  53. else:
  54. one_d[lab[0]] = "【start:1】" + one_d[lab[0]]
  55. if not one_d[lab[1]-1].strip():
  56. one_d[lab[1]-2] += "【end:1】"
  57. else:
  58. one_d[lab[1]-1] += "【end:1】"
  59. # print(one_d, 999999999999999999999)
  60. all_sents = [sent for sent in one_d if sent.replace("【start:1】", "").replace("【end:1】", "").strip()]
  61. st = 0
  62. for n, sentence in enumerate(all_sents):
  63. if sentence.startswith("【start:1】"):
  64. sentence = sentence.replace("【start:1】", "")
  65. st = n
  66. if sentence.endswith("【end:1】"):
  67. sentence = sentence.replace("【end:1】", "")
  68. new_labels.append((st, n+1))
  69. # pprint(all_sents[st:n+1])
  70. sentences.append(sentence)
  71. one_d = sentences
  72. d_labels = new_labels
  73. # 句子分词编码
  74. inputs = self.tokenizer(one_d, padding='max_length', truncation=True,
  75. max_length=self.max_len, return_tensors='pt')
  76. label = []
  77. label_dict = {x: [] for x in self.label_list}
  78. for lab in d_labels:
  79. label.append([lab[0], lab[1], "TOPIC"])
  80. label_dict.get("TOPIC", []).append((one_d[lab[0]:lab[1]], lab[0]))
  81. # label为[[start, end, entity], ...]
  82. # one_d[start:end]为一个topic_item
  83. return (inputs, label), (one_d, label_dict)
  84. @staticmethod
  85. def load_data(file_path, tokenizer, max_len, label_list):
  86. return file_path
  87. # 加载实体(试题)识别数据集
  88. class NerDataset(ListDataset):
  89. @staticmethod
  90. def load_data1(filename, tokenizer, max_len, label_list):
  91. data = []
  92. callback_info = [] # 用于计算评价指标
  93. with open(filename, encoding='utf-8') as f:
  94. f = f.read()
  95. f = json.loads(f)
  96. for d in f:
  97. text = d['text']
  98. if len(text) == 0:
  99. continue
  100. labels = d['labels']
  101. tokens = [i for i in text]
  102. if len(tokens) > max_len - 2:
  103. tokens = tokens[:max_len - 2]
  104. text = text[:max_len]
  105. tokens = ['[CLS]'] + tokens + ['[SEP]']
  106. token_ids = tokenizer.convert_tokens_to_ids(tokens)
  107. label = []
  108. label_dict = {x: [] for x in label_list}
  109. for lab in labels: # 这里需要加上CLS的位置, lab[3]不用加1,因为是实体结尾的后一位
  110. label.append([lab[2] + 1, lab[3], lab[1]])
  111. label_dict.get(lab[1], []).append((text[lab[2]:lab[3]], lab[2]))
  112. data.append((token_ids, label)) # label为[[start, end, entity], ...]
  113. callback_info.append((text, label_dict))
  114. return data, callback_info
  115. @staticmethod
  116. def load_data(filename):
  117. """
  118. label_list:所有实体类别标签,本项目只有一个TOPIC
  119. """
  120. # 样本生成
  121. all_documents, all_labels = read_data(filename)
  122. all_pointer_labels = []
  123. for one_d_labels in all_labels:
  124. new_labels = []
  125. st = 0 # 索引从0开始
  126. for n, s_label in enumerate(one_d_labels):
  127. if s_label:
  128. new_labels.append((st, n+1)) # end的索引 +1
  129. st = n+1
  130. all_pointer_labels.append(new_labels)
  131. train_data, valid_data, test_data = split_dataset(all_documents, all_pointer_labels)
  132. # train_doc, train_seg_labels = train_data
  133. # valid_doc, valid_seg_labels = valid_data
  134. # test_doc, test_seg_labels = test_data
  135. return train_data, valid_data, test_data
  136. def format_data(doc_seg_labels, label_list, max_seq_len, tokenizer):
  137. """
  138. doc_seg_labels:doc_list及seg_labels
  139. 将划分了训练集、验证集、测试集的数据,再按规定格式进行整理
  140. """
  141. data = []
  142. callback_info = [] # 用于计算评价指标
  143. for one_d, d_labels in zip(doc_seg_labels[0], doc_seg_labels[1]):
  144. # for sent in one_d:
  145. # if not sent.strip():
  146. # continue
  147. inputs = tokenizer(one_d, padding='max_length', truncation=True,
  148. max_length=max_seq_len, return_tensors='pt')
  149. label = []
  150. label_dict = {x: [] for x in label_list}
  151. for lab in d_labels:
  152. label.append([lab[0], lab[1], "TOPIC"])
  153. label_dict.get("TOPIC", []).append((one_d[lab[0]:lab[1]], lab[0]))
  154. data.append((inputs, label)) # label为[[start, end, entity], ...]
  155. callback_info.append((one_d, label_dict)) # one_d[start:end]为一个topic_item
  156. return data, callback_info
  157. def get_papers(paper_num):
  158. """
  159. 生成指定数量的试卷样本
  160. """
  161. def txt_split(content):
  162. labels = []
  163. sentences = []
  164. st = 0
  165. # all_sents = [i.strip() for i in content.split("\n") if i.strip()]
  166. all_sents = [i.strip() for i in content.split("\n") if i.replace("【start:1】", "").replace("【end:1】", "").strip()]
  167. for n, sentence in enumerate(all_sents):
  168. if sentence.startswith("【start:1】"):
  169. sentence = sentence.replace("【start:1】", "")
  170. st = n
  171. if sentence.endswith("【end:1】"):
  172. sentence = sentence.replace("【end:1】", "")
  173. labels.append((st, n+1))
  174. sentences.append(sentence)
  175. # pprint(sentences)
  176. # pprint(labels)
  177. # a,b = labels[1]
  178. # print(sentences[a:b])
  179. return sentences, labels
  180. input_txts = []
  181. segment_labels = []
  182. for n in range(paper_num):
  183. print(n)
  184. paper_content1 = generate_paper_math(max_questions_num=7, min_questions_num=1)
  185. paper_content2 = generate_paper_phy(max_questions_num=7, min_questions_num=1)
  186. sentences_1, labels_1 = txt_split(paper_content1)
  187. input_txts.append(sentences_1)
  188. segment_labels.append(labels_1)
  189. sentences_2, labels_2 = txt_split(paper_content2)
  190. input_txts.append(sentences_2)
  191. segment_labels.append(labels_2)
  192. with open("/home/cv/workspace/tujintao/document_segmentation/Data/samples/train_data.json", "w", encoding="utf-8") as f1:
  193. json.dump({"input_txts": input_txts, "segment_labels": segment_labels}, f1, ensure_ascii=False)
  194. def get_paper_for_predict():
  195. """
  196. 制作测试预测时用的样本,每次随机生成1份
  197. """
  198. def txt_split(content):
  199. labels = []
  200. sentences = []
  201. st = 0
  202. all_sents = [i.strip() for i in content.split("\n") if i.replace("【start:1】", "").replace("【end:1】", "").strip()]
  203. for n, sentence in enumerate(all_sents):
  204. if sentence.startswith("【start:1】"):
  205. sentence = sentence.replace("【start:1】", "")
  206. st = n
  207. if sentence.endswith("【end:1】"):
  208. sentence = sentence.replace("【end:1】", "")
  209. labels.append((st, n+1))
  210. sentences.append(sentence)
  211. return sentences, labels
  212. if random.choice([0,1]) == 0:
  213. paper_content = generate_paper_math(max_questions_num=7, min_questions_num=1)
  214. else:
  215. paper_content = generate_paper_phy(max_questions_num=7, min_questions_num=1)
  216. pprint(paper_content)
  217. sentences, labels = txt_split(paper_content)
  218. return sentences, labels
  219. def load_and_split_dataset(filename, train_ratio=0.7, valid_ratio=0.1):
  220. # -----------测试时小批量样本---------------
  221. # input_texts, all_labels = read_data(filename)
  222. # segment_labels = []
  223. # for one_d_labels in all_labels[:1]:
  224. # new_labels = []
  225. # st = 0 # 索引从0开始
  226. # for n, s_label in enumerate(one_d_labels):
  227. # if s_label:
  228. # new_labels.append((st, n+1)) # end的索引 +1
  229. # st = n+1
  230. # segment_labels.append(new_labels)
  231. # print("segment_labels:::", segment_labels)
  232. # a, b = segment_labels[0][0]
  233. # print(input_texts[0][a: b])
  234. # -----------正式的大批量样本----------------------
  235. with open(filename, "r", encoding="utf-8") as f1:
  236. sample5w = json.load(f1)
  237. input_texts = sample5w["input_txts"]
  238. segment_labels = sample5w["segment_labels"]
  239. # print("input_texts:::", input_texts[:1])
  240. # json.dump({"input_txts": input_txts, "segment_labels": segment_labels}, f1, ensure_ascii=False)
  241. # 把数据划分为 Train/Valid/Test Set
  242. total_samples = len(input_texts)
  243. train_size = int(total_samples * train_ratio)
  244. valid_size = int(total_samples * valid_ratio)
  245. test_size = total_samples - train_size - valid_size
  246. train_doc = input_texts[:train_size]
  247. train_seg_labels = segment_labels[:train_size]
  248. valid_doc = input_texts[train_size:train_size + valid_size]
  249. valid_seg_labels = segment_labels[train_size:train_size + valid_size]
  250. test_doc = input_texts[-test_size:]
  251. test_seg_labels = segment_labels[-test_size:]
  252. return (train_doc, train_seg_labels), (valid_doc, valid_seg_labels), (
  253. test_doc, test_seg_labels)
  254. def split_dataset(input_texts, segment_labels, train_ratio=0.7, valid_ratio=0.1):
  255. """把数据划分为 Train/Valid/Test Set"""
  256. total_samples = len(input_texts)
  257. train_size = int(total_samples * train_ratio)
  258. valid_size = int(total_samples * valid_ratio)
  259. test_size = total_samples - train_size - valid_size
  260. train_doc = input_texts[:train_size]
  261. train_seg_labels = segment_labels[:train_size]
  262. valid_doc = input_texts[train_size:train_size + valid_size]
  263. valid_seg_labels = segment_labels[train_size:train_size + valid_size]
  264. test_doc = input_texts[-test_size:]
  265. test_seg_labels = segment_labels[-test_size:]
  266. return (train_doc, train_seg_labels), (valid_doc, valid_seg_labels), (
  267. test_doc, test_seg_labels)
  268. def convert_list_to_tensor(alist, dtype=torch.long):
  269. # return torch.tensor(np.array(alist) if isinstance(alist, list) else alist, dtype=dtype)
  270. return [torch.tensor(np.array(a) if isinstance(a, list) else a, dtype=dtype).squeeze(0) for a in alist]
  271. class NerCollate:
  272. def __init__(self, max_len, label2id):
  273. self.maxlen = max_len
  274. self.label2id = label2id
  275. def collate_fn(self, batch):
  276. batch_token_ids = []
  277. batch_attention_mask = []
  278. # batch_token_type_ids = []
  279. batch_start_labels = []
  280. batch_end_labels = []
  281. batch_content_labels = [] # 0 1标签:是否为试题
  282. for i, (inputs, sent_labels) in enumerate(batch):
  283. # a = inputs['input_ids']
  284. token_ids = inputs['input_ids'] #.squeeze(0)
  285. attention_mask = inputs['attention_mask']
  286. sent_num = token_ids.size()[0]
  287. start_labels = np.zeros((len(self.label2id), sent_num), dtype=np.int64)
  288. end_labels = np.zeros((len(self.label2id), sent_num), dtype=np.int64)
  289. content_labels = np.zeros((len(self.label2id), sent_num), dtype=np.int64)
  290. # token_type_ids = [0] * self.maxlen
  291. assert attention_mask.size()[1] == self.maxlen
  292. # assert len(token_type_ids) == self.maxlen
  293. assert token_ids.size()[1] == self.maxlen
  294. batch_token_ids.append(token_ids) # 前面编码时已经限制了长度
  295. batch_attention_mask.append(attention_mask)
  296. # batch_token_type_ids.append(token_type_ids)
  297. # pointer的start、end处理
  298. for start, end, label in sent_labels: # NER标签
  299. label_id = self.label2id[label]
  300. start_labels[label_id][start] = 1
  301. # if end < self.maxlen - 1: #
  302. end_labels[label_id][end-1] = 1
  303. content_labels[label_id][start: end] = [1] * (end - start)
  304. batch_start_labels.append(start_labels)
  305. batch_end_labels.append(end_labels)
  306. batch_content_labels.append(content_labels)
  307. # batch_token_ids = convert_list_to_tensor(batch_token_ids)
  308. # batch_token_type_ids = convert_list_to_tensor(batch_token_type_ids)
  309. # batch_attention_mask = convert_list_to_tensor(batch_attention_mask)
  310. batch_start_labels = convert_list_to_tensor(batch_start_labels, dtype=torch.float)
  311. batch_end_labels = convert_list_to_tensor(batch_end_labels, dtype=torch.float)
  312. batch_content_labels = convert_list_to_tensor(batch_content_labels, dtype=torch.float)
  313. # print("batch_end_labels:::", batch_end_labels[0].size())
  314. res = {
  315. "input_ids": batch_token_ids,
  316. # "token_type_ids": batch_token_type_ids,
  317. "attention_mask": batch_attention_mask,
  318. "ner_start_labels": batch_start_labels,
  319. "ner_end_labels": batch_end_labels,
  320. "ner_content_labels": batch_content_labels,
  321. }
  322. return res
  323. if __name__ == "__main__":
  324. import sys
  325. sys.path.append(r'/home/cv/workspace/tujintao/document_segmentation')
  326. # sys.path.append(r'/home/cv/workspace/tujintao/PointerNet_Chinese_Information_Extraction')
  327. from transformers import BertTokenizer
  328. from Utils.read_data import read_data
  329. # model_dir = r'/home/cv/workspace/tujintao/PointerNet_Chinese_Information_Extraction/UIE/model_hub/chinese-bert-wwm-ext/'
  330. # tokenizer = BertTokenizer.from_pretrained(model_dir)
  331. # 测试
  332. # max_seq_len = 50
  333. # label_path = "PointerNet/data/labels.txt"
  334. # with open(label_path,"r") as fp:
  335. # labels = fp.read().strip().split("\n")
  336. # train_dataset, train_callback = NerDataset(file_path=r"Data/samples",
  337. # tokenizer=tokenizer,
  338. # max_len=max_seq_len,
  339. # label_list=labels)
  340. # print(train_dataset[0])
  341. # # 测试实体识别
  342. # ============================
  343. # max_seq_len = 150
  344. # label_path = "PointerNet_Chinese_Information_Extraction/UIE/data/ner/cner/labels.txt"
  345. # with open(label_path,"r") as fp:
  346. # labels = fp.read().strip().split("\n")
  347. # train_dataset, train_callback = NerDataset(file_path='PointerNet_Chinese_Information_Extraction/UIE/data/ner/cner/train.json',
  348. # tokenizer=tokenizer,
  349. # max_len=max_seq_len,
  350. # label_list=labels)
  351. # print(train_dataset[1]) # ([101, 2382, 2456, 5679, 8024, 4511, 8024, 102], [[1, 3, 'NAME']])
  352. # print(train_callback[1])
  353. # ('常建良,男,', {'TITLE': [], 'RACE': [], 'CONT': [], 'ORG': [], 'NAME': [('常建良', 0)], 'EDU': [], 'PRO': [], 'LOC': []})
  354. # ------------------------------------------------------------------
  355. # id2tag = {}
  356. # tag2id = {}
  357. # for i, label in enumerate(labels):
  358. # id2tag[i] = label
  359. # tag2id[label] = i
  360. # collate = NerCollate(max_len=max_seq_len, label2id=tag2id)
  361. # batch_size = 2
  362. # train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=False, collate_fn=collate.collate_fn)
  363. # for batch in tqdm(train_dataloader):
  364. # print(222222222222222222, batch)
  365. # # print(batch["ner_start_labels"].shape) #[bs, label_n, maxlen]
  366. # # # for k, v in batch.items():
  367. # # # print(k,v.shape)
  368. # # ============================
  369. # get_papers(50000) # 取试卷
  370. a, b = get_paper_for_predict()
  371. # print(a)
  372. # print(b)
  373. # print(a[6:8])
  374. for i in b:
  375. print(a[i[0]: i[1]])
  376. # ------------------------------------------
  377. # load_and_split_dataset(r"Data/samples/临时样本")
  378. # load_and_split_dataset("")
  379. from PointerNet.config import NerArgs
  380. args = NerArgs()
  381. # load_and_split_dataset(args.train_path,train_ratio=0.995, valid_ratio=0.003)