Accelerating AI: Implementing Multi-GPU Distributed Training for Personalized Recommendations

Susrutha Gongalla
- New York, NY

Stitch Fix uses a cutting-edge multi-tiered recommender system stack to personalize styling recommendations at scale. This stack comprises several critical components, including feature generation, scoring, ranking, and inventory optimization techniques.

Our scoring module is based on the Client Time Series Model (CTSM) which is an award winning novel sequence based model that uses temporally masked encoders. CTSM is built using PyTorch, and was initially trained on a single Graphics Processing Unit (GPU) instance. Since we first put this model into production last year, we have launched several updates to the model that improved its performance. Many of these improvements involved adding new features or increasing the time window of our training data. As a result, the model training time increased significantly, making it harder for us to iterate quickly and get feedback on new ideas we want to try for improving the model. We needed a way to reduce the model training time.

This blog delves into the steps we followed to overcome this challenge and our journey to implement multi-GPU distributed model training for CTSM. By sharding the training data across multiple GPUs and training multiple mini-batches in parallel, we aimed to achieve significant reductions in training time. We present empirical results showcasing the observed reduction in training time when we scaled up resources from 1 to N GPUs, and share some future directions we are considering in our continued effort to speed up model training.

Model Training Workflow

The scores generated by CTSM are leveraged by multiple downstream services to get insight into what items a client is likely to purchase. The model is retrained at a regular cadence to ensure that it is using the most updated information about each client when making predictions and does not degrade in its performance. We leverage configuration driven machine learning pipelines to set up a Directed Acyclic Graph (DAG) that automatically retrains the model, checks if the model performance is above a set threshold, and deploys the new model in production. Below is a high level overview of the entire workflow:

Diagram illustrating the training workflow of CTSM

Distributed Model Training for CTSM

Pytorch Lightning Trainer

As we worked on our effort to parallelize model training, we recognized the need to make an important design decision to ensure that our code was device-agnostic and scalable across multiple GPUs. Our goal was to minimize maintenance overhead for our team and keep our multi-processing code, which enables training on multiple GPUs, to a minimum. After considering various options, we ultimately chose to adopt the PyTorch Lightning framework for model training.

PyTorch Lightning is an open-source framework that provides a higher-level interface to PyTorch, making it easier to write code for training deep learning models. It includes built-in support for distributed model training using different strategies such as Data Parallel, Distributed Data Parallel, and more. By adopting this framework, we were able to reduce the amount of boilerplate training code we had to write and streamline the code for training CTSM on multiple GPUs.

However, it’s important to note that there are potential drawbacks to adopting this framework. For example, it requires your training code, such as the forward pass and optimization logic, to be structured in a specific way (in our case, we needed to refactor our code to adopt this framework for CTSM). It can also be challenging to customize the training loop or implement novel training strategies that are not already included in Lightning. However, we found that the benefits of using PyTorch Lightning outweighed these potential drawbacks for our project. Ultimately, the decision to use PyTorch Lightning should be based on a careful evaluation of the specific use case, and the available resources and expertise.

Data Parallel (DP) vs. Distributed Data Parallel (DDP)

Data Parallel and Distributed Data Parallel are the two most commonly used strategies for distributed model training with data sharding. There are key differences between these two approaches that we discuss below. Both strategies use common elements such as a data loader that is responsible for generating mini-batches to train the model, and the model itself with parameters to train using this training data.

Data Parallel (DP)

In the Data Parallel approach, the data loader and model are initialized as part of a base module, and we assume that we have N GPUs available for model training. During each training step, DP performs the following operations:

  1. The mini-batch is split into N segments of roughly equal size, and each segment is copied over to its corresponding GPU.
  2. The model object is replicated and copied over to each GPU.
  3. Each GPU executes a forward pass using its chunk of the mini-batch and backpropagation to compute gradients.
  4. These gradients are accumulated across all the GPUs and used to update the model object in the base module.

These steps are repeated for training each mini-batch until a set stopping criterion, such as the maximum number of epochs to train on, is reached. Steps 1 and 2 are both copy operations, whereas step 4 is a synchronization operation that gathers gradients computed across all the GPUs for each parameter.

Given below is a visual representation of steps executed by DP (steps in blue represent the operations performed during each training step)

Diagram illustrating the steps executed when using Data Parallel strategy for model training

Distributed Data Parallel (DDP)

DDP takes a slightly different approach. Similar to DP, the data loader and model are initialized as part of a base module, but it performs an additional step where it copies over the initialized model to each GPU. This model copy operation is only done once during initialization. DDP performs the following operations during each training step:

  1. Generate N new mini-batches (one for each GPU), and copy over each mini-batch to its corresponding GPU.
  2. Each GPU executes a forward pass on its own mini-batch and backpropagation to compute gradients.
  3. These gradients are accumulated across all GPUs and averaged. The average gradients are broadcasted to all the GPUs.
  4. Based on average gradients received, each GPU updates the parameters of its local copy of the model

DDP is significantly faster than DP because it only has one data copy operation (step 1). Step 3 is a communication operation that waits for all GPUs to compute gradients for each parameter and broadcasts average gradients back to the GPUs. Step 4 ensures that the model object is identical across all GPUs as they all use the same set of average gradients to update their parameters

Given below is a visual representation of the steps performed by DDP (steps in blue represent operations performed in each training step)

Diagram illustrating the steps executed when using Distributed Data Parallel strategy for model training

