import osimport numpy as npfrom PIL import Imageimport tensorflow as tffrom tensorflow.keras.applications import VGG16from tensorflow.keras.applications.vgg16 import preprocess_inputfrom tensorflow.keras.preprocessing import imageimport faissimport cv2class ImageSearchEngine: def __init__(self, model='vgg16'): # 使用预训练的VGG16模型提取特征 self.model = VGG16(weights='imagenet', include_top=False) self.index = None self.image_paths = [] self.image_embeddings = [] def extract_features(self, img_path): """从图片中提取特征""" try: img = image.load_img(img_path, target_size=(224, 224)) x = image.img_to_array(img) x = np.expand_dims(x, axis=0) x = preprocess_input(x) features = self.model.predict(x) return features.flatten() except Exception as e: print(f"Error processing {img_path}: {str(e)}") return None def build_index(self, image_dir): """构建图像特征索引""" self.image_paths = [] self.image_embeddings = [] # 遍历目录中的所有图片 for root, _, files in os.walk(image_dir): for file in files: if file.lower().endswith(('.png', '.jpg', '.jpeg')): img_path = os.path.join(root, file) features = self.extract_features(img_path) if features is not None: self.image_paths.append(img_path) self.image_embeddings.append(features) # 将特征转换为numpy数组 self.image_embeddings = np.array(self.image_embeddings) # 使用Faiss构建索引 dimension = self.image_embeddings.shape[1] self.index = faiss.IndexFlatL2(dimension) self.index.add(self.image_embeddings) print(f"Index built with {len(self.image_paths)} images") def search_similar_images(self, query_image_path, num_results=5): """搜索相似图片""" if self.index is None: raise ValueError("Index not built yet. Please call build_index() first.") query_features = self.extract_features(query_image_path) if query_features is None: return [] # 搜索最相似的图片 D, I = self.index.search(query_features.reshape(1, -1), num_results) # 返回相似图片的路径和相似度分数 results = [] for i in range(num_results): if I[0][i] < len(self.image_paths): results.append({ 'path': self.image_paths[I[0][i]], 'score': D[0][i] }) return results def visualize_results(self, query_path, results): """可视化搜索结果""" # 读取查询图片 query_img = cv2.imread(query_path) query_img = cv2.resize(query_img, (200, 200)) # 创建结果窗口 result_img = np.zeros((200, 200 * (len(results) + 1), 3), dtype=np.uint8) result_img[:, :200] = query_img # 添加查询图片 cv2.putText(result_img, "Query", (10, 180), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 255, 255), 2) # 添加相似图片 for i, result in enumerate(results): img = cv2.imread(result['path']) img = cv2.resize(img, (200, 200)) result_img[:, (i+1)*200:(i+2)*200] = img cv2.putText(result_img, f"Score: {result['score']:.2f}", ((i+1)*200 + 10, 180), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 255, 255), 2) cv2.imshow("Similar Images", result_img) cv2.waitKey(0) cv2.destroyAllWindows()# 使用示例if __name__ == "__main__": # 创建图像搜索引擎实例 search_engine = ImageSearchEngine() # 构建索引(需要指定图片目录) image_directory = "path/to/your/images" search_engine.build_index(image_directory) # 搜索相似图片 query_image = "path/to/query/image.jpg" results = search_engine.search_similar_images(query_image, num_results=5) # 可视化结果 search_engine.visualize_results(query_image, results)
2025年06月09日 08点06分
7