Skip to main content
Version: User Guides (Cloud)

Image Search with Zilliz Cloud and PyTorch

On this page, we are going to go over a simple image search example using Zilliz Cloud. The dataset we are searching through is the Impressionist-Classifier Dataset found on Kaggle. For this example, we have re-hosted the data in a public google drive.

For this example, we are just using a 1 CU cluster and using the Torchvision pre-trained ResNet50 model for embeddings. Let's get started!

Before you start

For this example, we are going to use pymilvus to connect to Zilliz Cloud, torch to run the embedding model, torchvision for the actual model and preprocessing, gdown to download the example dataset and tqdm for loading bars.

pip install pymilvus torch gdown torchvision tqdm

Prepare data

We are going to use gdown to grab the zip from Google Drive and then decompress it with the built-in zipfile library.

import gdown
import zipfile

url = '<https://drive.google.com/uc?id=1OYDHLEy992qu5C4C8HV5uDIkOWRTAR1_>'
output = './paintings.zip'
gdown.download(url, output)

with zipfile.ZipFile("./paintings.zip","r") as zip_ref:
zip_ref.extractall("./paintings")
📘Notes

The size of the dataset is 2.35 GB, and the time spent depends on your network condition.

Parameters

These are some of the main global arguments that we will be using for easier tracking and updating.

# 1. Set up the name of the collection to be created.
COLLECTION_NAME = 'image_search_db'

# 2. Set up the dimension of the embeddings.
DIMENSION = 2048

# 3. Set the inference parameters
BATCH_SIZE = 128
TOP_K = 3

# 4. Set up the connection parameters for your Zilliz Cloud cluster.
URI = 'YOUR_CLUSTER_ENDPOINT'
TOKEN = 'YOUR_CLUSTER_TOKEN'

Setting up Zilliz Cloud

At this point, we are going to begin setting up Zilliz Cloud. The steps are as follows:

  1. Connect to the Zilliz Cloud cluster using the provided URI.

    from pymilvus import connections

    # Connect to Zilliz Cloud and create a collection
    connections.connect(
    alias='default',
    # Public endpoint obtained from Zilliz Cloud
    uri=URI,
    token=TOKEN
    )
  2. If the collection already exists, drop it.

    from pymilvus import utility

    # Remove any previous collections with the same name
    if COLLECTION_NAME in utility.list_collections():
    utility.drop_collection(COLLECTION_NAME)
  3. Create the collection that holds the ID, the file path of the image, and its embedding.

    from pymilvus import FieldSchema, CollectionSchema, DataType, Collection

    fields = [
    FieldSchema(name='id', dtype=DataType.INT64, is_primary=True, auto_id=True),
    FieldSchema(name='filepath', dtype=DataType.VARCHAR, max_length=200), # VARCHARS need a maximum length, so for this example they are set to 200 characters
    FieldSchema(name='image_embedding', dtype=DataType.FLOAT_VECTOR, dim=DIMENSION)
    ]

    schema = CollectionSchema(fields=fields)

    collection = Collection(
    name=COLLECTION_NAME,
    schema=schema,
    )
  4. Create an index on the newly created collection and load it into memory.

    index_params = {
    'index_type': 'AUTOINDEX',
    'metric_type': 'L2',
    'params': {}
    }

    collection.create_index(
    field_name='image_embedding',
    index_params=index_params
    )

    collection.load()

Once these steps are done the collection is ready to be inserted into and searched. Any data added will be indexed automatically and be available to search immediately. If the data is very fresh, the search might be slower as brute force searching will be used on data that is still in process of getting indexed.

Insert data

In this example, we will use the ResNet50 model from the torch library and its model hub. To obtain embeddings, we will remove the final classification layer, resulting in the model providing embeddings of 2048 dimensions. All vision models found on torch use the same preprocessing method, which we have included here.

In the following steps, we will:

  1. Load the data.

    import glob

    # Get the filepaths of the images
    paths = glob.glob('./paintings/paintings/**/*.jpg', recursive=True)
    len(paths)

    # Output
    #
    # 4978
  2. Preprocess the data into batches.

    import torch

    # Load the embedding model with the last layer removed
    model = torch.hub.load('pytorch/vision:v0.10.0', 'resnet50', weights=ResNet50_Weights.DEFAULT)
    model = torch.nn.Sequential(*(list(model.children())[:-1]))
    model.eval()
  3. Embed the data.

    from torchvision import transforms

    # Preprocessing for images
    preprocess = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ])
  4. Insert the data.

    from PIL import Image
    from tqdm import tqdm

    # Embed function that embeds the batch and inserts it
    def embed(data):
    with torch.no_grad():
    output = model(torch.stack(data[0])).squeeze()
    collection.insert([data[1], output.tolist()])

    data_batch = [[],[]]

    # Read the images into batches for embedding and insertion
    for path in tqdm(paths):
    im = Image.open(path).convert('RGB')
    data_batch[0].append(preprocess(im))
    data_batch[1].append(path)
    if len(data_batch[0]) % BATCH_SIZE == 0:
    embed(data_batch)
    data_batch = [[],[]]

    # Embed and insert the remainder
    if len(data_batch[0]) != 0:
    embed(data_batch)

    # Call a flush to index any unsealed segments.
    time.sleep(5)
    📘Notes

    This step is relatively time-consuming because embedding takes time. Take a sip of coffee and relax.

    PyTorch may not work well with Python 3.9 and earlier versions. Considering using Python 3.10 and later versions instead.

With all the data inserted into Zilliz Cloud, we can start performing our searches. In this example, we are going to search for two example images. Because we are doing a batch search, the search time is shared across the images of the batch.

import glob

# Get the filepaths of the search images
search_paths = glob.glob('./paintings/test_paintings/**/*.jpg', recursive=True)
print(len(search_paths))

# Output
#
# 2

import time
from matplotlib import pyplot as plt

# Embed the search images
def embed(data):
with torch.no_grad():
ret = model(torch.stack(data))
# If more than one image, use squeeze
if len(ret) > 1:
return ret.squeeze().tolist()
# Squeeze would remove batch for single image, so using flatten
else:
return torch.flatten(ret, start_dim=1).tolist()

data_batch = [[],[]]

for path in search_paths:
im = Image.open(path).convert('RGB')
data_batch[0].append(preprocess(im))
data_batch[1].append(path)

embeds = embed(data_batch[0])
start = time.time()
res = collection.search(embeds, anns_field='image_embedding', param={}, limit=TOP_K, output_fields=['filepath'])
finish = time.time()

# Show the image results
f, axarr = plt.subplots(len(data_batch[1]), TOP_K + 1, figsize=(20, 10), squeeze=False)

for hits_i, hits in enumerate(res):
axarr[hits_i][0].imshow(Image.open(data_batch[1][hits_i]))
axarr[hits_i][0].set_axis_off()
axarr[hits_i][0].set_title('Search Time: ' + str(finish - start))
for hit_i, hit in enumerate(hits):
axarr[hits_i][hit_i + 1].imshow(Image.open(hit.entity.get('filepath')))
axarr[hits_i][hit_i + 1].set_axis_off()
axarr[hits_i][hit_i + 1].set_title('Distance: ' + str(hit.distance))

# Save the search result in a separate image file alongside your script.
plt.savefig('search_result.png')

The search result image should be similar to the following:

image_search