Skip to content

Training Parameters

Required Training Parameters

Name Type Description
leftKeys list[str] List of column names for the left features, typically text fields.
leftWeights list[float] List of weights for the left features.
rightKeys list[str] List of column names for the right features. The first field should be an image, and the others should be text fields.
rightWeights list[float] List of weights for the right features.
weightKey str CSV column name used to represent weights.

Optional Training Parameters

Name Type Default value Description
accumFreq int 1 Update the model every specified number of steps.
batchSize int 256 Batch size per GPU.
beta1 float 0.9 Adam optimizer's beta1 parameter.
beta2 float 0.98 or 0.999 Adam optimizer's beta2 parameter. If model name contains vit, then beta2 default is 0.98 else beta2 default is 0.999.
contextLength int 77 Maximum number of tokens in the input text to train with.
cocaCaptionLossWeight float 2.0 Weight assigned to caption loss in CoCa.
cocaContrastiveLossWeight float 1.0 Weight assigned to contrastive loss in CoCa.
dynamicBatchSize bool True Whether to use dynamic batch size. If True, the batch size will be adjusted to fit the GPU memory.
epochs int 5 Number of epochs for which to train
epochsCooldown int "" Perform cooldown epochs from the total epochs minus cooldown epochs onward.
eps float 1.0e-6 or 1.0e-8 Adam optimizer's epsilon value. If model name contains vit, then eps default is 1.0e-6 else eps default is 1.0e-8.
forceCustomText bool False Force use of a custom text model.
forcePatchDropout float "" Override the patch dropout during training.
forceQuickGELU bool False Force the use of QuickGELU activation.
frozenRight bool False Whether to use sampling with replacement for right-side web dataset shard selection.
gatherWithGrad bool True Enable full distributed gradient for feature gathering.
gradCheckpointing bool True Enable gradient checkpointing.
gradClipNorm float "" Gradient clipping norm.
imageMean list[float] "" Override default image mean values.
imageStd list[float] "" Override default image standard deviations.
localLoss bool False Calculate loss with local features.
lockImage bool False Lock the image tower by disabling gradients.
lockImageFreezeBnStats bool False Freeze BatchNorm running stats in locked image tower layers.
lockImageUnlockedGroups int 0 Leave the last n image tower groups unlocked.
lockText bool False Lock the text tower by disabling gradients.
lockTextFreezeLayerNorm bool False Freeze layer norm running stats in locked text tower layers.
lockTextUnlockedLayers int 0 Leave the last n text tower layers unlocked.
logitBias float "" Initialization of the logit bias.
logitScale float "" Initialization of the logit scale.
lr float 5.0e-4 Learning rate.
lrCooldownPower float 1.0 Power for the polynomial cooldown schedule.
lrScheduler str cosine Learning rate scheduler. One of cosine, const, const-cooldown
precision str amp Floating-point precision to use. One of amp, amp_bf16, amp_bfloat16, bf16, fp16, fp32.
poolingMethod str None Pooling method to use, default value is model dependent. One of mean, cls, max, cls_last_hidden_state.
saveFrequency int 1 How often to save checkpoints.
seed int 0 Random seed for training consistency.
skipScheduler bool False Skip the learning rate decay.
sqInterval int 2 Number of normal iterations between a same-query iteration.
sqLogitScale int 1.5 Logit scale for the same-query batch.
trace bool False Trace the model for inference or evaluation only.
useBnSync bool False Whether to use batch normalization synchronization.
valFrequency int 1 Frequency to run evaluation with validation data.
warmup int 10000 Number of steps for warmup.
wd float 0.2 Weight decay.
weightedLoss str ce Type of loss function used for weighted training. One of ce or siglip.
workers int 1 Number of data loader workers per GPU.
zeroshotFrequency int 2 Frequency to run zero-shot evaluation.