CLIP (Contrastive Language–Image Pre-training) was released by OpenAI about two years ago (January 2021), which is the foundation and a major component of the “game-changer” text-to-image model Stable Diffusion (released by Stability AI in August 2022).
I developed this tutorial to show the following tasks using CLIP and Unsplash dataset:
The data and code can be accessed at this repo.
This tutorial is based on these references.
The dataset for this tutorial is based on the open dataset from Unsplash. I did the following to make the dataset more accessible:
1k-compressed
(~42M) is included in this repo. 1k
(~8.45G) and 5k-compressed
(~213M) can be downloaded from Kaggle.
Simply put, zero-shot learning means a model can recognize things that it hasn’t explicitly seen before during training.
For example, the Unsplash photos do not have explicit labels/tags to show:
In a typical supervised learning setting, you would label some photos with dawn/day/dusk/night and indoor/outdoor and then train a model to label the rest of your data.
With zero-shot learning, you don’t need manual labeling, which is often costly.
In this notebook, I show how to use CLIP to do zero-shot image classification to add two additional features to the dataset:
taken_time
with values: dawn
, day
, dusk
, or night
taken_space
with values: indoor
or outdoor
CLIP makes this task very simple:
given an image and a text prompt, CLIP can return a cosine similarity score.
The key steps are:
Create the text prompts for the labels, e.g., A photo taken at dawn
, A photo taken at day
, A photo taken at dusk
, A photo taken at night
, and tokenize them:
labels = ['dawn', 'day', 'dusk', 'night'] # the target labels
text_prompts = ['A photo taken at ' + label for label in labels]
text = clip.tokenize(text_prompts).to(device) # tokenize the text
load and preprocess each image:
image = preprocess(Image.open(im)).unsqueeze(0).to(device) # preprocess image
calculate the cosine similarity scores (logit scores) between the images and the 4 prompts and use the label with the highest score as the prediction:
logits_per_image, logits_per_text = model(image, text) # return cosine similarity scores
probs = logits_per_image.softmax(dim=-1).cpu().numpy() # use softmax to get probabilities
choices = np.argmax(probs) # use argmax to get the largest number as the prediction
For example, the cosine similarity scores for the following photo are [0.488 0.1882 0.3152 0.00908]
and therefore is predicted to be taken at dawn (0.488
is the largest) - this is pretty good given the dew on the grass.
The histograms of the classification results are as follows:
I wrote a code snippet to randomly show the prediction results, which you can use to evaluate the correctness of the labels:
# random verify the result
# sample one for each cls: dusk, day, dawn, night
photos_pred = pd.read_csv('photos-1k-pred.csv')
samples = photos_pred.groupby('taken_time').sample(1)
ids = list(samples.photo_id)
labels = list(samples.taken_time)
total_photos = len(ids)
sample_photos = [f'./1k-compressed/'+ idx + '.jpg' for idx in ids]
# display the result horizontally
fig = figure(figsize=(20, 80))
for i in range(total_photos):
ax = fig.add_subplot(1, total_photos, i+1)
image = imread(sample_photos[i])
imshow(image)
ax.set_title(labels[i])
axis('off')
As always, some predictions are very good (day
and night
are quite accurate) and some are just OK (dawn
and dusk
in some cases are even hard for a human to tell).
I also wrote another notebook to show how to do this using HuggingFace Transformers.
Assuming we don’t have any tags related to the photos, then we cannot use lexical search, i.e., searching for literal matches of the query words/phrases (and/or variants), to retrieve the relevant photos.
Semantic search means to search with meaning, which aims to understand the intent and contextual meaning of the search query. For example, when we search dog in the snow
, we hope the search engine can understand the meaning of the query instead of simply searching for photos with tags, such as dog
, snow
, etc. In addition, the images must also be represented with meaning beyond the typical pixel information.
CLIP is trained using a multimodal learning approach, which enables encoding images and texts as numerical features (tensors/vectors) with contextual meaning, e.g., the text encoder in Stable Diffusion is just CLIP text encoder. The encoded image and text features can then be used in various downstream tasks, such as semantic image search, search with an image, image captioning, etc.
In this notebook, I show how to conduct semantic image search, which includes the following key steps:
Encode all images (processed in batches and then combined into one large numpy array):
# Load all the photos from the files
photos = [Image.open(photo_file) for photo_file in photos_batch]
# Preprocess all photos by stacking the result for each photo
photos_preprocessed = torch.stack([preprocess(photo) for photo in photos]).to(device)
with torch.no_grad():
# Encode the photos batch to compute the feature vectors and normalize them
photos_features = model.encode_image(photos_preprocessed)
photos_features /= photos_features.norm(dim=-1, keepdim=True)
...
# Load all numpy files
features_list = [np.load(features_file) for features_file in sorted(features_path.glob("*.npy"))]
# Concatenate the features and store them in a merged file
features = np.concatenate(features_list)
np.save(features_path / "features.npy", features)
Encode the search query, such as dog in the snow
:
search_query = "dog in the snow"
with torch.no_grad():
# Encode and normalize the description using CLIP
text_encoded = model.encode_text(clip.tokenize(search_query).to(device))
text_encoded /= text_encoded.norm(dim=-1, keepdim=True)
Compare the text feature to all image features via matrix product @
/torch.matmul()
and find the best matches (the top N results).
similarities = list((text_features @ photo_features.T).squeeze(0))
# Sort the photos by their similarity score
best_photos = sorted(zip(similarities, range(photo_features.shape[0])), key=lambda x: x[0], reverse=True)
print(best_photos[:3])
Let’s look at some results:
Search dog
and the photos are all dogs (remember there are no tags):
Search dog in the snow
and context snow
is successfully captured:
Search music
- an abstract word but all music-related photos are returned! - there are only two photos are music-related - I went through all 1000 photos to make sure (browsing beautiful photos is a pleasure ^_^)
Search listen to music
- see how the photo with a band is not shown anymore - amazing.
Feel free to try more search queries and let me know if you find any additional interesting examples.
By default, CLIP resizes the shortest edge of the image to 224 and does a center crop of 224x224 (See code). It’s useful to understand how to manipulate images using Python.
In this notebook, I show the basic image manipulations using Pillow, which is used in CLIP preprocessing step, such as:
I also demonstrate how to “de-process” and show the preprocessed image by CLIP.
In another notebook, I show how to use CLIP (a multimodal model) and EfficientNet (a CV model) to extract image features and use those features to do traditional classification and clustering (with PCA for dimensionality reduction). The task is to use image features to predict the camera ISO levels used to take the photo, which may not be a good task given that there are so many other related features such as aperture and shutter speed. A few hundreds of images may also not enough to train a good model - focus on how to do these tasks rather than the performance.
As a side note, I did some tests with CLIP encoding using the 1k original Unsplash photos on my MacBook Pro with M1 (16G 2020 model).
MPS is slower than CPU in most cases.
Batch Size | MPS | CPU |
---|---|---|
100 | 5m 12.2s | 4m 51.4s |
500 | 4m 54.5s | 4m 51.7s |
1000 | 4m 54.1s | 4m 57.1s |
See additional related discussions here, such as “…. you will only observe the difference when inputs are large enough … This is because loading small data to memory and using GPU for calculation is overkill so the CPU takes advantage but if you have large data, GPU can compute it efficiently and take over the CPU …” by @bikcrum.
PS. The first image for this post is generated via Midjourney using the prompt “millions of random words and pictures connected with strings, flying in the mystery universe”.