import os
import time
import math
from collections import Counter
import shutil
import sqlite3
from config import sqlite_path, sqlite_copy_path, log_msg
from data_preprocessing import DataPreProcessing
from word_segment import Word_Segment
from heap_sort import heap_sort
class Info_Retrieval():
def __init__(self, data_process, logger=None, n_grams_flag=False):
# 数据预处理实例化
self.dpp = data_process
# 日志配置
self.logger = logger
# 分词算法
self.word_seg = Word_Segment(n_grams_flag=n_grams_flag)
# bm25算法参数
self.k1, self.b = 1.5, 0.75
# sqlite数据库连接
self.sqlite_path = sqlite_path
self.sqlite_copy_path = sqlite_copy_path
def __call__(self, sentence, topk=50):
# 将搜索语句进行标准化清洗
sentence = self.dpp.content_clear_func(sentence)
# 将搜索语句分词
seg_list, seg_init_list = self.word_seg(sentence)
# 日志采集
self.logger.info(log_msg.format(id="文本查重",
type="info_retrieve分词",
message=seg_list)) if self.logger else None
# bm25算法分数计算词典
scores_dict = dict()
# "doc_data_statistics"存储文档总数和文档总长度
doc_data = self.sqlite_fetch("doc_data_statistics")
# 获取文档总数与平均文档长度
doc_param = [int(doc_data[1]), int(doc_data[2]) / int(doc_data[1])]
# 初始化召回文档列表和检索数据列表
recall_doc_list, term_data_list = [], []
# 遍历分词元素用于搜索相关数据
for word in set(seg_list):
# 通过sqlite数据库检索term相关数据
term_data = self.sqlite_fetch(word)
if term_data is not None:
# 获取每个分词元素对应的文档列表和数据列表
doc_id_list = str(term_data[1]).split('\n')
term_data_list.append((doc_id_list, term_data[2]))
recall_doc_list.extend(doc_id_list)
# 计算召回文档列表中出现频率最高的文档集合
recall_doc_set = set([ele[0] for ele in Counter(recall_doc_list).most_common(500)])
# return [ele[0] for ele in Counter(recall_doc_list).most_common(topk)], seg_list
# 遍历分词元素数据列表
for term_data in term_data_list:
# bm25算法计算倒排索引关键词对应文档得分
scores_dict = self.BM25(scores_dict, term_data, doc_param, recall_doc_set)
# 将分数词典进行排序并返回排序结果
# scores_list = heap_sort(list(scores_dict.items()), topk)
scores_list = sorted(list(scores_dict.items()), key=lambda x:x[1], reverse=True)[:topk]
# 对检索结果进行判断并取出排序后的id
scores_sort_list = [ele[0] for ele in scores_list]
return scores_sort_list, seg_init_list
# bm25算法计算倒排索引关键词对应文档得分
def BM25(self, scores_dict, term_data, doc_param, recall_doc_set):
# 获取文档总数与平均文档长度
doc_count, avg_doc_length = doc_param[0], doc_param[1]
# 获取当前term对应的文档数量
doc_id_list = term_data[0]
doc_freq = len(doc_id_list)
# 计算idf值
idf_weight = math.log((doc_count - doc_freq + 0.5) / (doc_freq + 0.5) + 1)
term_docs = term_data[1].split('\n')
# 遍历计算当前term对应的所有文档的得分
for i,doc in enumerate(term_docs):
# 获取文档id并判断是否在召回文档集合中
doc_id = doc_id_list[i]
if doc_id not in recall_doc_set:
continue
# 获取词频以及当前文档长度
tf, dl = map(int, doc.split('\t'))
# 计算bm25得分(idf值必须大于0)
score = (idf_weight * self.k1 * tf) / (tf + self.k1 * (1 - self.b + self.b * dl / avg_doc_length))
# 通过词典锁定文档id累加计分
scores_dict[doc_id] = scores_dict.get(doc_id, 0) + score
return scores_dict
# 通过sqlite数据库检索term相关数据
def sqlite_fetch(self, term):
try:
# 建立sqlite数据库链接
sqlite_conn = sqlite3.connect(self.sqlite_path)
# 创建游标对象cursor
cursor = sqlite_conn.cursor()
# 从sqlite读取数据
cursor.execute("SELECT * FROM physics WHERE term=?", (term,))
term_data = cursor.fetchone()
# 关闭数据库连接
cursor.close()
sqlite_conn.close()
return term_data
except Exception as e:
# 关闭已损坏失效的sqlite数据库连接
cursor.close()
sqlite_conn.close()
# 删除损害失效的sqlite数据库
if os.path.exists(self.sqlite_path):
os.remove(self.sqlite_path)
# 复制并重命名已备份sqlite数据库
shutil.copy(self.sqlite_copy_path, self.sqlite_path)
return self.sqlite_fetch(term)
if __name__ == "__main__":
sent_list = ["中华民族传统文化源远流长!下列诗句能体现“分子在不停地做无规则运动”的是( )
A.大风起兮云飞扬
B.柳絮飞时花满城
C.满架蔷薇一院香
D.秋雨梧桐叶落时"]
data_process = DataPreProcessing()
ir = Info_Retrieval(data_process, n_grams_flag=True)
start = time.time()
for sentence in sent_list:
scores_sort_list = ir(sentence)
print(time.time()-start)