all_lang_emb.py 1.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748
  1. from sentence_transformers import SentenceTransformer, util
  2. from my_config import LANG_EMB_MODEL
  3. model = SentenceTransformer(LANG_EMB_MODEL["all"])
  4. # print("model load time:{}".format(time.time()-s))
  5. # # Our sentences we like to encode
  6. def item2emb_all(items_list):
  7. """
  8. items= ["已知集合点集,集合点集,求交集",
  9. "求集合点集与集合点集的交集",
  10. "求函数根式复合是一次的定义域",
  11. "函数是奇函数,在区间上单调递增,求参数的取值范围"]
  12. :param items_list:
  13. :return:
  14. """
  15. if isinstance(items_list, str):
  16. items_list = [items_list]
  17. # # Sentences are encoded by calling model.encode()
  18. item_embeddings = model.encode(items_list)
  19. return item_embeddings
  20. if __name__ == '__main__':
  21. # ss = item2emb_all("已知集合点集,集合点集,求交集")
  22. # print(ss)
  23. # print(ss.shape),
  24. # ss = item2emb_all(["已知集合点集,集合点集,求交集",
  25. # "求集合点集与集合点集的交集",
  26. # "求函数根式复合是一次的定义域",
  27. # "函数是奇函数,在区间上单调递增,求参数的取值范围"])
  28. ss = item2emb_all(["广泛地阅读", "泛读"]) # ["visual", "scene", "situation"]
  29. a = util.cos_sim(ss, ss)
  30. # print(ss.shape)
  31. # # b = util.cos_sim(ss[:1],ss[1:])
  32. print(a)
  33. # print(b)
  34. # res = similarity(ss)
  35. # print(res)