config.py 3.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108
  1. import torch
  2. from transformers import BertTokenizer
  3. import sys
  4. import logging
  5. import logging.handlers
  6. import os
  7. rootdir = os.getcwd()
  8. sys.path.append(rootdir)
  9. class NerArgs:
  10. tasks = ["ner"]
  11. bert_dir = os.path.join(rootdir, "model_hub/chinese-bert-wwm-ext/")
  12. save_dir = "checkpoints/PointerNet_topic_ner_model_with_2_task_release_2.pt"
  13. optimizer_save_dir = "checkpoints/optimizer_further_finetuned.pt"
  14. label_path = "PointerNet/data/labels.txt"
  15. with open(label_path, "r") as fp:
  16. labels = fp.read().strip().split("\n")
  17. label2id = {}
  18. id2label = {}
  19. for i, label in enumerate(labels):
  20. label2id[label] = i
  21. id2label[i] = label
  22. ner_num_labels = len(labels)
  23. # train_epoch = 3
  24. # train_batch_size = 1
  25. # eval_batch_size = 1
  26. # eval_step = 1000
  27. # max_seq_len = 240 # 100个句子:220;90个句子:228-245,选240
  28. # # max_encoder_sent_len = 100
  29. # max_input_sent_num = 90
  30. # weight_decay = 0.01
  31. # adam_epsilon = 1e-8
  32. # max_grad_norm = 5.0
  33. # lr = 1e-5
  34. # other_lr = 3e-4
  35. # warmup_proportion = 0.01
  36. device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
  37. tokenizer = BertTokenizer.from_pretrained(bert_dir)
  38. """
  39. 调试记录:
  40. max_seq_len 和 batch_size这两个参数最影响显存消耗
  41. 当batch_size为2时, 易出现显存不够;此时在forward计算过程中将部分结果移至cpu中,可以节约显存
  42. >>>实际中在前馈计算时,将bert的输出last_hidden_state使用了.detach().cpu(),
  43. detach()方法会切断张量与其前导图之间的梯度流,导致模型训练不稳定
  44. 当batch_size>1时,训练稳定性暂时无法验证,显存不够
  45. max_seq_len太大时,是否收敛带验证
  46. 在一张卡24G显存中,取bs=1:一个文档,max_seq_len最多取到68,且要求max_encoder_sent_len设置比较小
  47. >>>需将文档分批计算,如100个句子为一个批次
  48. """
  49. class myLog(object):
  50. '''
  51. 封装后的logging
  52. '''
  53. def __init__(self, logger=None, log_cate='my_log',subject=''):
  54. '''
  55. 指定保存日志的文件路径,日志级别,以及调用文件
  56. 将日志存入到指定的文件中
  57. '''
  58. log_dir = os.path.join(rootdir, "logs")
  59. if subject:
  60. log_dir += '/{}'.format(subject)
  61. if not os.path.exists(log_dir):
  62. os.mkdir(log_dir)
  63. # 创建一个logger
  64. self.logger = logging.getLogger(logger)
  65. self.logger.setLevel(logging.INFO) # DEBUG
  66. self.log_name = os.path.join(log_dir, '{}.log'.format(log_cate)) # 日志地址
  67. fh = logging.handlers.RotatingFileHandler(self.log_name, maxBytes=120000000, backupCount=4,
  68. mode='a', encoding='utf-8', delay=True)
  69. # fh = logging.FileHandler(self.log_name, mode='a', encoding='utf-8', delay=True)
  70. fh.setLevel(logging.INFO)
  71. # formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
  72. formatter = logging.Formatter('{"log-msg": %(message)s, "other-msg": "%(filename)s-%(lineno)s-%(asctime)s"}')
  73. fh.setFormatter(formatter)
  74. self.logger.addHandler(fh)
  75. fh.close()
  76. def getlog(self):
  77. return self.logger
  78. server_ip = '0.0.0.0'
  79. server_port = 10622
  80. class TestingCfg: # testing
  81. internal_ip = '0.0.0.0' # internal
  82. external_ip = '192.168.1.204' # external
  83. server_port = 10622
  84. class ProductionCfg: # production
  85. internal_ip = '0.0.0.0' # internal
  86. external_ip = '10.19.1.14' # external
  87. server_port = 10622