Path: blob/master/recsys/factorization_machine/factorization_machine.ipynb
2579 views
Factorization Machine (FM)
Factorization Machine type algorithms are a combination of linear regression and matrix factorization, the cool idea behind this type of algorithm is it aims model interactions between features (a.k.a attributes, explanatory variables) using factorized parameters. By doing so it has the ability to estimate all interactions between features even with extremely sparse data.
Introduction
Normally, when we think of linear regression, we would think of the following formula:
Where:
is the bias term, a.k.a intercept.
are weights corresponding to each feature vector , here we assume we have total features.
This formula's advantage is that it can computed in linear time, . The drawback, however, is that it does not handle feature interactions. To capture interactions, we could introduce a weight for each feature combination. This is sometimes referred to as a ordered polynomial. The resulting model is shown below:
Compared to our previous model, this formulation has the advantages that it can capture feature interactions at least for two features at a time. But we have now ended up with a complexity which means that to train the model, we now require a lot more time and memory. Another issue is that when we have categorical variables with high cardinality, after one-hot encoding them, we would end up with a lot of columns that are sparse, making it harder to actually capture their interactions (not enough data).
To solve this complexity issue, Factorization Machines takes inspiration from matrix factorization, and models the feature interaction using latent factors. Every feature has a corresponding latent factor , and two features' interactions are modelled as , where refers to the dot product of the two feature vector. If we assume its of size (this is a hyperparameter that we can tune). Then:
This leads of our new equation:
This is an improvement from our previous model (when we modeled each pair of interaction terms with weight ) as the number of parameters is reduced from to , since , which also helps mitigate overfitting issues. Using the naive way of formulating factorization machine results in a complexity of , because all pairwise interactions have to be computed, but we can reformulate it to make it run in .
Note, summing over different pairs is the same as summing over all pairs minus the self-interactions (divided by two). This is the reason why the value 1/2 is introduced from the beginning of the derivation.
This reformulated equation has a linear complexity in both and , i.e. its computation is in , substituting this new equation into the existing factorization machine formula, we end up with:
In a machine learning setting, factorization machine can be applied to different supervised prediction tasks:
Regression:, in this case can be used directly by minimizing the mean squared error between the model prediction and target value, e.g.
Classification:, if we were to use it in a binary classification setting, we could then minimize the log loss, , where is the sigmoid/logistic function and .
To train factorization machine, we can use a gradient descent based optimization techniques, the parameters to be learned are and ).
Notice that does not depend on , thus it can be computed independently.
The last formula above, can also be written as .
In practice, we would throw in some L2 regularization to prevent overfitting.
As the next section contains implementation of the algorithm from scratch, the gradient of the log loss is also provided here for completeness. The predicted value is replaced with for making the notation cleaner.
Advantages: We'll now wrap up the theoretical section of factorization machine, with some of its advantages:
We can observe from the model equation that it can be computed in linear time.
By leveraging ideas from matrix factorization, we can estimate higher order interaction effects even under very sparse data.
Compared to traditional matrix factorization methods, which is restricted to modeling a user-item matrix, we can leverage other user or item specific features making factorization machine more flexible.
Implementation
For the implementation of factorization machine, we'll use a for loop based code as I personally find it easier to comprehend for the gradient update section. There are different ways to speed up for loop based code in Python, such as using Cython or Numba, here we'll be using Numba.
There are various open-sourced implementations floating around the web, here are the links to some of them:
I personally haven't tested which one is more efficient, feel free to grab one of them as see if it helps solve your problem.