123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173 |
- import os
- import re
- import sys
- sys.path.append("/home/cv/workspace/tujintao/document_segmentation")
- from Utils.main_clear.sci_clear import non_data_latex_iter
- filename = "Data/samples"
- def read_data(directory):
- all_documents = []
- all_labels = []
- for filename in os.listdir(directory)[:2]:
- if filename.endswith(".txt"):
- filepath = os.path.join(directory, filename)
- # print(filepath)
- # 读取txt文件内容并处理每一行结尾
- with open(filepath, "r", encoding="utf-8") as file:
- lines = file.readlines()
- # for i in range(len(lines)):
- # if not lines[i].endswith("<br/>------------------------1\n"):
- # lines[i] = re.sub(r'------------------------1$', '<br/>------------------------1\n', lines[i])
- # 将所有行的内容拼接为一行,并清除无关符号
- text = "".join(lines)
- text = re.sub(r'</body>|</html>|<sub>|</sub>|<p>|</p>|<td>\s*</td>', '', text)
- text = re.sub(r'<td .+?["\']>\s*</td>', '', text)
- text = re.sub(r'</td>\s*<td( .+?["\'])?>', ' ', text)
- text = re.sub(r'<tr .+?["\']>\s*</tr>|</?table>|</?tbody>|<table .+?["\']>', '', text)
- text = re.sub(r'<tr( .+?["\'])?>\s*<td( .+?["\'])?>', '<tr>', text)
- text = re.sub('</td></tr>', '', text)
- text = re.sub("【<img .*?\"\s*/?>公式latex提取失败】", "【公式】", text)
- text = re.sub("<img .*?[\"']\s*/?>", "【图片】", text)
- text = re.sub(r"<span style=\"color: red\">(.*?)</span>", r"\1", text)
- text = non_data_latex_iter(text)
- text = re.sub(r'<b\*?r\s*/?>', '\n', text)
- text = re.sub(r'(------------------------1)(?!\n)', r'\1\n', text)
- # print(text)
- # 提取标签, 并过滤掉空句子
- labels = []
- sentences = []
- for sentence in text.split("\n"):
- # print(sentence)
- if re.search("------------------------1", sentence.strip()):
- if sentence.strip().startswith("------------------------1"):
- if labels:
- labels[-1] = 1 # 将前一个非空句子的标签设为1
- else:
- labels.append(1)
- sentence = re.sub("------------------------1", "", sentence)
- else:
- if sentence.strip():
- labels.append(0)
- if sentence.strip():
- sentences.append(sentence.strip())
- print("句子数目:", len(sentences))
- all_documents.append(sentences)
- all_labels.append(labels)
- return all_documents, all_labels
- def split_dataset(input_texts, segment_labels, train_ratio=0.7, valid_ratio=0.1):
- """把数据划分为 Train/Valid/Test Set"""
- total_samples = len(input_texts)
- train_size = int(total_samples * train_ratio)
- valid_size = int(total_samples * valid_ratio)
- test_size = total_samples - train_size - valid_size
- train_doc = input_texts[:train_size]
- train_seg_labels = segment_labels[:train_size]
- valid_doc = input_texts[train_size:train_size + valid_size]
- valid_seg_labels = segment_labels[train_size:train_size + valid_size]
- test_doc = input_texts[-test_size:]
- test_seg_labels = segment_labels[-test_size:]
- return (train_doc, train_seg_labels), (valid_doc, valid_seg_labels), (
- test_doc, test_seg_labels)
- def get_token(sentences):
-
- all_tokens = tokenizer.encode("\n".join(sentences))
- bef_sent_tokens = []
- aft_sent_tokens = all_tokens[1:-1] # 包含当前句及之后句token
- sents_token_range = []
- sent_idx = [] #(start_idx, local_sent_len):相对应索引的句子中真正想要的实际句子索引
- tokens_per_sent = []
- for idx, one_sent in enumerate(sentences):
- token_id = tokenizer.encode(one_sent)
- print(token_id)
- tokens_per_sent.append(token_id[1:-1])
- local_sent_len = len(token_id[1:-1])
- if not idx:
- sents_token_range.append(aft_sent_tokens[0:510])
- sent_idx.append("{},{}".format(0, local_sent_len))
- else:
- if len(token_id[1:-1]) > 200: # 当前句token长超过200时,开始截断,再前面取 150 后取160
- bef_lenght = 150
- elif idx == len(sentences) - 1: # 最后一个句子
- bef_lenght = 200
- else:
- bef_lenght = int((510 - len(token_id[1:-1])) * 0.4)
- aft_lenght = 510 - len(bef_sent_tokens[-bef_lenght:]) # 当bef_sent_tokens中不到150个数时
- sents_token_range.append(bef_sent_tokens[-bef_lenght:] + aft_sent_tokens[:aft_lenght])
- sent_idx.append("{},{}".format(len(bef_sent_tokens[-bef_lenght:]), local_sent_len))
- aft_sent_tokens = aft_sent_tokens[local_sent_len:]
- bef_sent_tokens.extend(token_id[1:-1])
- return tokens_per_sent, sents_token_range, sent_idx
- if __name__ == '__main__':
- text = r"""
- <table><tr><td></td><td>物理量</td><td>1</td><td>2</td><td>3</td><td>4</td><td>5</td><td>6</td><td>7</td><td>8</td></tr><tr><td rowspan="2">纸质</td><td>h(m)</td><td>0.1226</td><td>0.1510</td><td>0.1820</td><td>0.2153</td><td>0.2517</td><td>0.2900</td><td>0.3316</td><td>0.3753</td></tr><tr><td>$\frac{{v}^{2}}{2}$(m<sup>2</sup>·s<sup>-2</sup>)</td><td></td><td>1.10</td><td>1【公式】29</td><td>1.52</td><td>1.74</td><td>2.00</td><td>2.27</td><td></td></tr><tr><td rowspan="2">木质</td><td>h(m)</td><td>0.0605</td><td>0.0825</td><td>0.1090</td><td>0【公式】1400</td><td>0.1740</td><td>0.2115</td><td>0.2530</td><td>0.2980</td></tr><tr><td>$\frac{{v}^{2}}{2}$(m<sup>2</sup>·s<sup>-2</sup>)</td><td></td><td>0.735</td><td>1.03</td><td>1.32</td><td>1.60</td><td>1.95</td><td>2.34</td><td></td></tr><tr><td rowspan="2">铁质</td><td>h(m)</td><td>0.0953</td><td>0.1244</td><td>0.1574</td><td>0.1944</td><td>0.2352</td><td>0.2799</td><td>0.3285</td><td>0.3810</td></tr><tr><td>$\frac{{v}^{2}}{2}$(m<sup>2</sup>·s<sup>-2</sup>)</td><td></td><td>1.21</td><td>1.53</td><td>1.89</td><td>2.28</td><td>2.72</td><td>3.19</td><td></td></tr></table>
- """
- # aa = [1,2,3,6,3,3,7,5,4,4,8,8]
- # print(aa[-60:])
- # print(aa[:60])
- # print(aa[-3:3])
- # print(int(470*0.4))
- from Utils.train_configs import data_dir
- # from transformers import BertTokenizer
- # tokenizer = BertTokenizer.from_pretrained('Models/bert-base-chinese', use_fast=True, do_lower_case=True)
- all_documents, all_labels = read_data(data_dir)
- # tokens_per_sent, sents_token_range, sent_idx = get_token(all_documents[0])
- # for i in range(len(tokens_per_sent)):
- # print("参考句:", all_documents[0][i])
- # print("上下句范围:", tokenizer.decode(sents_token_range[i]))
- # st, lenght = sent_idx[i].split(",")
- # print("本句:", tokenizer.decode(sents_token_range[i][int(st): int(st)+int(lenght)]))
- # print("****************************************")
- import torch
- # last_hidden_states = torch.randn(3, 10, 20)
- # seq_idxs = [[1, 3], [2, 5], [3, 7]]
- # # last_hidden_states = map(lambda x: x[0], seq_idxs)
- # padded_sliced_hidden_states = [last_hidden_states[i, seq_idxs[i][0]:seq_idxs[i][1], :]
- # for i in range(last_hidden_states.size(0))]
- # for i in range(last_hidden_states.size(0)):
- # aa = last_hidden_states[i, seq_idxs[i][0]:seq_idxs[i][0]+seq_idxs[i][1], :]
- # print(aa.size())
- # bb = aa.mean(dim=0)
- # print(bb)
- # concatenated_tensor = torch.stack([last_hidden_states[i, seq_idxs[i][0]:seq_idxs[i][0]+seq_idxs[i][1], :].mean(dim=0) for i in range(last_hidden_states.size(0))], dim=0)
- # print(concatenated_tensor)
- # for i, j in enumerate(range(2)):
- # print(i, j)
- tensor = torch.randn(2, 3)
- tensor_expanded = tensor.unsqueeze(0)
- print(tensor_expanded.shape)
- for i in range(0):
- print(11111111111111111111111111)
-
- # 假设你有两个张量
- # tensor1 = torch.randn(300, 512, 768) # 第一个张量,大小为 [300, 512, 768]
- # tensor2 = torch.randn(22, 512, 768) # 第二个张量,大小为 [22, 512, 768]
-
- # # 使用 torch.cat 在第一维度上连接这两个张量
- # combined_tensor = torch.cat((tensor1, tensor2), dim=0)
-
- # print(combined_tensor.shape) # 输出应该是 torch.Size([522, 512, 768])
|