123456789101112131415161718192021222324252627282930313233343536373839404142434445464748 |
- from sentence_transformers import SentenceTransformer, util
- from my_config import LANG_EMB_MODEL
- model = SentenceTransformer(LANG_EMB_MODEL["all"])
- # print("model load time:{}".format(time.time()-s))
- # # Our sentences we like to encode
- def item2emb_all(items_list):
- """
- items= ["已知集合点集,集合点集,求交集",
- "求集合点集与集合点集的交集",
- "求函数根式复合是一次的定义域",
- "函数是奇函数,在区间上单调递增,求参数的取值范围"]
- :param items_list:
- :return:
- """
- if isinstance(items_list, str):
- items_list = [items_list]
- # # Sentences are encoded by calling model.encode()
- item_embeddings = model.encode(items_list)
- return item_embeddings
- if __name__ == '__main__':
- # ss = item2emb_all("已知集合点集,集合点集,求交集")
- # print(ss)
- # print(ss.shape),
- # ss = item2emb_all(["已知集合点集,集合点集,求交集",
- # "求集合点集与集合点集的交集",
- # "求函数根式复合是一次的定义域",
- # "函数是奇函数,在区间上单调递增,求参数的取值范围"])
- ss = item2emb_all(["广泛地阅读", "泛读"]) # ["visual", "scene", "situation"]
- a = util.cos_sim(ss, ss)
- # print(ss.shape)
- # # b = util.cos_sim(ss[:1],ss[1:])
- print(a)
- # print(b)
- # res = similarity(ss)
- # print(res)
|