tjt 1 year ago
parent
commit
273e484524
2 changed files with 9 additions and 9 deletions
  1. 5 5
      heap_sort.py
  2. 4 4
      info_retrieval.py

+ 5 - 5
heap_sort.py

@@ -23,23 +23,23 @@ def heapify(arr, n, i, dim):
         heapify(arr, n, largest, dim)
 
 # 堆排序
-def heap_sort(arr, topK=None, dim=1): 
+def heap_sort(arr, topk=None, dim=1): 
     n = len(arr) 
     # 构造大顶堆,从非叶子节点开始倒序遍历,因此是l//2 -1 就是最后一个非叶子节点
     for i in range(n//2-1, -1, -1):
         heapify(arr, n, i, dim)
-    # 若topK不为None,则进行设定
-    topK = n if topK > n or topK is None else topK
+    # 若topk不为None,则进行设定
+    topk = n if topk > n or topk is None else topk
     res_list = []
     # 上面的循环完成了大顶堆的构造,那么就开始把根节点跟末尾节点交换,然后重新调整大顶堆
-    for i in range(n-1, n-1-topK, -1):
+    for i in range(n-1, n-1-topk, -1):
         res_list.append(arr[0])
         arr[0] = arr[i]
         # arr[i], arr[0] = arr[0], arr[i]
         heapify(arr, i, 0, dim)
     
     return res_list
-    # return arr[n-topK:][::-1]
+    # return arr[n-topk:][::-1]
 
 if __name__ == "__main__":
     arr = [[0,12], [1,11], [2,13], [3,5], [4,6], [5,7]]

+ 4 - 4
info_retrieval.py

@@ -24,7 +24,7 @@ class Info_Retrieval():
         self.sqlite_path = sqlite_path
         self.sqlite_copy_path = sqlite_copy_path
     
-    def __call__(self, sentence, topK=50):
+    def __call__(self, sentence, topk=50):
         # 将搜索语句进行标准化清洗
         sentence = self.dpp.content_clear_func(sentence)
         # 将搜索语句分词
@@ -53,14 +53,14 @@ class Info_Retrieval():
                 recall_doc_list.extend(doc_id_list)
         # 计算召回文档列表中出现频率最高的文档集合
         recall_doc_set = set([ele[0] for ele in Counter(recall_doc_list).most_common(500)])
-        # return [ele[0] for ele in Counter(recall_doc_list).most_common(topK)], seg_list
+        # return [ele[0] for ele in Counter(recall_doc_list).most_common(topk)], seg_list
         # 遍历分词元素数据列表
         for term_data in term_data_list:
             # bm25算法计算倒排索引关键词对应文档得分
             scores_dict = self.BM25(scores_dict, term_data, doc_param, recall_doc_set)
         # 将分数词典进行排序并返回排序结果
-        # scores_list = heap_sort(list(scores_dict.items()), topK)
-        scores_list = sorted(list(scores_dict.items()), key=lambda x:x[1], reverse=True)[:topK]
+        # scores_list = heap_sort(list(scores_dict.items()), topk)
+        scores_list = sorted(list(scores_dict.items()), key=lambda x:x[1], reverse=True)[:topk]
         # 对检索结果进行判断并取出排序后的id
         scores_sort_list = [ele[0] for ele in scores_list]