Annoy最近鄰檢索技術之 “圖片檢索”

本文主要介紹一下NN檢索方式Annoy(Approximate Nearest Neighbors Oh Yeah)的應用,在前幾篇的召回文章中(1.推薦系統召回模型之YouTubeNet;2. 推薦系統召回模型之MIND用戶多興趣網絡實踐)都涉及這個技術點,一直沒有詳細的介紹。本文以圖片檢索爲應用場景,介紹一下Annoy。

1

Annoy算法原理

Annoy是Python的一個模塊,用於空間檢索近鄰的數據。檢索過程分成三步:

  • 建立索引過程

  • 近鄰查詢過程;

  • 返回最終近鄰節點;

首先先來一張2D數據分佈圖:

接下來按照步驟1,2和3進行分析。

1.1 建立索引過程

Annoy的目標是建立一個數據結構,使得查詢一個點的最近鄰點的時間複雜度是次線性。Annoy 通過建立一個二叉樹來使得每個點查找時間複雜度是O(log n)。以下圖爲例,隨機選擇兩個點,以這兩個節點爲初始中心節點,執行聚類數爲2的kmeans過程,最終產生收斂後兩個聚類中心點。這兩個聚類中心點之間連一條線段(灰色短線),建立一條垂直於這條灰線,並且通過灰線中心點的線(黑色粗線)。這條黑色粗線把數據空間分成兩部分。在多維空間的話,這條黑色粗線可以看成等距垂直超平面。

在劃分的子空間內進行不停的遞歸迭代繼續劃分,直到每個子空間最多隻剩下K個數據節點。

通過多次遞歸迭代劃分的話,最終原始數據會形成類似下圖的二叉樹結構。二叉樹底層是葉子節點記錄原始數據節點,其他中間節點記錄的是分割超平面的信息。Annoy建立這樣的二叉樹結構是希望滿足這樣的一個假設:  相似的數據節點應該在二叉樹上位置更接近,一個分割超平面不應該把相似的數據節點分割二叉樹的不同分支上。

根據上述步驟,建立多棵二叉樹樹,構成一個森林。

1.2 近鄰查詢過程

上面已完成節點索引建立過程。如何進行對一個數據點進行查找相似節點集合呢?比如下圖的紅色節點,查找的過程就是不斷看他在分割超平面的哪一邊。從二叉樹索引結構來看,就是從根節點不停的往葉子節點遍歷的過程。通過對二叉樹每個中間節點(分割超平面相關信息)和查詢數據節點進行相關計算來確定二叉樹遍歷過程是往這個中間節點左孩子節點走還是右孩子節點走。通過以上方式完成查詢過程。

查詢過程採用優先隊列機制:採用一個優先隊列來遍歷二叉樹,從根節點往下的路徑,根據查詢節點與當前分割超平面距離進行排序。

1.3 返回最終近鄰節

步驟1會構建多棵二叉樹樹,每棵樹都返回一堆近鄰點後,如何得到最終的Top N相似集合呢?首先所有樹返回近鄰點都插入到優先隊列中,求並集去重, 然後計算和查詢點距離,最終根據距離值從近距離到遠距離排序,返回Top-N近鄰節點集合。

2

圖片檢索實踐

先放一張本文檢索的效果圖:

檢索結果:最相似的 Top-9張商品圖片如下所示:

技術步驟:

  • 下載一批商品圖片,本文使用的商品圖片來源於某電商商城;

  • 下載vgg16模型;

  • 使用vgg16模型提取圖片特徵;

  • 使用Annoy技術對圖片特徵數據構建索引,及建樹;

  • 輸入一張圖片特徵數據,檢索並返回最相似的Top-9張圖片;

2.1 下載一批商品圖片

本文使用的圖片數據來源於某電商商城,下載了30個種類的圖片數據,共計5130張。下載代碼如下:

# encoding="utf-8"
from requests_html import HTMLSession
import re
import os
import time

sku_eng_list = ["Mobile-phone", "T-shirt", "Milk", "Mask", "Headset", \
"Wine", "Helmet", "Fan", "Sneaker", "Cup", \
"Glasses", "Backpack", "UAV", "Sofa", "Bicycle", \
"Cleanser", "Paper", "Bread", "Sausage", "Toilet", \
"Book", "Tire", "Clock", "Mango", "Shrimp", \
"Stroller", "Necklace", "Baby-bottle", "Yuba", "Pot"]


session = HTMLSession()

