han_semantic_similarity.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304
  1. #!/usr/bin/env/python
  2. # -*- coding:utf-8 -*-
  3. """
  4. 文本语义相似度
  5. sts([
  6. ('看图猜一电影名', '看图猜电影'),
  7. ('无线路由器怎么无线上网', '无线上网卡和无线路由器怎么用'),
  8. ('北京到上海的动车票', '上海到北京的动车票'),
  9. ])
  10. """
  11. import re
  12. import time
  13. from Utils.util import phrase_classify
  14. from concurrent.futures import ThreadPoolExecutor
  15. from my_config import sts, pos, dict_tags
  16. def batch_tag(hw_list, ans_given):
  17. """
  18. 批量贴标签
  19. :param hw_list:
  20. :param ans_given:
  21. :return:
  22. """
  23. all_tags_with_str = {}
  24. # 多进程获取词性标注
  25. hw_l = [len(i) for i in hw_list]
  26. ans_given_choosed = []
  27. if len(ans_given) > 80: # 选择部分进行词性标注
  28. ans_given_choosedbylen = [a for a in ans_given if len(a) in hw_l]
  29. num = 80 - len(ans_given_choosedbylen)
  30. ans_given_choosed.extend(ans_given_choosedbylen[:80])
  31. if num > 0:
  32. ans_given_choosed.extend(list(set(ans_given) - set(ans_given_choosedbylen))[:num])
  33. else:
  34. ans_given_choosed = ans_given
  35. all_ch = hw_list.copy()
  36. all_ch.extend(ans_given_choosed)
  37. all_ch = list(set(all_ch))
  38. print("all_ch:::", all_ch)
  39. with ThreadPoolExecutor(max_workers=3) as executor:
  40. for future in executor.map(pos_tag_han, all_ch):
  41. all_tags_with_str[future[0]] = future[1]
  42. return all_tags_with_str
  43. def groups_choose(hw_list, ans_given):
  44. """
  45. 相似度计算个数限定
  46. :param hw_list:
  47. :param ans_given:
  48. :return:
  49. """
  50. if len(hw_list) * len(ans_given) > 100:
  51. length_hw = [len(h.strip()) for h in hw_list]
  52. new_anss = [a for a in ans_given if len(a) in length_hw]
  53. anss_rest = [a for a in ans_given if len(a) not in length_hw]
  54. if len(hw_list) * len(new_anss) <= 90:
  55. new_anss2 = [a for a in anss_rest if (len(a) > 2 and len(a)-1 in length_hw) or len(a)+1 in length_hw]
  56. if len(hw_list) * (len(new_anss) + len(new_anss2)) > 100:
  57. new_anss.extend(new_anss2[:int(100/len(hw_list))-len(new_anss)])
  58. elif len(hw_list) * (len(new_anss) + len(new_anss2)) < 50:
  59. new_anss.extend(anss_rest[:int(100 / len(hw_list)) - len(new_anss)])
  60. elif len(hw_list) * len(new_anss) > 100:
  61. new_anss = new_anss[:int(100/len(hw_list))]
  62. return new_anss
  63. return ans_given
  64. def han_similarity(en, hw_list, ans_given, cutted_words, is_token=0):
  65. """
  66. hanlp 短文本相似度
  67. 含"的"时,判断一下词性
  68. is_token:表示hw_list是否为分词的结果
  69. :return:
  70. """
  71. scores_byrow = []
  72. scores_byrow_rawshape = []
  73. num_per_row = []
  74. all_groups = []
  75. part_of_speech_s = []
  76. double_groups_locate = {}
  77. # ----------------先限制下相似度计算个数:<=100----------
  78. ans_given = groups_choose(hw_list, ans_given)
  79. # -------------------------------------------------------
  80. st5 = time.time()
  81. # 一起标注
  82. all_hw_tags = pos_tag_han(hw_list, flag="by_list")
  83. all_ans_tags = pos_tag_han(ans_given, flag="by_list")
  84. # print(all_hw_tags, ans_given, all_ans_tags)
  85. # 多进程获取词性标注
  86. all_tags_with_str = batch_tag(hw_list, ans_given)
  87. print("111111词性标注时间::", time.time() - st5)
  88. print("m*n:::", len(hw_list), hw_list, len(ans_given))
  89. # ---------------------------------------------
  90. for idi, hw in enumerate(hw_list):
  91. # print(idi, len(hw_list))
  92. print("is_token:", is_token)
  93. if is_token and all_hw_tags[idi] in ["d", "u", "ud", "c"]: # 分词时暂不考虑副词
  94. continue
  95. hw = re.sub(r"^([\u4e00-\u9fa5])\1", r"\1", hw) # 例:速率率
  96. hw = re.sub(r"^([使得有我令对让向将和与]+)…+$", r"\1", hw)
  97. row_groups = []
  98. # part_of_speech_s = []
  99. row_groups_locate = []
  100. for idj, ans in enumerate(ans_given):
  101. double_groups = 1
  102. # print(ans, pos_tag_han(ans), hw, pos_tag_han(hw))
  103. row_groups.append((hw, ans))
  104. if cutted_words and len(en.split(" ")) > 1:
  105. ans0, hw0 = ans, hw
  106. is_repl = 0
  107. for j in cutted_words[::-1]: # hw的分词结果
  108. if j+";" in ans + ";":
  109. ans0 = re.sub(j+"$", "", ans0)
  110. hw0 = re.sub(j+"$", "", hw0)
  111. is_repl = 1
  112. if is_repl and ans0 and hw0:
  113. row_groups.append((hw0, ans0))
  114. double_groups += 1
  115. if "…" in hw or "…" in ans:
  116. row_groups.append((hw.replace("…", ""), ans.replace("…", "")))
  117. double_groups += 1
  118. if re.sub(r"(.+)的$", r"\1", ans) == hw and all_hw_tags[idi] != "a":
  119. part_of_speech_s.append(0)
  120. if re.search(r"(.+)的$", ans) and re.search(r"(.+)的$", hw) is None or (
  121. re.search(r"(.+)的$", hw) and re.search(r"(.+)的$", ans) is None and all_ans_tags[idj] != 'a'):
  122. part_of_speech_s.extend([0]*double_groups)
  123. elif re.search(r".{2,}地$", ans):
  124. # if pos_tag_han(ans) in ["d", "v"] and pos_tag_han(hw) not in ["d", "v"]:
  125. if all_ans_tags[idj] in ["d", "v"] and all_hw_tags[idi] not in ["d", "v"]:
  126. part_of_speech_s.extend([0]*double_groups)
  127. elif re.search(r"(.+)地$", ans) is None and (all_hw_tags[idi] == 'a'or all_ans_tags[idj] != 'n'):
  128. part_of_speech_s.extend([0]*double_groups)
  129. elif re.search(r"(.+)的$", hw):
  130. part_of_speech_s.extend([0]*double_groups)
  131. else:
  132. part_of_speech_s.extend([1]*double_groups)
  133. elif ans[-1] in ["地", "的"] and hw[-1] in ["地", "的"] and ans[-1] != hw[-1]:
  134. part_of_speech_s.extend([0]*double_groups)
  135. elif ans[-1] not in ["地", "的"] and hw[-1] in ["地", "的"] and ans == hw[:-1]:
  136. part_of_speech_s.extend([0]*double_groups)
  137. elif "的人" in ans and all_ans_tags[idj] == 'n' and all_hw_tags[idi] == 'a' \
  138. or ("的人" in hw and all_hw_tags[idi] == 'n' and all_ans_tags[idj] == 'a'):
  139. part_of_speech_s.extend([0]*double_groups)
  140. elif "…" + hw + ";" in ans+";" and phrase_classify(en) in ["prep-phrase", "v-phrase"]:
  141. part_of_speech_s.extend([0]*double_groups)
  142. elif re.search("[()()]", hw) is None and re.search(hw+"…+[\u4e00-\u9fa5]", ans):
  143. part_of_speech_s.extend([0] * double_groups)
  144. elif all_ans_tags[idj] == 'v' and re.search("^使", ans):
  145. row_groups.append((hw, re.sub("^使", "", ans)))
  146. double_groups += 1
  147. part_of_speech_s.extend([1] * double_groups)
  148. else:
  149. hw_pos = [[all_hw_tags[idi]]] # hw in dict_tags:hw_pos = [pos([hw])]
  150. ans_pos = [[all_ans_tags[idj]]]
  151. if hw not in dict_tags and hw in all_tags_with_str:
  152. # if hw not in all_tags_with_str:
  153. # a_pos = pos(hw)
  154. # all_tags_with_str[hw] = a_pos
  155. hw_pos.extend(all_tags_with_str[hw])
  156. if ans not in dict_tags and ans in all_tags_with_str:
  157. # ans_pos = [pos([ans])]
  158. # ans_pos = [[all_ans_tags[idj]]]
  159. # if ans not in all_tags_with_str:
  160. # b_pos = pos(ans)
  161. # all_tags_with_str[ans] = b_pos
  162. ans_pos.extend(all_tags_with_str[ans])
  163. # print(hw_pos, ans_pos)
  164. if all([True if i not in sum(ans_pos, []) else False for i in sum(hw_pos, [])]): # 词性不同,如a和v,逻辑待定
  165. part_of_speech_s.extend([0]*double_groups)
  166. else:
  167. part_of_speech_s.extend([1]*double_groups)
  168. if not row_groups_locate:
  169. row_groups_locate.append([0, double_groups])
  170. else:
  171. bef = row_groups_locate[-1][1]
  172. row_groups_locate.append([bef, bef + double_groups])
  173. double_groups_locate[idi] = row_groups_locate
  174. # print("row_groups:", row_groups)
  175. num_per_row.append(len(row_groups))
  176. all_groups.extend(row_groups)
  177. # 开始是每个hw与多个参考答案为一组算相似度
  178. # simi_score = sts(row_groups)
  179. # print("part_of_speech_s:", part_of_speech_s)
  180. # print("simi_score[{}]:".format(row_groups), simi_score)
  181. # scores_byrow.append(simi_score)
  182. # if 0 in part_of_speech_s:
  183. # scores_byrow[-1] = list(map(lambda x, y: x * y, scores_byrow[-1], part_of_speech_s))
  184. # 所有组合一起算相似度
  185. print("词性标注所花时间:", time.time()-st5)
  186. st1 = time.time()
  187. print(all_groups)
  188. # print(part_of_speech_s)
  189. if all_groups:
  190. simi_score = sts(all_groups)
  191. print("99999999时间:", time.time() - st1)
  192. # print(num_per_row)
  193. # --------------局部更新相似度---------------------------
  194. # ------------对相似度>0.9的单词对调换位置重新计算相似度------------
  195. groups_09 = [[i, all_groups[i]] for i, j in enumerate(simi_score) if j > 0.9 and part_of_speech_s[i]]
  196. rejudge_groups = [(s[1][1], s[1][0]) for s in groups_09]
  197. print("rejudge_groups>0.9:", rejudge_groups)
  198. if groups_09 and len(rejudge_groups) <= 3 and len(re.findall("[a-zA-Z'\-\(\)()]+", en.strip())) == 1:
  199. simi_score_reversed = sts(rejudge_groups)
  200. simi_score_le_09 = [i for i, si in enumerate(simi_score_reversed) if si < 0.6]
  201. if simi_score_le_09:
  202. for i in simi_score_le_09:
  203. simi_score[groups_09[i][0]] = simi_score_reversed[i]
  204. # -------------对相似度>0.8的词对调换位置重新计算相似度,最后选取最大的-------------
  205. elif not groups_09:
  206. groups_08 = [[i, all_groups[i]] for i, j in enumerate(simi_score) if j > 0.8 and part_of_speech_s[i]]
  207. rejudge_groups = [(s[1][1], s[1][0]) for s in groups_08]
  208. print("rejudge_groups>0.8:", rejudge_groups)
  209. simi_score_reversed = sts(rejudge_groups)
  210. simi_score_ge_08 = [i for i, si in enumerate(simi_score_reversed) if si > 0.9]
  211. if simi_score_ge_08:
  212. for i in simi_score_ge_08:
  213. simi_score[groups_08[i][0]] = simi_score_reversed[i]
  214. # ------------------------------------------------------------------------------
  215. idx = 0
  216. for n in num_per_row:
  217. score_a_row = simi_score[idx: idx + n]
  218. if 0 in part_of_speech_s[idx: idx + n]:
  219. score_a_row = list(map(lambda x, y: x * y, score_a_row, part_of_speech_s[idx: idx + n]))
  220. idx += n
  221. scores_byrow.append(score_a_row)
  222. # ---------------按原来的group—shape还原-------------
  223. if not is_token and len(double_groups_locate) == len(scores_byrow) and all(
  224. [True if k1 < len(scores_byrow) and v1[-1][1] == len(scores_byrow[k1]) else False for k1, v1
  225. in double_groups_locate.items()]):
  226. for idn, row in enumerate(list(double_groups_locate.values())):
  227. new_row = []
  228. score_row = scores_byrow[idn]
  229. if any([True for r in row if r[1] - r[0] > 1]):
  230. for r in row:
  231. if r[1] - r[0] > 1:
  232. new_row.append(max(score_row[r[0]: r[1]]))
  233. else:
  234. new_row.append(score_row[r[0]])
  235. else:
  236. new_row = score_row
  237. scores_byrow_rawshape.append(new_row)
  238. # print(scores_byrow)
  239. return scores_byrow, scores_byrow_rawshape
  240. def pos_tag_han(w, flag="by_str"):
  241. """
  242. 词性标注
  243. :return:
  244. """
  245. # print(hanlp.pretrained.pos.ALL) # 打印所有的训练数据
  246. if flag == "by_str":
  247. return w, pos(w)
  248. if type(w) == str:
  249. return pos([w])[0]
  250. elif type(w) == list:
  251. return pos(w)
  252. if __name__ == '__main__':
  253. import time
  254. t1 = time.time()
  255. # print(pos_tag_han(['环境友好的'], flag="by_list"))
  256. simi_score = sts([('画廊', '美术画廊'), ('美术画廊', '画廊')])
  257. print(simi_score)
  258. # '使物质分解', '破裂', '分解', '消除', '损坏', '机器或车辆出毛病', '讨论、关系或系统失败'
  259. # '破除', '感情失控(痛哭起来)', '感情失控', '感情失控痛哭起来',
  260. # han_similarity(['看图猜一电影名', '看图猜电影'],['北京到上海的动车票', '上海到北京的动车票'])
  261. # a = han_similarity(['抛弃'],['破除','捣毁', '拆除', '破除障碍或偏见', '破除(障碍或偏见)'])
  262. # # # a = sts([('特别地','尤其'),('特别地','特别')])
  263. # print(a)
  264. # a = han_similarity("effective measure", ['有效的措施'], ['有效措施'], is_token=0)
  265. # print(a)
  266. #
  267. # # print(time.time()-t1)
  268. a = pos("草拟")
  269. # print(pos_tag_han("范畴"))
  270. # a.append(pos_tag("明白"))
  271. print(a)
  272. # print(pos_tag(["明白"]))
  273. # print(pos_tag("明白"))
  274. # print(pos("知道"))
  275. # print(pos("明白"))
  276. # a = pos(["必"])
  277. # b = pos(["分解"])
  278. # c = pos(["不是"],)
  279. # print(c)
  280. # b1 = [['j'], ['v'], ['v'], ['v']]
  281. # b2 = [['j'], ['v'], ['n']]
  282. # print(all([True if i not in b1 else False for i in b2]))
  283. # aa = ['“垃圾', '废弃物无用的东西', '乌七八糟的东西', '垃圾', '废物']
  284. # pos_tag_han(aa)
  285. # rrr = pos_tag_han("不想要的", flag="by_list")
  286. # print(rrr)