read_data.py 8.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173
  1. import os
  2. import re
  3. import sys
  4. sys.path.append("/home/cv/workspace/tujintao/document_segmentation")
  5. from Utils.main_clear.sci_clear import non_data_latex_iter
  6. filename = "Data/samples"
  7. def read_data(directory):
  8. all_documents = []
  9. all_labels = []
  10. for filename in os.listdir(directory)[:2]:
  11. if filename.endswith(".txt"):
  12. filepath = os.path.join(directory, filename)
  13. # print(filepath)
  14. # 读取txt文件内容并处理每一行结尾
  15. with open(filepath, "r", encoding="utf-8") as file:
  16. lines = file.readlines()
  17. # for i in range(len(lines)):
  18. # if not lines[i].endswith("<br/>------------------------1\n"):
  19. # lines[i] = re.sub(r'------------------------1$', '<br/>------------------------1\n', lines[i])
  20. # 将所有行的内容拼接为一行,并清除无关符号
  21. text = "".join(lines)
  22. text = re.sub(r'</body>|</html>|<sub>|</sub>|<p>|</p>|<td>\s*</td>', '', text)
  23. text = re.sub(r'<td .+?["\']>\s*</td>', '', text)
  24. text = re.sub(r'</td>\s*<td( .+?["\'])?>', ' ', text)
  25. text = re.sub(r'<tr .+?["\']>\s*</tr>|</?table>|</?tbody>|<table .+?["\']>', '', text)
  26. text = re.sub(r'<tr( .+?["\'])?>\s*<td( .+?["\'])?>', '<tr>', text)
  27. text = re.sub('</td></tr>', '', text)
  28. text = re.sub("【<img .*?\"\s*/?>公式latex提取失败】", "【公式】", text)
  29. text = re.sub("<img .*?[\"']\s*/?>", "【图片】", text)
  30. text = re.sub(r"<span style=\"color: red\">(.*?)</span>", r"\1", text)
  31. text = non_data_latex_iter(text)
  32. text = re.sub(r'<b\*?r\s*/?>', '\n', text)
  33. text = re.sub(r'(------------------------1)(?!\n)', r'\1\n', text)
  34. # print(text)
  35. # 提取标签, 并过滤掉空句子
  36. labels = []
  37. sentences = []
  38. for sentence in text.split("\n"):
  39. # print(sentence)
  40. if re.search("------------------------1", sentence.strip()):
  41. if sentence.strip().startswith("------------------------1"):
  42. if labels:
  43. labels[-1] = 1 # 将前一个非空句子的标签设为1
  44. else:
  45. labels.append(1)
  46. sentence = re.sub("------------------------1", "", sentence)
  47. else:
  48. if sentence.strip():
  49. labels.append(0)
  50. if sentence.strip():
  51. sentences.append(sentence.strip())
  52. print("句子数目:", len(sentences))
  53. all_documents.append(sentences)
  54. all_labels.append(labels)
  55. return all_documents, all_labels
  56. def split_dataset(input_texts, segment_labels, train_ratio=0.7, valid_ratio=0.1):
  57. """把数据划分为 Train/Valid/Test Set"""
  58. total_samples = len(input_texts)
  59. train_size = int(total_samples * train_ratio)
  60. valid_size = int(total_samples * valid_ratio)
  61. test_size = total_samples - train_size - valid_size
  62. train_doc = input_texts[:train_size]
  63. train_seg_labels = segment_labels[:train_size]
  64. valid_doc = input_texts[train_size:train_size + valid_size]
  65. valid_seg_labels = segment_labels[train_size:train_size + valid_size]
  66. test_doc = input_texts[-test_size:]
  67. test_seg_labels = segment_labels[-test_size:]
  68. return (train_doc, train_seg_labels), (valid_doc, valid_seg_labels), (
  69. test_doc, test_seg_labels)
  70. def get_token(sentences):
  71. all_tokens = tokenizer.encode("\n".join(sentences))
  72. bef_sent_tokens = []
  73. aft_sent_tokens = all_tokens[1:-1] # 包含当前句及之后句token
  74. sents_token_range = []
  75. sent_idx = [] #(start_idx, local_sent_len):相对应索引的句子中真正想要的实际句子索引
  76. tokens_per_sent = []
  77. for idx, one_sent in enumerate(sentences):
  78. token_id = tokenizer.encode(one_sent)
  79. print(token_id)
  80. tokens_per_sent.append(token_id[1:-1])
  81. local_sent_len = len(token_id[1:-1])
  82. if not idx:
  83. sents_token_range.append(aft_sent_tokens[0:510])
  84. sent_idx.append("{},{}".format(0, local_sent_len))
  85. else:
  86. if len(token_id[1:-1]) > 200: # 当前句token长超过200时,开始截断,再前面取 150 后取160
  87. bef_lenght = 150
  88. elif idx == len(sentences) - 1: # 最后一个句子
  89. bef_lenght = 200
  90. else:
  91. bef_lenght = int((510 - len(token_id[1:-1])) * 0.4)
  92. aft_lenght = 510 - len(bef_sent_tokens[-bef_lenght:]) # 当bef_sent_tokens中不到150个数时
  93. sents_token_range.append(bef_sent_tokens[-bef_lenght:] + aft_sent_tokens[:aft_lenght])
  94. sent_idx.append("{},{}".format(len(bef_sent_tokens[-bef_lenght:]), local_sent_len))
  95. aft_sent_tokens = aft_sent_tokens[local_sent_len:]
  96. bef_sent_tokens.extend(token_id[1:-1])
  97. return tokens_per_sent, sents_token_range, sent_idx
  98. if __name__ == '__main__':
  99. text = r"""
  100. <table><tr><td></td><td>物理量</td><td>1</td><td>2</td><td>3</td><td>4</td><td>5</td><td>6</td><td>7</td><td>8</td></tr><tr><td rowspan="2">纸质</td><td>h(m)</td><td>0.1226</td><td>0.1510</td><td>0.1820</td><td>0.2153</td><td>0.2517</td><td>0.2900</td><td>0.3316</td><td>0.3753</td></tr><tr><td>$\frac{{v}^{2}}{2}$(m<sup>2</sup>·s<sup>-2</sup>)</td><td></td><td>1.10</td><td>1【公式】29</td><td>1.52</td><td>1.74</td><td>2.00</td><td>2.27</td><td></td></tr><tr><td rowspan="2">木质</td><td>h(m)</td><td>0.0605</td><td>0.0825</td><td>0.1090</td><td>0【公式】1400</td><td>0.1740</td><td>0.2115</td><td>0.2530</td><td>0.2980</td></tr><tr><td>$\frac{{v}^{2}}{2}$(m<sup>2</sup>·s<sup>-2</sup>)</td><td></td><td>0.735</td><td>1.03</td><td>1.32</td><td>1.60</td><td>1.95</td><td>2.34</td><td></td></tr><tr><td rowspan="2">铁质</td><td>h(m)</td><td>0.0953</td><td>0.1244</td><td>0.1574</td><td>0.1944</td><td>0.2352</td><td>0.2799</td><td>0.3285</td><td>0.3810</td></tr><tr><td>$\frac{{v}^{2}}{2}$(m<sup>2</sup>·s<sup>-2</sup>)</td><td></td><td>1.21</td><td>1.53</td><td>1.89</td><td>2.28</td><td>2.72</td><td>3.19</td><td></td></tr></table>
  101. """
  102. # aa = [1,2,3,6,3,3,7,5,4,4,8,8]
  103. # print(aa[-60:])
  104. # print(aa[:60])
  105. # print(aa[-3:3])
  106. # print(int(470*0.4))
  107. from Utils.train_configs import data_dir
  108. # from transformers import BertTokenizer
  109. # tokenizer = BertTokenizer.from_pretrained('Models/bert-base-chinese', use_fast=True, do_lower_case=True)
  110. all_documents, all_labels = read_data(data_dir)
  111. # tokens_per_sent, sents_token_range, sent_idx = get_token(all_documents[0])
  112. # for i in range(len(tokens_per_sent)):
  113. # print("参考句:", all_documents[0][i])
  114. # print("上下句范围:", tokenizer.decode(sents_token_range[i]))
  115. # st, lenght = sent_idx[i].split(",")
  116. # print("本句:", tokenizer.decode(sents_token_range[i][int(st): int(st)+int(lenght)]))
  117. # print("****************************************")
  118. import torch
  119. # last_hidden_states = torch.randn(3, 10, 20)
  120. # seq_idxs = [[1, 3], [2, 5], [3, 7]]
  121. # # last_hidden_states = map(lambda x: x[0], seq_idxs)
  122. # padded_sliced_hidden_states = [last_hidden_states[i, seq_idxs[i][0]:seq_idxs[i][1], :]
  123. # for i in range(last_hidden_states.size(0))]
  124. # for i in range(last_hidden_states.size(0)):
  125. # aa = last_hidden_states[i, seq_idxs[i][0]:seq_idxs[i][0]+seq_idxs[i][1], :]
  126. # print(aa.size())
  127. # bb = aa.mean(dim=0)
  128. # print(bb)
  129. # concatenated_tensor = torch.stack([last_hidden_states[i, seq_idxs[i][0]:seq_idxs[i][0]+seq_idxs[i][1], :].mean(dim=0) for i in range(last_hidden_states.size(0))], dim=0)
  130. # print(concatenated_tensor)
  131. # for i, j in enumerate(range(2)):
  132. # print(i, j)
  133. tensor = torch.randn(2, 3)
  134. tensor_expanded = tensor.unsqueeze(0)
  135. print(tensor_expanded.shape)
  136. for i in range(0):
  137. print(11111111111111111111111111)
  138. # 假设你有两个张量
  139. # tensor1 = torch.randn(300, 512, 768) # 第一个张量,大小为 [300, 512, 768]
  140. # tensor2 = torch.randn(22, 512, 768) # 第二个张量,大小为 [22, 512, 768]
  141. # # 使用 torch.cat 在第一维度上连接这两个张量
  142. # combined_tensor = torch.cat((tensor1, tensor2), dim=0)
  143. # print(combined_tensor.shape) # 输出应该是 torch.Size([522, 512, 768])