Fine-Tuning Multimodal Embedding Models with Marqtune
Summary
This guide will walk you through the process of fine tuning a model based on a pre-trained OpenCLIP model using a multi-modal training dataset. We will then evaluate the performance of the tuned model and compare it with an equivalent evaluation of the pre-trained model to demonstrate an improvement in performance. This tuned model can subsequently be used in a Marqo index to provide more relevant results for queries.
By completing the steps in this walkthrough you will learn how to use the Marqtune Python client to:
- Setup datasets in Marqtune
- Fine-tune a pre-trained model with a training dataset
- Evaluate models with an evaluation dataset
- Download a tuned model
This walkthrough is also available in article format for those wanting to see how it couples with the Marqtune UI.
1. Setup
You will need:
- A Python (3.11+) environment with pip set up
- A Marqo API key with access to Marqtune
- (Recommended) a Python virtualenv to run this walkthrough in
- (Recommended) IPython to run it interactively, though you can simply run it non-interactively by copying all the code snippets below into a single Python script.
- Finally, the Marqtune Python client:
pip install marqtune
Then, in IPython (or in a new Python script file), make the necessary imports and setup the Marqtune Python client. Note
that all the Python snippets in this guide are designed for you to simply copy and paste unchanged (though you are
encouraged to experiment, of course); the api_key
value below is the only exception - you are required to modify it.
from marqtune.client import Client
from marqtune.enums import DatasetType, ModelType, InstanceType
from urllib.request import urlopen
import gzip
import json
import uuid
import re
# suffix is used just to make the dataset and model names unique
suffix = str(uuid.uuid4())[:8]
print(f"Using suffix={suffix} for this walkthrough")
# Change this to your API Key:
api_key = "<YOUR_MARQO_API_KEY>"
marqtune_client = Client(url="https://marqtune.marqo.ai", api_key=api_key)
2. Dataset Creation
We will now create two datasets, one for training and another for evaluation. The datasets will be sourced from a couple of CSV files. The data in these CSV files consists of shopping data generated from a subset of Marqo-GS-10M which is described in more detail in our Open Source GCL repository.
Both CSV files have the same format; however, the first one is larger (100,000 rows) which we will use for training a model, the second is smaller (25,000 rows) which we will use for model evaluation.
The datasets are multi-modal, consisting of both text and images. The images are represented by URLs that Marqtune will use to download.
print("Downloading data files:")
base_path = (
"https://marqo-gcl-public.s3.us-west-2.amazonaws.com/marqtune_test/datasets/v1"
)
training_data = "gs_100k_training.csv"
eval_data = "gs_25k_eval.csv"
open(training_data, "w").write(
gzip.open(urlopen(f"{base_path}/{training_data}.gz"), "rb").read().decode("utf-8")
)
open(eval_data, "w").write(
gzip.open(urlopen(f"{base_path}/{eval_data}.gz"), "rb").read().decode("utf-8")
)
To be able to create datasets in Marqtune first we need to identify the columns in the CSVs as well as their types by defining a data schema. We will reuse the same data schema for both training and evaluation datasets though this is not strictly necessary.
data_schema = {
"query": "text",
"title": "text",
"image": "image_pointer",
"score": "score",
}
After defining the data schema we can then create the two datasets. Note that creating a dataset takes a few minutes to complete as it accomplishes a few steps:
- The CSV file has to be uploaded
- Some simple validations have to pass (e.g. the data schema needs to be validated against each row in the CSV input)
- The URLs in the
image_pointer
columns are used to download the image files to the dataset
# Create the training dataset.
training_dataset_name = f"{training_data}-{suffix}"
print(f"Creating training dataset ({training_dataset_name}):")
training_dataset = marqtune_client.create_dataset(
dataset_name=training_dataset_name,
file_path=training_data,
dataset_type=DatasetType.TRAINING,
data_schema=data_schema,
query_columns=["query"],
result_columns=["title", "image"],
# setting wait_for_completion=True will make this a blocking call and will also print logs interactively
wait_for_completion=True,
)
# Similarly we create the Evaluation dataset.
eval_dataset_name = f"{eval_data}-{suffix}"
print(f"Creating evaluation dataset ({eval_dataset_name}):")
eval_dataset = marqtune_client.create_dataset(
dataset_name=eval_dataset_name,
file_path=eval_data,
dataset_type=DatasetType.EVALUATION,
data_schema=data_schema,
query_columns=["query"],
result_columns=["title", "image"],
wait_for_completion=True,
)
Note that the results of these datasets (and all other resources generated in this walkthrough) can be viewed with the Marqtune UI.
3. Model Tuning
Now we're ready to train a model. In our example for the base pretrained OpenCLIP model we've chosen to use
ViT-B-32 - laion2b_s34b_b79k
which is a good model to start with as it gives us good performance with low latency/memory
usage that we know will run successfully on InstanceType.BASIC
. We previously published a guide to help you
choose the right model for your use case. Note
that some of the models mentioned require more GPU memory than Vit-B-32
so you will need to use
InstanceType.PERFORMANCE
to train with them.
The training_params
dictionary is used to define the training hyperparameters. We've chosen a minimal set of
hyperparameters to get you started - primarily the left/right keys define the columns in the input CSV that we're
training on. You can experiment on these parameters yourself, refer to the
Training Parameters documentation for documentation on these and other parameters
available for training.
# Setup training hyper parameters:
training_params = {
"leftKeys": ["query"],
"leftWeights": [1],
"rightKeys": ["image", "title"],
"rightWeights": [0.9, 0.1],
"weightKey": "score",
"epochs": 5,
}
base_model = "Marqo/ViT-B-32.laion2b_s34b_b79k"
model_name = f"{training_data}-model-{suffix}"
print(f"Training a new model ({model_name}):")
tuned_model = marqtune_client.train_model(
dataset_id=training_dataset.dataset_id,
model_name=f"{training_data}-model-{suffix}",
instance_type=InstanceType.BASIC,
base_model=base_model,
hyperparameters=training_params,
wait_for_completion=True,
)
This training will take a while to complete, though you may choose to run it faster using more powerful hardware:
instance_type=InstanceType.PERFORMANCE
.
It's also worth noting that once training has been successfully kicked off in Marqtune it will continue till completion
no matter what happens to your local client session. On start the logs will show the new model id that can be used
to identify your model - copy this id so that if your local console disconnects for some reason during training you can
always resume the rest of this guide after loading the completed model:
tuned_model = marqtune_client.model('<model id>')
.
4. Evaluating Models
Once we've successfully tuned the model we will want to be able to quantify the performance of the tuned model against the baseline set by the original base model. To do this we can get Marqtune to use the evaluation dataset to run an evaluation on the original base model to establish a baseline and then a subsequent evaluation with the same dataset on the last checkpoint generated by our freshly tuned model.
Finally, we will print out the results of each evaluation which should show the tuned model returning better performance numbers than the base model.
eval_params = {
"leftKeys": ["query"],
"leftWeights": [1],
"rightKeys": ["image", "title"],
"rightWeights": [0.9, 0.1],
"weightKey": "score",
}
print("Evaluating the base model:")
base_model_eval = marqtune_client.evaluate(
dataset_id=eval_dataset.dataset_id,
model=base_model,
hyperparameters=eval_params,
wait_for_completion=True,
)
print("Evaluating the tuned model:")
tuned_model_id = tuned_model.model_id
tuned_checkpoint = tuned_model.describe()["checkpoints"][-1]
tuned_model_eval = marqtune_client.evaluate(
dataset_id=eval_dataset.dataset_id,
model=f"tuned_model_id/tuned_checkpoint",
hyperparameters=eval_params,
wait_for_completion=True,
)
# convenience function to inspect evaluation logs and extract the results
def print_eval_results(description, evaluation):
regexp = re.compile("{'mAP@1000': .*'mRBP9': .*}")
results = next(
(
json.loads(match.group().replace("'", '"'))
for log in evaluation.logs()[-10:]
if (match := regexp.search(log["message"]))
),
None,
)
print(description)
print(json.dumps(results, indent=4))
print_eval_results("Evaluation results from base model:", base_model_eval)
print_eval_results("Evaluation results from tuned model:", tuned_model_eval)
Again, we've chosen a minimal set of hyperparameters for the evaluation tasks, and you can read about these in the Evaluation Parameters documentation.
Due to the inherent stochasticity of training and evaluation the results you see will likely be different from our measurements, but you should see improvements similar to the measurements below (higher numbers are better):
Metric | Base Model | Tuned Model |
---|---|---|
mAP@1000 | 0.23614 | 0.25182 |
mrr@1000 | 0.26416 | 0.28572 |
NDCG@10 | 0.27879 | 0.30076 |
mERR | 0.2321309837009569 | 0.2417180692721506 |
mRBP7 | 0.08388403333492037 | 0.08949556069806149 |
mRBP8 | 0.06776364809797988 | 0.07389784006765014 |
mRBP9 | 0.04509028006250037 | 0.05054958376233629 |
Picking out one of the above metrics: NDCG@10 (Normalized Discounted Cumulative Gain - a measure of the ranking and retrieval quality of the model by comparing top 10 model retrievals with the ground truth) we can see our tuned model performed better than the base model. Similarly, the other metrics also show consistent improvements. Thus, we can conclude that our tuned model performs better than the base model for the domain represented by the evaluation dataset. You can refer to our blog post on Generalised Contrastive Learning for Multimodal Retrieval and Ranking for more information as well as an explanation of each of the metrics above.
5. Download and Cleanup
At this point, you can download the model to your local disk:
tuned_model.download()
From here you can choose to create a Marqo index with this custom model.
Finally, you can choose to (optionally) clean up your generated resources:
training_dataset.delete()
eval_dataset.delete()
tuned_model.delete()
base_model_eval.delete()
tuned_model_eval.delete()