123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108 |
- import torch
- from transformers import BertTokenizer
- import sys
- import logging
- import logging.handlers
- import os
- rootdir = os.getcwd()
- sys.path.append(rootdir)
- class NerArgs:
- tasks = ["ner"]
- bert_dir = os.path.join(rootdir, "model_hub/chinese-bert-wwm-ext/")
- save_dir = "checkpoints/PointerNet_topic_ner_model_with_2_task_release_2.pt"
- optimizer_save_dir = "checkpoints/optimizer_further_finetuned.pt"
- label_path = "PointerNet/data/labels.txt"
- with open(label_path, "r") as fp:
- labels = fp.read().strip().split("\n")
- label2id = {}
- id2label = {}
- for i, label in enumerate(labels):
- label2id[label] = i
- id2label[i] = label
- ner_num_labels = len(labels)
- # train_epoch = 3
- # train_batch_size = 1
- # eval_batch_size = 1
- # eval_step = 1000
- # max_seq_len = 240 # 100个句子:220;90个句子:228-245,选240
- # # max_encoder_sent_len = 100
- # max_input_sent_num = 90
- # weight_decay = 0.01
- # adam_epsilon = 1e-8
- # max_grad_norm = 5.0
- # lr = 1e-5
- # other_lr = 3e-4
- # warmup_proportion = 0.01
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
- tokenizer = BertTokenizer.from_pretrained(bert_dir)
- """
- 调试记录:
- max_seq_len 和 batch_size这两个参数最影响显存消耗
- 当batch_size为2时, 易出现显存不够;此时在forward计算过程中将部分结果移至cpu中,可以节约显存
- >>>实际中在前馈计算时,将bert的输出last_hidden_state使用了.detach().cpu(),
- detach()方法会切断张量与其前导图之间的梯度流,导致模型训练不稳定
- 当batch_size>1时,训练稳定性暂时无法验证,显存不够
- max_seq_len太大时,是否收敛带验证
- 在一张卡24G显存中,取bs=1:一个文档,max_seq_len最多取到68,且要求max_encoder_sent_len设置比较小
- >>>需将文档分批计算,如100个句子为一个批次
- """
- class myLog(object):
- '''
- 封装后的logging
- '''
- def __init__(self, logger=None, log_cate='my_log',subject=''):
- '''
- 指定保存日志的文件路径,日志级别,以及调用文件
- 将日志存入到指定的文件中
- '''
- log_dir = os.path.join(rootdir, "logs")
- if subject:
- log_dir += '/{}'.format(subject)
- if not os.path.exists(log_dir):
- os.mkdir(log_dir)
- # 创建一个logger
- self.logger = logging.getLogger(logger)
- self.logger.setLevel(logging.INFO) # DEBUG
- self.log_name = os.path.join(log_dir, '{}.log'.format(log_cate)) # 日志地址
- fh = logging.handlers.RotatingFileHandler(self.log_name, maxBytes=120000000, backupCount=4,
- mode='a', encoding='utf-8', delay=True)
- # fh = logging.FileHandler(self.log_name, mode='a', encoding='utf-8', delay=True)
- fh.setLevel(logging.INFO)
- # formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
- formatter = logging.Formatter('{"log-msg": %(message)s, "other-msg": "%(filename)s-%(lineno)s-%(asctime)s"}')
- fh.setFormatter(formatter)
- self.logger.addHandler(fh)
- fh.close()
- def getlog(self):
- return self.logger
- server_ip = '0.0.0.0'
- server_port = 10622
- class TestingCfg: # testing
- internal_ip = '0.0.0.0' # internal
- external_ip = '192.168.1.204' # external
- server_port = 10622
- class ProductionCfg: # production
- internal_ip = '0.0.0.0' # internal
- external_ip = '10.19.1.14' # external
- server_port = 10622
|