Uncovering the Triplet Loss Function
What is it? Why/when is used?
At a high level, the triplet loss is an objective function designed to pull embeddings of similar classes together and repel the embeddings of dissimilar classes so that there is at least a certain "margin" distance between embeddings of examples from different classes. Ultimately, with this loss function, it should be possible to distinguish classes by their embeddings. Initially and still today, it is most commonly applied to facial recognition where several pictures of the same person passed through a CNN should produce embeddings within a certain L1 or L2 distance of each other.
After its initial adaptation though, some researchers realized that with the triplet loss, they could force visual models to also learn about common actions by pairing together images of the same actions from scenes with drastically different visual features. This is most notably the case in Google's excellent "Time-Contrastive Networks" paper. Since then, other papers have also broached the use of the triplet loss function in various tasks; however, because the cross-entropy loss is so effective and triplet batch generation is difficult, the triplet loss function is unquestionably under-appreciated.
How is the triplet loss function defined?
Mathematical expression:
loss = max(0, margin + distance_func(embed_anchor, embed_positive) - distance_func(embed_anchor, embed_negative)), where:
-
margin is the typically a desired minimum distance between the normalized layer embeddings of positive and negative examples
-
distance_func is normally either the L1 or L2 distance function
-
embed_<example_type> are the unit normalized (i.e. sum of the squares of all activation equals 1) embeddings of one of the final layers of a CNN
Practical explanation:
Initially, in an untrained model, the layer embeddings for the anchor and negative examples will have a comparable distance between them as the layer embeddings for the anchor and positive examples. Therefore, the resulting loss value (according to the equation above) will be close to the margin. After training, the loss is minimized over time so eventually the positive examples' embeddings will get grouped together and the negative examples' embeddings will get pushed away hopefully into their own clustered embedding space too.
How to generate triplet batches?
Unlike a normal loss function where the model input is expected to be a single item, the triplet loss objective function requires model inputs to be pairs of three items. Accordingly, in many DL frameworks, you will simply pass the triplet to the model as three sequential input items where the first element in the sequence is the anchor, the second element is a different item from the same class as the anchor, and the third element is an item from another class. In other words, a batch of 12 triplet pairs will be presented as a batch of 36 inputs items (i.e. 12 pairs of 3 sequential triplet items) to the model.
Triplet pairs can be generated either online or offline. In the offline case, triplet batches will be pre-computed and stored on disk to avoid the computational cost to having to order them at training time. This approach can consume a lot of disk space and is normally not optimal for training efficiency or final performance of the trained model. More commonly, triplet batches are generated "online" during the training process. This approach normally leverages current model's weights to ensure that the triplet generator continues to produce a high percentage of triplet pairs providing a strong error signal to the model.
In the near universally used online case, there are a variety of approaches for selecting triplet pairs. As the most basic, the generator can always yield the same difficulty triplets pairs. However, this rigid strategy can be problematic. If you're picking easy triplet pairs, a majority of the triplets will quickly become too easy and fail to provide any kind of learning signal to the DNN. However, if the triplets are too hard, learning can become unstable, especially in the early stages.
As an alternative, you can gradually/dynamically increase difficulty level of the triplets being passed to the model. This is a sort of curriculum approach to generating triplets and generally results in more efficient training because the difficult of training examples are adapted to fit the progress of training the model. Here are some of training difficulty levels typically selected using this dynamic approach:
-
Level 1
-
examples where distance(embed_anchor, embed_negative) - distance(embed_anchor, embed_positive) are greater than 0 but less than full margin value
-
therefore, the triplets will not be so hard that they cause instability in learning; however, they will still be hard enough for model to get some learning feedback
-
-
Level 2
-
examples where distance(embed_anchor, embed_negative) - distance(embed_anchor, embed_positive) are less than the full margin value
-
these are include the triplets in level 1 but also feature harder triplets where the difference between embedding distances can be negative
-
this is a more difficult super-set of level 1
-
-
Level 3
-
examples where distance(embed_anchor, embed_negative) - distance(embed_anchor, embed_positive) are less than 0
-
these are some of the hardest triplets because the negative embeddings are closer than the positive embeddings to the anchor embeddings
-
-
Level 4
-
for each anchor, you should select the hardest positive (i.e. furthest positive embeddings) and negative example (i.e. closest negative embeddings)
-
these are the absolute hardest possible triplets and should probably only be used towards the middle and end of the training process to realize the best possible final model
-
this selection technique comes directly from the "In Defense of Triplet Loss" paper
-
How to use an implemented generator?
Code (similar to the 'train_triplet_model' function in 'dl_utilites/trip_utils'):
# Build and convert model (to one that combines triplet loss and cross-entropy loss)
model = scratch_model_func() # Wrapper function with no parameters to get desired model having
# two layers named 'final_embeddings' and 'predictions', respectively
# opt - cross-entropy loss function optimizer
# margin - value for margin to be used in loss function and in generating triplets
# expected_loss - realistic desired/expected total/final loss (given the percent_trip_loss loss value below)
# total/final loss = (trip_loss_val * percent_trip_loss) + (xentropy_loss_val * (1-percent_trip_loss))
# percent_trip_loss - portion of total/final loss derived by triplet loss value
# is_L2 - use L2 distance function (as oppposed to L1 distance) for triplet loss function
model = convert_model_to_trip_model(model, SGD(lr=0.1), margin,
expected_loss, percent_trip_loss,
is_L2)
# Get list of lists where each one contains the image indices for each category
train_images, train_labels = train_data
class_indices = break_down_class_indices(train_labels)
# Start generator for triplet selections
if not trip_worker.is_ready_for_start():
trip_worker.stop_activity()
# Dynamic margin flag - adjusts current margin automatically if it is too hard or too easy
# Warm_up_examples_per_anchor - max # of triplets from an anchor to avoid overuse of a subset of images
trip_worker.start_activity(train_data, class_indices, scratch_model_func,
batch_size, margin, dynamic_margin, is_L2,
image_aug, warm_up_examples_per_anchor,
hardness=2.99, skip_percentage=0.4)
....
# Multiply_factor - way to reduce training time on earlier/easier iterations
for hardness in hardness_vals:
hist = model.fit_generator(trip_worker,
steps_per_epoch=
(train_images.shape[0] // (batch_size * multiply_factor)),
epochs=(iter + step_size),
initial_epoch=iter,
callbacks=callbacks)
trip_worker.set_new_hardness_value(hardness) # manually change the difficulty of triplets
...
trip_worker.stop_activity()
Explanation:
The code above is a basic outline of the correct usage of a custom triplet generator object in the 'dl_utilities' repo. The code comes predominantly from the 'train_triplet_model' function in the 'dl_utilites/trip_loss/trip_utils' module.
Preceding this code and in general to perform triplet loss training, you will have to create a triplet generator object (as done in the main training script for my Udacity capstone project.) Because TensorFlow cannot/should not be initialized twice in a single process, the main training thread must create a multi-process triplet generator before initializing TensorFlow itself. The main training thread can then re-use that generator (with its 'start_activity' and 'stop_activity' member functions) for training every future model.
In the code example here, a normal cross-entropy model is created and then converted into a model factoring the triplet loss into its objective function (using the 'convert_model_to_trip_model' function). After that, the triplet loss generator is started and passed to the training function (e.g. the Keras 'fit_generator' call). As training progresses, the difficulty of the triplet batches being produced by the 'trip_worker' generator is gradually increased to pass through each of the difficulty levels (from above) using the 'set_new_hardness_value' member function.