eng_emb.py 2.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051
  1. from transformers import AutoTokenizer, AutoModel
  2. from sentence_transformers import util
  3. import torch
  4. import torch.nn.functional as F
  5. from my_config import LANG_EMB_MODEL
  6. tokenizer = AutoTokenizer.from_pretrained(LANG_EMB_MODEL["eng"])
  7. model = AutoModel.from_pretrained(LANG_EMB_MODEL["eng"])
  8. # Mean Pooling - Take attention mask into account for correct averaging
  9. def mean_pooling(model_output, attention_mask):
  10. token_embeddings = model_output[0] # First element of model_output contains all token embeddings
  11. input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
  12. return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)
  13. def item2emb_eng(sentences):
  14. # Sentences we want sentence embeddings for
  15. # sentences = ['This is an example sentence', 'Each sentence is converted']
  16. # # Load model from HuggingFace Hub
  17. # # Tokenize sentences
  18. encoded_input = tokenizer(sentences, padding=True, truncation=True, return_tensors='pt')
  19. #
  20. # # Compute token embeddings
  21. with torch.no_grad():
  22. model_output = model(**encoded_input)
  23. # Perform pooling
  24. sentence_embeddings = mean_pooling(model_output, encoded_input['attention_mask'])
  25. # Normalize embeddings
  26. sentence_embeddings = F.normalize(sentence_embeddings, p=2, dim=1)
  27. return sentence_embeddings
  28. if __name__ == '__main__':
  29. sentences = ['现货团湿疹克星来啦纽强二代顺峰宝宝洗护全系列开团宝宝的肌肤守护天使刮码发货正品保证',
  30. '简单直接的欢乐客积木桌来了最新品哦',
  31. '宝宝拍新品儿童创意沙画激发孩子的想象力和颜色搭配能力预售18号发货',
  32. '北鼎多功能G56家用蒸炖锅电蒸锅隔水炖盅全自动可预约好收纳高颜值']
  33. # item2emb_cn(sentences)
  34. ss = item2emb_eng(["A highly recommended book", " a highly recommendable book"])
  35. print(ss)
  36. # b = util.cos_sim(ss[0], ss[1:])
  37. b = util.cos_sim(ss, ss)
  38. # print(a)
  39. print(b)