chinese_emb.py 2.0 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950
  1. from transformers import AutoTokenizer, AutoModel
  2. import torch
  3. import torch.nn.functional as F
  4. from my_config import LANG_EMB_MODEL
  5. tokenizer = AutoTokenizer.from_pretrained(LANG_EMB_MODEL["cn"])
  6. model = AutoModel.from_pretrained(LANG_EMB_MODEL["cn"]) # , from_tf=True
  7. # Mean Pooling - Take attention mask into account for correct averaging
  8. def mean_pooling(model_output, attention_mask):
  9. token_embeddings = model_output[0] # First element of model_output contains all token embeddings
  10. input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
  11. return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)
  12. def item2emb_cn(sentences):
  13. # Sentences we want sentence embeddings for
  14. # sentences = ['This is an example sentence', 'Each sentence is converted']
  15. # # Load model from HuggingFace Hub
  16. # # Tokenize sentences
  17. encoded_input = tokenizer(sentences, padding=True, truncation=True, return_tensors='pt')
  18. #
  19. # # Compute token embeddings
  20. with torch.no_grad():
  21. model_output = model(**encoded_input)
  22. # Perform pooling
  23. sentence_embeddings = mean_pooling(model_output, encoded_input['attention_mask'])
  24. # Normalize embeddings
  25. sentence_embeddings = F.normalize(sentence_embeddings, p=2, dim=1)
  26. return sentence_embeddings
  27. if __name__ == '__main__':
  28. from sentence_transformers import util
  29. sentences = ['现货团湿疹克星来啦纽强二代顺峰宝宝洗护全系列开团宝宝的肌肤守护天使刮码发货正品保证',
  30. '简单直接的欢乐客积木桌来了最新品哦',
  31. '宝宝拍新品儿童创意沙画激发孩子的想象力和颜色搭配能力预售18号发货',
  32. '北鼎多功能G56家用蒸炖锅电蒸锅隔水炖盅全自动可预约好收纳高颜值']
  33. ss = item2emb_cn(["我不会", "欣赏"])
  34. b = util.cos_sim(ss[0], ss[1:])
  35. print(b)
  36. # "因此", "总之"