Quick Start - Marqtune
This guide is designed to get you working with Marqtune quickly with just the basics. If you're looking for something more detailed and comprehensive, visit our Getting Started guide.
Full code:
If you have any questions or need help, visit our Community and ask in the get-help
channel.
Marqo API Key
To run Marqtune, you will need a Marqo API Key. To do this,
- Sign up, if you haven't already, to Marqo Cloud.
- Navigate to API Keys. You can either create a new API key or use the default. Copy this key, you'll need it when we begin programming later.
For a full walkthrough on how to find your API key, visit our article.
Installation
To use Marqtune, all you need to run is:
pip install marqtune
Dataset
We will be using a CSV file which you can download from here: Marqtune WCLIP Pairs.
Place this into the directory you're working in.
What to Expect
When creating your Marqo Cloud account, the UI will look as follows:
Notice how there's a section on the left-handside navigation bar for Marqtune; this is where we'll be fine-tuning our embedding models! In the next section, we'll explain how you can do this.
Fine-Tuning with Marqtune
We'll now show you how to fine-tune a base model with Marqtune.
1. Initial Imports and API Key
Navigate to a Python script and input the following:
from marqtune import Client
from marqtune.enums import DatasetType, ModelType, InstanceType
# Define Marqo Cloud API Key. For information visit: https://marqo.ai/blog/finding-my-marqo-api-key
api_key = "your_api_key"
To obtain your API Key, visit our article.
2. Set up Marqtune Client
Next, we set up the Marqtune client:
marqtune_client = Client("https://marqtune.marqo.ai", api_key=api_key)
# Specify path to your csv. We will use a small dataset
input_data_path = "quick-start/marqtune_wclip_pairs.csv"
# Specify model name
model_name = "quick_start_marqtune"
3. Creating Dataset
We now take our csv file and create a dataset using this with Marqtune.
# Define dataset schema. These headings MUST match those in your csv.
dataset_schema = {
"text-1": "text",
"text-2": "text",
"image-1": "image_pointer",
"score": "score",
}
# Creating dataset
print(f"Creating dataset with name: {model_name}_dataset")
dataset = marqtune_client.create_dataset(
model_name + "_dataset",
input_data_path,
data_schema=dataset_schema,
dataset_type=DatasetType.TRAINING,
)
4. Fine-Tuning Base Model
First, we define the base model and checkpoints:
# Define base model and checkpoints to perform fine-tuning on
base_model = "ViT-B-32"
base_checkpoint = "laion400m_e31"
Next, we specify the training task parameters:
# Define training task parameters
train_task_params = {
"warmup": 0,
"epochs": 5,
"lr": 2e-05,
"precision": "amp",
"workers": 2,
"batchSize": 256,
"wd": 0.02,
"weightedLoss": "ce",
"rightKeys": ["text-2"],
"leftWeights": [1],
"rightWeights": [1],
"leftKeys": ["text-1"],
"weightKey": "score",
}
# Fine-tune the base model
tuned_model = marqtune_client.train_model(
dataset.dataset_id,
model_name,
base_model,
base_checkpoint,
ModelType.OPEN_CLIP,
instance_type=InstanceType.BASIC,
hyperparameters=train_task_params,
)
Awesome, the model has now been fine-tuned so we can download it.
# Download the model in '.pt' format
print("Downloading model and logs")
marqtune_client.model(model_id=tuned_model.model_id).download()
Code
Full code:
Next Steps
In this quick start guide, we've seen how you can fine-tune a base embedding model with Marqtune. If you want to see a more complex fine-tuning example as well as how to perform evaluations with Marqtune, visit our more comprehensive Getting Started Guide.