train_script.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344
  1. """
  2. Train script for a single file
  3. Need to set the TPU address first:
  4. export XRT_TPU_CONFIG="localservice;0;localhost:51011"
  5. """
  6. import torch.multiprocessing as mp
  7. import threading
  8. import time
  9. import random
  10. import sys
  11. import argparse
  12. import gzip
  13. import json
  14. import logging
  15. import tqdm
  16. import torch
  17. from torch import nn
  18. from torch.utils.data import DataLoader
  19. import torch
  20. import torch_xla
  21. import torch_xla.core
  22. import torch_xla.core.functions
  23. import torch_xla.core.xla_model as xm
  24. import torch_xla.distributed.xla_multiprocessing as xmp
  25. import torch_xla.distributed.parallel_loader as pl
  26. import os
  27. from shutil import copyfile
  28. from transformers import (
  29. AdamW,
  30. AutoModel,
  31. AutoTokenizer,
  32. get_linear_schedule_with_warmup,
  33. set_seed,
  34. )
  35. class AutoModelForSentenceEmbedding(nn.Module):
  36. def __init__(self, model_name, tokenizer, normalize=True):
  37. super(AutoModelForSentenceEmbedding, self).__init__()
  38. self.model = AutoModel.from_pretrained(model_name)
  39. self.normalize = normalize
  40. self.tokenizer = tokenizer
  41. def forward(self, **kwargs):
  42. model_output = self.model(**kwargs)
  43. embeddings = self.mean_pooling(model_output, kwargs['attention_mask'])
  44. if self.normalize:
  45. embeddings = torch.nn.functional.normalize(embeddings, p=2, dim=1)
  46. return embeddings
  47. def mean_pooling(self, model_output, attention_mask):
  48. token_embeddings = model_output[0] # First element of model_output contains all token embeddings
  49. input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
  50. return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)
  51. def save_pretrained(self, output_path):
  52. if xm.is_master_ordinal():
  53. self.tokenizer.save_pretrained(output_path)
  54. self.model.config.save_pretrained(output_path)
  55. xm.save(self.model.state_dict(), os.path.join(output_path, "pytorch_model.bin"))
  56. def train_function(index, args, queue):
  57. tokenizer = AutoTokenizer.from_pretrained(args.model)
  58. model = AutoModelForSentenceEmbedding(args.model, tokenizer)
  59. ### Train Loop
  60. device = xm.xla_device()
  61. model = model.to(device)
  62. # Instantiate optimizer
  63. optimizer = AdamW(params=model.parameters(), lr=2e-5, correct_bias=True)
  64. lr_scheduler = get_linear_schedule_with_warmup(
  65. optimizer=optimizer,
  66. num_warmup_steps=500,
  67. num_training_steps=args.steps,
  68. )
  69. # Now we train the model
  70. cross_entropy_loss = nn.CrossEntropyLoss()
  71. max_grad_norm = 1
  72. model.train()
  73. for global_step in tqdm.trange(args.steps, disable=not xm.is_master_ordinal()):
  74. #### Get the batch data
  75. batch = queue.get()
  76. #print(index, "batch {}x{}".format(len(batch), ",".join([str(len(b)) for b in batch])))
  77. if len(batch[0]) == 2: #(anchor, positive)
  78. text1 = tokenizer([b[0] for b in batch], return_tensors="pt", max_length=args.max_length, truncation=True, padding="max_length")
  79. text2 = tokenizer([b[1] for b in batch], return_tensors="pt", max_length=args.max_length, truncation=True, padding="max_length")
  80. ### Compute embeddings
  81. embeddings_a = model(**text1.to(device))
  82. embeddings_b = model(**text2.to(device))
  83. ### Gather all embedings
  84. embeddings_a = torch_xla.core.functions.all_gather(embeddings_a)
  85. embeddings_b = torch_xla.core.functions.all_gather(embeddings_b)
  86. ### Compute similarity scores 512 x 512
  87. scores = torch.mm(embeddings_a, embeddings_b.transpose(0, 1)) * args.scale
  88. ### Compute cross-entropy loss
  89. labels = torch.tensor(range(len(scores)), dtype=torch.long, device=embeddings_a.device) # Example a[i] should match with b[i]
  90. ## Symmetric loss as in CLIP
  91. loss = (cross_entropy_loss(scores, labels) + cross_entropy_loss(scores.transpose(0, 1), labels)) / 2
  92. else: #(anchor, positive, negative)
  93. text1 = tokenizer([b[0] for b in batch], return_tensors="pt", max_length=args.max_length, truncation=True, padding="max_length")
  94. text2 = tokenizer([b[1] for b in batch], return_tensors="pt", max_length=args.max_length, truncation=True, padding="max_length")
  95. text3 = tokenizer([b[2] for b in batch], return_tensors="pt", max_length=args.max_length, truncation=True, padding="max_length")
  96. embeddings_a = model(**text1.to(device))
  97. embeddings_b1 = model(**text2.to(device))
  98. embeddings_b2 = model(**text3.to(device))
  99. embeddings_a = torch_xla.core.functions.all_gather(embeddings_a)
  100. embeddings_b1 = torch_xla.core.functions.all_gather(embeddings_b1)
  101. embeddings_b2 = torch_xla.core.functions.all_gather(embeddings_b2)
  102. embeddings_b = torch.cat([embeddings_b1, embeddings_b2])
  103. ### Compute similarity scores 512 x 1024
  104. scores = torch.mm(embeddings_a, embeddings_b.transpose(0, 1)) * args.scale
  105. ### Compute cross-entropy loss
  106. labels = torch.tensor(range(len(scores)), dtype=torch.long, device=embeddings_a.device) # Example a[i] should match with b[i]
  107. ## One-way loss
  108. loss = cross_entropy_loss(scores, labels)
  109. # Backward pass
  110. optimizer.zero_grad()
  111. loss.backward()
  112. torch.nn.utils.clip_grad_norm_(model.parameters(), max_grad_norm)
  113. xm.optimizer_step(optimizer, barrier=True)
  114. lr_scheduler.step()
  115. #Save model
  116. if (global_step+1) % args.save_steps == 0:
  117. output_path = os.path.join(args.output, str(global_step+1))
  118. xm.master_print("save model: "+output_path)
  119. model.save_pretrained(output_path)
  120. output_path = os.path.join(args.output, "final")
  121. xm.master_print("save model final: "+ output_path)
  122. model.save_pretrained(output_path)
  123. def produce_data(args, queue, filepaths, dataset_indices):
  124. global_batch_size = args.batch_size*args.nprocs #Global batch size
  125. size_per_dataset = int(global_batch_size / args.datasets_per_batch) #How many datasets per batch
  126. num_same_dataset = int(size_per_dataset / args.batch_size)
  127. print("producer", "global_batch_size", global_batch_size)
  128. print("producer", "size_per_dataset", size_per_dataset)
  129. print("producer", "num_same_dataset", num_same_dataset)
  130. datasets = []
  131. for filepath in filepaths:
  132. if "reddit_" in filepath: #Special dataset class for Reddit files
  133. data_obj = RedditDataset(filepath)
  134. else:
  135. data_obj = Dataset(filepath)
  136. datasets.append(iter(data_obj))
  137. # Store if dataset is in a 2 col or 3 col format
  138. num_cols = {idx: len(next(dataset)) for idx, dataset in enumerate(datasets)}
  139. while True:
  140. texts_in_batch = set()
  141. batch_format = None #2 vs 3 col format for this batch
  142. #Add data from several sub datasets
  143. for _ in range(args.datasets_per_batch):
  144. valid_dataset = False #Check that datasets have the same 2/3 col format
  145. while not valid_dataset:
  146. data_idx = random.choice(dataset_indices)
  147. if batch_format is None:
  148. batch_format = num_cols[data_idx]
  149. valid_dataset = True
  150. else: #Check that this dataset has the same format
  151. valid_dataset = (batch_format == num_cols[data_idx])
  152. #Get data from this dataset
  153. dataset = datasets[data_idx]
  154. for _ in range(num_same_dataset):
  155. for _ in range(args.nprocs):
  156. batch_device = [] #A batch for one device
  157. while len(batch_device) < args.batch_size:
  158. sample = next(dataset)
  159. in_batch = False
  160. for text in sample:
  161. if text in texts_in_batch:
  162. in_batch = True
  163. break
  164. if not in_batch:
  165. for text in sample:
  166. texts_in_batch.add(text)
  167. batch_device.append(sample)
  168. queue.put(batch_device)
  169. class RedditDataset:
  170. """
  171. A class that handles the reddit data files
  172. """
  173. def __init__(self, filepath):
  174. self.filepath = filepath
  175. def __iter__(self):
  176. while True:
  177. with gzip.open(self.filepath, "rt") as fIn:
  178. for line in fIn:
  179. data = json.loads(line)
  180. if "response" in data and "context" in data:
  181. yield [data["response"], data["context"]]
  182. class Dataset:
  183. """
  184. A class that handles one dataset
  185. """
  186. def __init__(self, filepath):
  187. self.filepath = filepath
  188. def __iter__(self):
  189. max_dataset_size = 10*1000*1000 #Cache small datasets in memory
  190. dataset = []
  191. data_format = None
  192. while dataset is None or len(dataset) == 0:
  193. with gzip.open(self.filepath, "rt") as fIn:
  194. for line in fIn:
  195. data = json.loads(line)
  196. if isinstance(data, dict):
  197. data = data['texts']
  198. if data_format is None:
  199. data_format = len(data)
  200. #Ensure that all entries are of the same 2/3 col format
  201. assert len(data) == data_format
  202. if dataset is not None:
  203. dataset.append(data)
  204. if len(dataset) >= max_dataset_size:
  205. dataset = None
  206. yield data
  207. # Data loaded. Now stream to the queue
  208. # Shuffle for each epoch
  209. while True:
  210. random.shuffle(dataset)
  211. for data in dataset:
  212. yield data
  213. if __name__ == "__main__":
  214. parser = argparse.ArgumentParser()
  215. parser.add_argument('--model', default='nreimers/MiniLM-L6-H384-uncased')
  216. parser.add_argument('--steps', type=int, default=2000)
  217. parser.add_argument('--save_steps', type=int, default=10000)
  218. parser.add_argument('--batch_size', type=int, default=64)
  219. parser.add_argument('--max_length', type=int, default=128)
  220. parser.add_argument('--nprocs', type=int, default=8)
  221. parser.add_argument('--datasets_per_batch', type=int, default=2, help="Number of datasets per batch")
  222. parser.add_argument('--scale', type=float, default=20, help="Use 20 for cossim, and 1 when you work with unnormalized embeddings with dot product")
  223. parser.add_argument('--data_folder', default="/data", help="Folder with your dataset files")
  224. parser.add_argument('data_config', help="A data_config.json file")
  225. parser.add_argument('output')
  226. args = parser.parse_args()
  227. # Ensure global batch size is divisble by data_sample_size
  228. assert (args.batch_size*args.nprocs) % args.datasets_per_batch == 0
  229. logging.info("Output: "+args.output)
  230. if os.path.exists(args.output):
  231. print("Output folder already exists.")
  232. input("Continue?")
  233. # Write train script to output path
  234. os.makedirs(args.output, exist_ok=True)
  235. data_config_path = os.path.join(args.output, 'data_config.json')
  236. copyfile(args.data_config, data_config_path)
  237. train_script_path = os.path.join(args.output, 'train_script.py')
  238. copyfile(__file__, train_script_path)
  239. with open(train_script_path, 'a') as fOut:
  240. fOut.write("\n\n# Script was called via:\n#python " + " ".join(sys.argv))
  241. #Load data config
  242. with open(args.data_config) as fIn:
  243. data_config = json.load(fIn)
  244. queue = mp.Queue(maxsize=100*args.nprocs)
  245. filepaths = []
  246. dataset_indices = []
  247. for idx, data in enumerate(data_config):
  248. filepaths.append(os.path.join(os.path.expanduser(args.data_folder), data['name']))
  249. dataset_indices.extend([idx]*data['weight'])
  250. # Start producer
  251. p = mp.Process(target=produce_data, args=(args, queue, filepaths, dataset_indices))
  252. p.start()
  253. # Run training
  254. print("Start processes:", args.nprocs)
  255. xmp.spawn(train_function, args=(args, queue), nprocs=args.nprocs, start_method='fork')
  256. print("Training done")
  257. print("It might be that not all processes exit automatically. In that case you must manually kill this process.")
  258. print("With 'pkill python' you can kill all remaining python processes")
  259. p.kill()
  260. exit()
  261. # Script was called via:
  262. #python train_many_data_files_v2.py --steps 1000000 --batch_size 128 --model nreimers/MiniLM-L6-H384-uncased train_data_configs/all_datasets_v4.json output/all_datasets_v4_MiniLM-L6-H384-uncased-batch128