for inx, key in enumerate(["手機", "T恤", "牛奶", "口罩", "耳機", \
    "酒", "頭盔", "風扇", "運動鞋", "杯子", \
    "眼鏡", "揹包", "無人機", "沙發", "自行車", \
    "洗面奶", "抽紙", "麪包", "香腸", "馬桶", \
    "書", "輪胎", "鐘錶", "芒果", "蝦", \
    "童車", "項鍊", "奶瓶", "浴霸", "鍋"]):

    for j in range(1, 10):
        
        time.sleep(2)
        url = 'https://search.jd.com/Search?keyword=%s&wq=%s&page=%s&s=90&click=0' % \
            (key, key, str(j))

        r = session.get(url)

        for i in range(1, 20):
            try:
                contain_pic_url = str(r.html.find('#J_goodsList > ul > li:nth-child('+str(i)+') > div > div > div.gl-i-tab-content > div.tab-content-item.tab-cnt-i-selected > div.p-img > a > img'))
                src_start = re.search('src',contain_pic_url).end() + 2
                src_end = int(re.search("'",contain_pic_url[src_start:]).start())
                pic_url = 'https:'+contain_pic_url[src_start:src_start + src_end]

                os.chdir('C:\\Users\\Desktop\\figures')
                pic = session.get(pic_url)
                open(sku_eng_list[inx]+'_page_'+str(j)+'_NO_'+str(i)+'.jpg','wb').write(pic.content)

            except:
                try:
                    contain_pic_url = str(r.html.find('#J_goodsList > ul > li:nth-child('+str(i)+') > div > div.p-img > a > img'))
                    src_start = re.search('src',contain_pic_url).end() + 2
                    src_end = int(re.search("'",contain_pic_url[src_start:]).start())
                    pic_url = 'https:'+contain_pic_url[src_start:src_start + src_end]

                    os.chdir('C:\\Users\\Desktop\\figures')
                    pic = session.get(pic_url)
                    open(sku_eng_list[inx]+'_page_'+str(j)+'_NO_'+str(i)+'.jpg','wb').write(pic.content)

                except:
                    pass

    print("Download %s done !!!" % sku_eng_list[inx])

2.2 下載vgg16模型

官方下載地址

https://github.com/fchollet/deep-learning-models/releases/download/v0.1/vgg16_weights_tf_dim_ordering_tf_kernels_notop.h5

如果無法在官方地址下載,可以從百度網盤中下載:

鏈接:https://pan.baidu.com/s/1Exa8g_q9hVmqOU9SBrIxrg
提取碼:qtsb

將 vgg16_weights_tf_dim_ordering_tf_kernels_notop.h5 模型放入 ~/.keras/models 路徑中即可。

2.3 使用vgg16模型提取圖片特徵

該版本的vgg16模型可以將圖片轉化爲維度爲 [7, 7, 512] 的浮點型數據,將該數據“壓平”保存。

from tensorflow.keras.applications.vgg16 import VGG16
from tensorflow.keras.preprocessing import image
from tensorflow.keras.applications.vgg16 import preprocess_input
import numpy as np
from tqdm import tqdm
import pickle
import os

# 1. 加載vgg16模型
model = VGG16(weights='imagenet', include_top=False)
#print(model.summary())


# 2. 提取圖片特徵
img_path = "figures/"

img_name_list = []
img_feature_list = []

for file in tqdm(os.listdir(img_path)):
    img_name_list.append(file)
    file_path = img_path + file
    
    img = image.load_img(file_path, target_size=(224, 224))
    x = image.img_to_array(img)
    x = np.expand_dims(x, axis=0)
    x = preprocess_input(x)

    features = model.predict(x)
    img_feature_list.append(features.reshape((7*7*512,)))


# 3. 將圖片名稱和圖片特徵保存爲pkl格式
f = open("img_feature_list.pkl", 'wb')
pickle.dump(img_feature_list, f)
f.close()

g = open("img_name_list.pkl", 'wb')
pickle.dump(img_name_list, g)
g.close()

此時將得到 img_name_list.pkl 和 img_feature_list.pkl 兩個文件,分別保存圖片的名稱和圖片的特徵。

2.4 使用Annoy技術對圖片特徵數據構建索引,及建樹

建索引及建樹腳本如下所示:

# encoding:utf-8
from annoy import AnnoyIndex
import pickle
import numpy as np
np.random.seed(20200601)
import sys, time
from tqdm import tqdm

