tujintao 10 місяців тому
батько
коміт
9bf55b4ab6
2 змінених файлів з 0 додано та 89 видалено
  1. 0 40
      bert_whitening.py
  2. 0 49
      heap_sort.py

+ 0 - 40
bert_whitening.py

@@ -1,40 +0,0 @@
-import pickle
-import numpy as np
-
-from config import mongo_coll, whitening_path
-
-# 计算bert-whitening参数
-def compute_kernel_bias(vec_list):
-    """计算kernel和bias
-    vecs.shape = [num_samples, embedding_size]
-    最后的变换: y = (x + bias).dot(kernel)
-    """
-    vecs = np.array(vec_list)
-    mu = vecs.mean(axis=0, keepdims=True)
-    cov = np.cov(vecs.T)
-    u, s, vh = np.linalg.svd(cov)
-    W = np.dot(u, np.diag(1 / np.sqrt(s)))
-    return W, -mu
-
-# 计算句向量列表
-def compute_vec_list(origin_dataset):
-    vec_list = []
-    for data in origin_dataset:
-        if "sentence_vec" not in data:
-            continue
-        sentence_vec = pickle.loads(data["sentence_vec"])
-        if sentence_vec.size != 384:
-            continue
-        vec_list.append(sentence_vec)
-    
-    return vec_list
-
-if __name__ == "__main__":
-    origin_dataset = mongo_coll.find(no_cursor_timeout=True, batch_size=5)
-    # 计算句向量列表
-    vec_list = compute_vec_list(origin_dataset)
-    # 计算bert-whitening参数
-    kernel, bias = compute_kernel_bias(vec_list)
-    # 保存bert-whitening参数
-    with open(whitening_path, 'wb') as f:
-        pickle.dump([kernel, bias], f)

+ 0 - 49
heap_sort.py

@@ -1,49 +0,0 @@
-"""
-在海量数据中选取topK个数据:
-整个操作中,遍历数组需要O(n)的时间复杂度,每次调整小顶堆的时间复杂度是O(logK),加起来就是O(nlogK)的复杂度;
-如果K远小于n的话,O(nlogK)其实就接近于O(n),甚至会更快,因此也是十分高效的.
-"""
-
-# 构建大顶堆
-def heapify(arr, n, i, dim):
-    largest = i
-    left = 2*i + 1
-    right = 2*i + 2
-    # 与左节点进行比较
-    if left < n and arr[i][dim] < arr[left][dim]:
-        largest = left
-    # 与左节点比较后再与右节点进行比较
-    if right < n and arr[largest][dim] < arr[right][dim]:
-        largest = right
-    # 通过上面跟左右节点比较后,得出三个元素之间较大的下标
-    # 若较大下标不是父节点的下标,说明交换后需要重新调整大顶堆
-    if largest != i:
-        arr[i], arr[largest] = arr[largest], arr[i]
-        # 重新调整大顶堆
-        heapify(arr, n, largest, dim)
-
-# 堆排序
-def heap_sort(arr, topk=None, dim=1): 
-    n = len(arr) 
-    # 构造大顶堆,从非叶子节点开始倒序遍历,因此是l//2 -1 就是最后一个非叶子节点
-    for i in range(n//2-1, -1, -1):
-        heapify(arr, n, i, dim)
-    # 若topk不为None,则进行设定
-    topk = n if topk > n or topk is None else topk
-    res_list = []
-    # 上面的循环完成了大顶堆的构造,那么就开始把根节点跟末尾节点交换,然后重新调整大顶堆
-    for i in range(n-1, n-1-topk, -1):
-        res_list.append(arr[0])
-        arr[0] = arr[i]
-        # arr[i], arr[0] = arr[0], arr[i]
-        heapify(arr, i, 0, dim)
-    
-    return res_list
-    # return arr[n-topk:][::-1]
-
-if __name__ == "__main__":
-    arr = [[0,12], [1,11], [2,13], [3,5], [4,6], [5,7]]
-    arr = heap_sort(arr, 3)
-    print ("排序后")
-    for i in range(len(arr)):
-        print(arr[i])