123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132 |
- import os
- import time
- import math
- from collections import Counter
- import shutil
- import sqlite3
- from config import sqlite_path, sqlite_copy_path, log_msg
- from data_preprocessing import DataPreProcessing
- from word_segment import Word_Segment
- from heap_sort import heap_sort
- class Info_Retrieval():
- def __init__(self, data_process, logger=None, n_grams_flag=False):
- # 数据预处理实例化
- self.dpp = data_process
- # 日志配置
- self.logger = logger
- # 分词算法
- self.word_seg = Word_Segment(n_grams_flag=n_grams_flag)
- # bm25算法参数
- self.k1, self.b = 1.5, 0.75
- # sqlite数据库连接
- self.sqlite_path = sqlite_path
- self.sqlite_copy_path = sqlite_copy_path
-
- def __call__(self, sentence, topk=50):
- # 将搜索语句进行标准化清洗
- sentence = self.dpp.content_clear_func(sentence)
- # 将搜索语句分词
- seg_list, seg_init_list = self.word_seg(sentence)
- # 日志采集
- self.logger.info(log_msg.format(id="文本查重",
- type="info_retrieve分词",
- message=seg_list)) if self.logger else None
- # bm25算法分数计算词典
- scores_dict = dict()
- # "doc_data_statistics"存储文档总数和文档总长度
- doc_data = self.sqlite_fetch("doc_data_statistics")
- # 获取文档总数与平均文档长度
- doc_param = [int(doc_data[1]), int(doc_data[2]) / int(doc_data[1])]
- # 初始化召回文档列表和检索数据列表
- recall_doc_list, term_data_list = [], []
- # 遍历分词元素用于搜索相关数据
- for word in set(seg_list):
- # 通过sqlite数据库检索term相关数据
- term_data = self.sqlite_fetch(word)
- if term_data is not None:
- # 获取每个分词元素对应的文档列表和数据列表
- doc_id_list = str(term_data[1]).split('\n')
- term_data_list.append((doc_id_list, term_data[2]))
- recall_doc_list.extend(doc_id_list)
- # 计算召回文档列表中出现频率最高的文档集合
- recall_doc_set = set([ele[0] for ele in Counter(recall_doc_list).most_common(500)])
- # return [ele[0] for ele in Counter(recall_doc_list).most_common(topk)], seg_list
- # 遍历分词元素数据列表
- for term_data in term_data_list:
- # bm25算法计算倒排索引关键词对应文档得分
- scores_dict = self.BM25(scores_dict, term_data, doc_param, recall_doc_set)
- # 将分数词典进行排序并返回排序结果
- # scores_list = heap_sort(list(scores_dict.items()), topk)
- scores_list = sorted(list(scores_dict.items()), key=lambda x:x[1], reverse=True)[:topk]
- # 对检索结果进行判断并取出排序后的id
- scores_sort_list = [ele[0] for ele in scores_list]
-
- return scores_sort_list, seg_init_list
-
- # bm25算法计算倒排索引关键词对应文档得分
- def BM25(self, scores_dict, term_data, doc_param, recall_doc_set):
- # 获取文档总数与平均文档长度
- doc_count, avg_doc_length = doc_param[0], doc_param[1]
- # 获取当前term对应的文档数量
- doc_id_list = term_data[0]
- doc_freq = len(doc_id_list)
- # 计算idf值
- idf_weight = math.log((doc_count - doc_freq + 0.5) / (doc_freq + 0.5) + 1)
- term_docs = term_data[1].split('\n')
- # 遍历计算当前term对应的所有文档的得分
- for i,doc in enumerate(term_docs):
- # 获取文档id并判断是否在召回文档集合中
- doc_id = doc_id_list[i]
- if doc_id not in recall_doc_set:
- continue
- # 获取词频以及当前文档长度
- tf, dl = map(int, doc.split('\t'))
- # 计算bm25得分(idf值必须大于0)
- score = (idf_weight * self.k1 * tf) / (tf + self.k1 * (1 - self.b + self.b * dl / avg_doc_length))
- # 通过词典锁定文档id累加计分
- scores_dict[doc_id] = scores_dict.get(doc_id, 0) + score
- return scores_dict
-
- # 通过sqlite数据库检索term相关数据
- def sqlite_fetch(self, term):
- try:
- # 建立sqlite数据库链接
- sqlite_conn = sqlite3.connect(self.sqlite_path)
- # 创建游标对象cursor
- cursor = sqlite_conn.cursor()
- # 从sqlite读取数据
- cursor.execute("SELECT * FROM physics WHERE term=?", (term,))
- term_data = cursor.fetchone()
- # 关闭数据库连接
- cursor.close()
- sqlite_conn.close()
- return term_data
- except Exception as e:
- # 关闭已损坏失效的sqlite数据库连接
- cursor.close()
- sqlite_conn.close()
- # 删除损害失效的sqlite数据库
- if os.path.exists(self.sqlite_path):
- os.remove(self.sqlite_path)
- # 复制并重命名已备份sqlite数据库
- shutil.copy(self.sqlite_copy_path, self.sqlite_path)
- return self.sqlite_fetch(term)
- if __name__ == "__main__":
- sent_list = ["中华民族传统文化源远流长!下列诗句能体现“分子在不停地做无规则运动”的是( )<br/>A.大风起兮云飞扬<br/>B.柳絮飞时花满城<br/>C.满架蔷薇一院香<br/>D.秋雨梧桐叶落时"]
- data_process = DataPreProcessing()
- ir = Info_Retrieval(data_process, n_grams_flag=True)
- start = time.time()
- for sentence in sent_list:
- scores_sort_list = ir(sentence)
- print(time.time()-start)
|