We are excited to announce Diamond, an open-source Python solver for certain kinds of generalized linear models. This post covers the specifics of Diamond. The sister post covers the mathematics used by Diamond. If you just want to use the package, check out the Github page.
Announcing Diamond: Scaling Mixed-Effects Models
Multilevel/hierarchical modeling is an important tool in a data scientist’s toolkit. At Stitch Fix, we often build models to predict binary outcomes like whether a particular client will love the fit of a shirt. There are many ways to approach this problem; one of the simplest is logistic regression. The features in the model might include the brand of the shirt, the client’s height, and the shirt’s material. But what if you have 10,000 brands? We might have observed some of them thousands of times, while observing others only a handful of times. If we used a dummy variable encoding, we would lose a lot of information if there is correlation between the fit of different brands. One approach that we like is to use a Generalized Linear Mixed-effects Model (GLMM), which is particularly well suited for grouped data. In this example, our groups are clothing brands.
In this example, GLMMs allow you to pool information across different brands, while still learning individual effects for each brand. It breaks the problem into sets of fixed and random effects. The fixed effects are similar to what you would find in a traditional logistic regression model, while the random effects allow the regression relationship to vary for each brand. One of the advantages of GLMMs is that they learn how different brands are from each other. Brands that are very similar to the overall average will have small random effect estimates. Because of the regularization of these models, brands with few observations will also have small random effect estimates, and be treated more like the overall average. In contrast, for brands that are very different from the average, with lots of data to support that, GLMMs will learn large random effect estimates.
In order to fit this GLMM, we will wind up working with an R-style formula which looks something like
loved_fit ~ 1 + height + (1 | client_id) + (1 + height | brand)
We would fit this model on (client, item) pairs, where the client rated the fit of the item, and we know various things about each client (including her height) and each item (its brand). In this model formula, the fixed effects measure
1: average fit rating
height: the average relationship between fit and height. This could reveal patterns such as very short or very tall clients have lower average fit feedback than clients of average height
The first random effects term,
(1 | client_id), captures the variation between clients: some clients give consistently high average feedback, while others may give consistently low feedback. This is known as “user bias” in some machine learning contexts, and it’s important to account for it.
The second random effects term,
(1 + height | brand), allows the relationship between height and fit to vary by brand. The random intercept here captures average feedback by brand. Just like clients, some brands get consistently higher feedback than others. The random slope,
height | brand, allows us to learn patterns like
- brand X fits tall clients better than short clients (positive random slope)
- brand Y fits short clients better than tall clients (negative random slope)
- brand Z fits clients of all heights equally well (random slope of 0)
Fitting a GLMM is generally a two-step procedure.
the covariance matrix for each random effects term is estimated. In our example, the covariance matrix for the first term is simply the variance of the client bias terms, a scalar. The covariance matrix of the second random effects term measures how much variation there is between different brands.
This covariance matrix is used to compute an estimate for all of the model coefficients, fixed and random alike.
These two steps are repeated in an iterative fashion. There are a number of packages in various languages which will perform both steps. In R, the ones we have used include LME4 and mbest, which uses a method of moments approach. In python, statsmodels has an implementation for linear mixed effects models. LME4 has also been ported to Julia.
Our journey with GLMMs started with the R implementation of LME4, which allows for fitting of models with complicated random effects structures using maximum likelihood methods. This worked well until we wanted to fit more complex models on larger datasets, and fitting the models started to take multiple hours or even days. At this point, we switched to using mbest, which uses a method of moments approximation to speed up the model fitting. At the time of this blog post, however, mbest is only able to fit nested mixed effects models. We needed a solution which would scale with both model complexity and data volume. We considered downsampling our data, simplifying our models, using LinkedIn’s spark-based Photon ML, as well as abandoning GLMMs entirely in favor of simple \(L_2\)-regularized logistic regression. We weren’t happy with any of these options, so we developed Diamond.
Our solution – which we are open sourcing as the python package Diamond – is to compute the covariance matrix once, fix it, and then solve the sub-problem of estimating the model coefficients using the fixed covariance matrix. This results in a convex problem which is amenable to the toolkit of convex optimization. This sub-problem can be recast as an \(L_2\)-regularized logistic regression problem, with a matrix of penalties rather than the usual scalar \(\lambda\). See first blog post for derivation.
Types of Problems Solved by Diamond
As previously stated, Diamond is useful in situations where you are trying to find coefficients for a GLMM and know the covariance structure (or can take a guess at it) a priori. We have also extended Diamond to solve cumulative logistic regression problems with a known covariance structure. Cumulative logistic regression is designed for ordinal responses \(Y = 1, 2, ..., J\) and the model we use is\[logit(Pr(Y_i \leq j)) = \alpha_j + \beta^T x_i\]
Determining Covariance Structure
The covariance matrices for the random effects can be determined in a number of ways:
- full Bayesian methods, learned by MCMC sampling. Here’s an example using PyMC3.
- A brute force approach, such as a grid search over hyperparameters combined with cross validation. This may work for simple covariance matrices, but sacrifices the rich complexity of models with many random slopes.
- empirical Bayes methods, i.e. estimating the variances from the data, as mbest and LME4 do. There is also some exciting research from the Stanford statistics department in this area.
In practice, we tend to use mbest to find the covariance matrix once (perhaps on a subset of the data) and then use that estimate of the covariance to determine the model coefficients daily using Diamond. We then update the covariance matrix on a slower cadence, such as weekly or monthly, depending on the stationarity of the problem.
Relative Speed of Diamond vs LME4 and mbest
We simulated data for a slightly simplified model compared to the above example (ignoring height) which included over 1 million random effect levels (brands), with 10 million rows of training data. Using a formula like:
loved_fit ~ 1 + height + (1 + height | brand)
We were able to fit the model in Diamond in ~20 minutes using a known covariance matrix. This model could not be fit in LME4 or mbest in a reasonable amount of time (we tried very hard). The reason that Diamond is so fast is because we use the Hessian of the likelihood function, with some special tweaks. See the sister post for the full details.
Diamond is hosted on github and can be forked/cloned and installed locally. It uses cython for a few of the linear algebra steps. There is an
examples folder with a few jupyter notebooks demonstrating basic use cases. We also have a docker image where you can try it out and run some examples using jupyter notebooks. See the readme in the github repo for more information. Happy model fitting!