I'm trying to write a program that identifies which images in a directory are similar to a query image, which is similar to but often slightly different from images in the directory. There are thousands of images in the directory. This question is related to Simple and fast method to compare images for similarity.
I have a few goals:
- Using a query image, identify similar images in a directory of images
- The query image might be slightly changed from the images in the directory. These changes might include the image being cropped and different image quality.
- The program should be pretty fast (able to identify similar images in a few seconds at most)
I know this is a question that has a lot of research. A chapter, "Building a Reverse Image Search Engine: Understanding Embeddings" from "Practical Deep Learning for Cloud, Mobile, and Edge" explains some approaches for this question.
I began writing a program to do this using a SIFT (scale-invariant feature transform) + bag of words approach. I don't have much experience in this area. The program I wrote works for an identical image, and pretty well for a slightly similar image, but once the image becomes a bit more dissimilar, it no longer detects the right image.
I have two questions:
- Is the approach I'm using for this good, and if not, is there a better approach?
- Is there anything in my program that might be causing the searches to be inaccurate for dissimilar images?
This is how the program works:
- Go through every image, get its descriptors with SIFT, and build a list of these descriptors.
- Using k-means, find the centroids of the list of descriptors. This is the "dictionary".
- Go through every image again, and get the k-nearest neighbors knnMatch with k=1 for each image's descriptors and the centroids. Use each match to create a histogram for each image, using match.trainIdx.
- Normalize each image's histogram by dividing the count of each "word" by the sum of the "words".
- Use knnMatch with k=1 with the query image's descriptors and the centroids. Go through the matches and create a normalized histogram.
- Use knnMatch with k=1 on the query image's histogram, and the histograms of all of the images in the database. This creates a list of matches, ordered by similarity to the query image.
import numpy as np
import cv2
import os
from matplotlib import pyplot as plt
sift = cv2.xfeatures2d.SIFT_create()
FLANN_INDEX_KDTREE = 0
index_params = dict(algorithm = FLANN_INDEX_KDTREE, trees = 100)
search_params = dict(checks = 100)
flann = cv2.FlannBasedMatcher(index_params, search_params)
bf = cv2.BFMatcher()
img1 = cv2.imread('path',0)
db = # load database
kp1, des1 = sift.detectAndCompute(img1,None)
load = False
clusters = 800
if load:
db.query('DELETE FROM centroids')
db.query('DELETE FROM histogram')
descriptors = []
for file in os.listdir('path'):
if file.endswith('.png'):
img = cv2.imread('path/{}'.format(file), 0)
kp, des = sift.detectAndCompute(img,None)
if des is None:
continue
descriptors.extend(des)
descriptors = np.float32(descriptors)
criteria = (cv2.TERM_CRITERIA_EPS + cv2.TERM_CRITERIA_MAX_ITER, 5, .01)
centroids = cv2.kmeans(descriptors, clusters, None, criteria, 1, cv2.KMEANS_PP_CENTERS)[2]
db.insert('centroids', d = np.ndarray.dumps(centroids))
for file in os.listdir('path'):
counter = np.zeros((clusters,), dtype=np.uint32)
if file.endswith('.png'):
img = cv2.imread('path/{}'.format(file),0)
kp, d = sift.detectAndCompute(img,None)
if d is None:
continue
matches = bf.knnMatch(d, centroids, k=1)
for match in matches:
counter[match[0].trainIdx] += 1
counter_sum = np.sum(counter)
counter = [float(n)/counter_sum for n in counter]
db.insert('histogram', frame_id = file, count=','.join(np.char.mod('%f', counter)))
histograms_db = list(db.query('SELECT * FROM histogram'))
histograms = []
for histogram in histograms_db:
histogram = histogram['count'].split(',')
histograms.append(histogram)
histograms = np.array(histograms)
counter = np.zeros((clusters,), dtype=np.uint32)
centroids = np.loads(db.query('SELECT * FROM centroids')[0]['d'])
matches = bf.knnMatch(des1, centroids, k=1)
for match in matches:
counter[match[0].trainIdx] += 1
counter_sum = np.sum(counter)
counter = [float(n)/counter_sum for n in counter]
matches = bf.knnMatch(np.float32([counter]), np.float32(histograms), k=1)
for match in matches[0]:
print "{} {}".format(histograms_db[match.trainIdx]['frame_id'], match.distance)
name = histograms_db[match.trainIdx]['frame_id']