Based on our understanding of how DP and DDP parallelize model training, we expected DDP to be faster and more efficient for distributed model training. PyTorch documentation also recommends using the DDP strategy over DP because DP uses multi-threading, which could suffer from performance issues due to Python’s Global Interpreter Lock, whereas DDP uses multi-processing, spawning a new process on each GPU.

Therefore, we adopted the DDP strategy for training CTSM on multiple GPUs using the PyTorch Lightning trainer, as discussed above.

DDP training of CTSM

Executing model training job on multiple GPUs

We use the Flotilla framework to submit containerized jobs and execute CTSM training on our EKS cluster. By specifying resource requests for CPU, memory, and number of GPUs required for training, this approach allows us to share an EC2 instance that has multiple GPUs with other model training jobs. By doing so, we can make the most of each EC2 instance in the cluster, which leads to better utilization of resources and significant cost savings.

To test the scalability of CTSM training with DDP, we tried training using 1 to 8 GPUs on a single g5.48xlarge instance. As we increased the number of GPUs used, model training time per epoch went down as expected. The plot below shows average training time per epoch benchmarked against training time on 1 GPU (e.g. training time on 2 GPUs took 0.55 times that of training time on 1 GPU). When we compare these observed scaling factors with ideal scaling factors, we see that DDP does exceptionally well as parallelizing model training and comes very close to the ideal numbers as we scale up the number of GPUs used.

Training time benchmarked against that on 1 GPU

As the number of GPUs utilized for training increases, the communication overhead within the DDP process also increases. This overhead arises due to the need to synchronize and accumulate gradients across all GPUs involved in the training process. The following plot illustrates the average time consumed by communication overhead in each epoch. For instance, when training CTSM on 2 GPUs, approximately 9% of the training time is dedicated to coordinating gradient computations between the 2 GPUs. However, this overhead amplifies significantly to nearly 23% when employing 8 GPUs, as the coordination across a larger number of GPUs becomes more time-consuming.

Communication overhead represented as percentage of total training time per number of GPUs

Given that model training costs increase as the number of GPUs used for training increases, we carefully considered the number of GPUs we wanted to use. After weighing the trade-offs between training time and cost, we decided to train the model across 2 GPUs. This increased our training costs to ~2 times the cost of training on a single GPU, but it reduced the model training time to 0.55x that of the training time on a single GPU. We believe that this reduced training time will allow us to iterate faster and improve our key business metrics, which will outweigh the additional training costs incurred.

Model performance monitoring

We use tensorboard utility within PyTorch for monitoring model performance, and the distribution of model weights and biases along with their gradients. The graphs below show a time-series distribution of loss computed per mini-batch for CTSM trained on 2 GPUs, with orange and blue colors representing separate GPUs. X-axis shows the number of training steps, while the y-axis shows the corresponding losses on training and validation data (shown as ‘training_loss’ and ‘test_loss’ respectively in the graphs below). Due to different mini-batches used by each GPU, there are slight variations in the exact loss values, but applying smoothing shows that both GPUs follow a similar trend where loss is reduced over time.

Plot showing the progress of loss on training data as we train the model on more epochs
Plot showing the progress of loss on test data as we train the model on more epochs

Productionizing the model

We use a backtesting process to assess CTSM performance on a held-out validation dataset. When we refactored CTSM to use DDP, we observed a slight drop in backtesting performance due to a change in the effective batch size. In particular, when training with one GPU, we called the optimizer to update weights and biases after each mini-batch. But with two GPUs using DDP, we use average gradients across two mini-batches to update weights and bases. Hence, the optimizer is only called once per two mini-batches, which effectively doubles the batch size used for training. As batch size is an important hyper parameter that requires careful tuning, this change negatively impacted the model performance.

To address this, we conducted an exercise of hyperparameter optimization and tweaked the learning rate scheduler to enhance model performance. Ultimately, we trained the model using DDP on 2 GPUs, with new hyperparameters, and verified using backtesting that this change did not negatively impact the model performance. We then launched a do-no-harm online experiment that further tested the hypothesis that this change did not negatively impact any of our business metrics. This experiment was successful and we launched this model in production for all our traffic earlier this year.

Learnings and Next Steps

Clear design requirements, particularly the ability to scale up training across more GPUs while reducing maintenance overhead on the team, have already paid off. We recently launched an experiment that requires training the model on significantly more data. Our design enabled us to easily scale up training from 2 to 4 GPUs with minimal code changes.

While we’ve only utilized multiple GPUs on a single physical instance so far, we recognize that DDP can also parallelize training across multiple physical instances, which would require setting up a GPU cluster or additional infrastructure provisioning. We plan to explore this option in the future as we increase our training data and potentially require parallelization beyond a single EC2 instance.

Currently, we’re using PyTorch 1.x, but we’re considering migrating to PyTorch 2.0 to pre-compile the model graph before training. While this has the potential to reduce training time, early experiments with PyTorch 2.0 have shown mixed results in terms of potential speedup for our training jobs. This is because pre-compiling the graph prevents DDP from applying optimizations in its communication step that make it exceptionally fast. This issue is explained in detail in this informative blog.

Additionally, we’re exploring the use of automatic mixed precision to switch to float16 precision wherever possible without degrading model performance (e.g. fully connected layers)

Overall, we’ll continue to explore various techniques to speed up model training as we increase our training data or network complexity. This will also allow us to maintain our developer velocity, retrain the model at a regular cadence, and ensure that we are using the most updated model artifact for making personalized recommendations to our customers.

Tweet this post! Post on LinkedIn

Come Work with Us!

We’re a diverse team dedicated to building great products, and we’d love your help. Do you want to build amazing products with amazing peers? Join us!