Skip to content

Train Model

Trains model based on selected model and selected dataset. The base model can be an open_clip model or a Marqtuned model.


POST /models/train

Body Parameters

Name Type Default value Description
datasetId UUID "" Required - ID of the training dataset already created.
modelName String "" Required - Name of the model that will be created by this task.
baseModel String "" Required - Name of the model or model ID to fine-tune. Model name must be from open_clip library. ID can be any Marqtuned model in your account.
baseCheckpoint String "" Required - Checkpoint of the model to fine-tune. Checkpoint must be from open_clip, or the epoch from a Marqtuned model.
modelType String "" Required - Type of model being fine-tuned. openclip or marqtuned.
maxTrainingTime Float 86400 Optional - Maximum time to run the training task, in seconds. Default is 86400 = 24 hours. The training task will be automatically terminated when the max time is reached.
hyperparameters Dictionary "" Required - Training task parameters - see the Training parameters guide for details.
instanceType String marqtune.basic Required - marqtune.basic or marqtune.performance instance type for performing the training. More details can be found in the Getting Started with Marqtune guide.
waitForCompletion Boolean True Optional[py-marqtune client only] - Instructs the client to continuously wait and poll until the operation is completed.

Example: Training a model

from marqtune.client import Client
from marqtune.enums import ModelType, DatasetType, InstanceType

url = "https://marqtune.marqo.ai"
api_key = "{api_key}"
marqtune_client = Client(url=url, api_key=api_key)
marqtune_client.train_model(
    dataset_id="dataset_id",
    model_name="test_model",
    base_model="ViT-B-32",
    base_checkpoint="laion2b_s34b_b79k",
    model_type=ModelType.OPEN_CLIP,
    max_training_time=600,
    instance_type=InstanceType.BASIC,
    hyperparameters={"leftKeys": ["query"], "rightKeys": ["my_image", "my_text"], "leftWeights": [1], "rightWeights": [0.9, 0.1] },
    wait_for_completion=True
)
# Train a model.
cURL -X POST 'https://marqtune.marqo.ai/models/train' \
-H "Content-Type: application/json" \
-H 'x-api-key: {api_key}' \
-d '{
    "datasetId": "dataset_id",
    "modelName": "test_model",
    "baseModel": "ViT-B-32",
    "baseCheckpoint": "laion2b_s34b_b79k",
    "modelType": "open_clip",
    "maxTrainingTime": 600,
    "hyperparameters": {"leftKeys": ["query"], "rightKeys": ["my_image", "my_text"], "leftWeights": [1], "rightWeights": [0.9, 0.1] },
    "instanceType": "marqtune.basic"
   }'

Response: 202 Accepted

Training task has been initalised and will now be executed.

{
    "statusCode": 202,
    "body": {
        "modelId": "model_id"
    }
}

Response: 400 (Invalid dataset)

Invalid dataset

{
    "statusCode": 400,
    "body": {
      "message": "Dataset must be of type 'training'"
    }
}

Response: 400 (Invalid base model)

Invalid base model

{
    "statusCode": 400,
    "body": {
      "message": "Model with id {base_model} not found"
    }
}

Response: 400 (Invalid checkpoint)

Invalid checkpoint

{
    "statusCode": 400,
    "body": {
      "message": "Invalid checkpoint. Available checkpoints: {checkpoints}"
    }
}

Response: 400 (Invalid hyperparameters)

Invalid hyperparameters are present in the data schema of the dataset

{
    "statusCode": 400,
    "body": {
      "message": "Invalid <left|right> key: <hyperparameter key> not found in the data schema"
    }
}

Response: 400 (Invalid hyperparameters)

Invalid weight key is present in the data schema of the dataset

{
    "statusCode": 400,
    "body": {
      "message": "Invalid weight key: <weight_key> not found in the data schema"
    }
}

Response: 400 (Invalid Request)

Request path or method is invalid.

{
    "statusCode": 400,
    "body": {
      "message": "Invalid request method"
    }
}

Response: 401 (Unauthorised)

Unauthorised. Check your API key and try again.

{
  "message": "Unauthorized."
}

Response: 500 (Internal server error)

Internal server error. Check your API key and try again.

{
  "message": "Internal server error."
}