def build_ann(name_path=None, \
       vec_path=None, \
       index_to_name_dict_path=None, \
       ann_path=None, \
       dim=64, \
       n_trees=10):

    name_path = open(name_path, 'rb')
    vec_path = open(vec_path, 'rb')
    img_name_list = pickle.load(name_path)
    img_vec_list = pickle.load(vec_path)
    

    ann = AnnoyIndex(dim)
    idx = 0
    batch_size = 100 * 10000
    index_to_name_dict = {}
    

    for name, vec in tqdm(zip(img_name_list, img_vec_list)):
        ann.add_item(idx, vec)
        index_to_name_dict[idx] = name

        idx += 1
        if idx % batch_size == 0:
            print("%s00w" % (int(idx/batch_size)))

    print("Add items Done!\nStart building trees")

    ann.build(n_trees)
    print("Build Trees Done!")
    
    ann.save(ann_path)
    print("Save ann to %s Done!" % (ann_path))

    fd = open(index_to_name_dict_path, 'wb')
    pickle.dump(index_to_name_dict, fd)
    fd.close()
    print("Saving index_to_name mapping Done!")


if __name__ == '__main__':
    name_path = "img_name_list.pkl"
    vec_path = "img_feature_list.pkl"
    index_to_name_dict_path = "index_to_name_dict.pkl"
    ann_path = "img_feature_list.ann"
    dim = 25088
    n_trees = 10

    build_ann(name_path=name_path, \
        vec_path=vec_path, \
        index_to_name_dict_path=index_to_name_dict_path, \
        ann_path=ann_path, \
        dim=dim, \
        n_trees=n_trees)

本實驗構建了10棵二叉樹,此時將得到 index_to_name_dict.pkl 和 img_feature_list.ann 兩個文件,分別保存圖片索引Id與名稱的映射數據,和圖片特徵的二叉樹結構信息。

2.5 輸入一張圖片特徵數據,檢索並返回最相似的Top-9張圖片

話不多說,代碼如下:

# encoding:utf-8
from annoy import AnnoyIndex
import numpy as np
np.random.seed(20200601)
import pickle
import sys
from matplotlib import image as mpimg
from matplotlib import pyplot as plt

def load_ann(ann_path=None, index_to_name_dict_path=None, dim=64):
    ann = AnnoyIndex(dim)
    ann.load(ann_path)

    with open(index_to_name_dict_path, 'rb') as f:
        index_to_name_dict = pickle.load(f)
    return ann, index_to_name_dict


def query_ann(ann=None, index_to_name_dict=None, query_vec=None, topN=None):
    topN_item_idx_list = ann.get_nns_by_vector(query_vec, topN)

    topN_item_id_list = []

    for idx in topN_item_idx_list:
        item_id = index_to_name_dict[idx]
        topN_item_id_list.append(item_id)

    return topN_item_id_list


if __name__ == '__main__':
    index_to_name_dict_path = "index_to_name_dict.pkl"
    ann_path = "img_feature_list.ann"
    name_path = "img_name_list.pkl"
    vec_path = "img_feature_list.pkl"
    dim = 25088
    topN = 9
    
    name_path = open(name_path, 'rb')
    vec_path = open(vec_path, 'rb')
    img_name_list = pickle.load(name_path)
    img_vec_list = pickle.load(vec_path)
    
    idx = 126
    query_name = img_name_list[idx]
    query_vec = img_vec_list[idx]
    
    ann, index_to_name_dict = load_ann(ann_path=ann_path, \
        index_to_name_dict_path=index_to_name_dict_path, \
        dim=dim)

    topN_item_list = query_ann(ann=ann, \
        index_to_name_dict=index_to_name_dict, \
        query_vec=query_vec, \
        topN=topN)

    # query 商品圖片
    print("query_image: \n")
    fig, axes = plt.subplots(1, 1)
    query_image = mpimg.imread("figures/" + query_name)
    axes.imshow(query_image/255)
    axes.axis('off')
    axes.axis('off')
    axes.set_title('%s' % query_name, fontsize=8, color='r')

    # Top-9 相似商品
    fig, axes = plt.subplots(3, 3)
    for idx, img_path in enumerate(topN_item_list):

        i = idx % 3   # Get subplot row
        j = idx // 3  # Get subplot column
        image = mpimg.imread("figures/" + img_path)
        axes[i, j].imshow(image/255)
        axes[i, j].axis('off')
        axes[i, j].axis('off')

        axes[i, j].set_title('%s' % img_path, fontsize=8, color='b')

本實驗以idx=126爲例進行測試,idx取值範圍爲[0, 5129]。

參考:

https://github.com/spotify/annoy

https://blog.csdn.net/hero_fantao/article/details/70245387

歡迎關注 “python科技園” 及 添加小編 進羣交流。

文章好看點這裏

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章