Shapley.jl
This is a package for computing Shapley values of machine learning predictions.
Installation
The package can be added with
Pkg.add("Shapley")
or via ] add Shapley
in the REPL.
Introduction
Shapley values can be used with machine learning methods to estimate the contribution of features to a prediction for a particular model, sometimes called "feature importance". The basic idea is to compute the contribution of a feature to a particular data point averaged over all possible combinations of features. For feature $i$, the shapley value can be computed as
\[\phi_{i} = \sum_{S\subseteq N\backslash \{i\}} \frac{|S|!(|N| - |S| - 1)!}{|N|!}\left(v_{x}(S\cup\{i\}) - v_{x}(S)\right)\]
where $N$ is the set of all features and $v_{x}$ is a "value function" for data point $x$. The coefficients in the sum can be thought of as normalization. For machine learning, we define the following value function
\[v_{x}(S) = \int d\xi \left( p(\xi_{i}\notin S) \hat{f}(\xi) - p(\xi)\hat{f}(\xi) \right)\]
where $\hat{f}$ is the prediction function associated with our machine learning model and the components of $\xi$ not integrated over are fixed to those of the value $x$. The integration in the first term is over the probability distributions of all features not in $S$ and the expectation value is that of $\hat{f}$ over the entire input distribution. The value function can therefore be thought of as the difference between the expectation value of $\hat{f}$ with only the chosen features fixed and the overall expectation value of $\hat{f}$.
The above definitions give the Shapley values a fairly straightforward intuitive meaning (this seems more obvious if you ignore the combinatoric normalization coefficients, though these can still be important to consider in how they weight the various terms).
The appeal of the Shapley values is that they have several properties which other methods for describing "feature importance" lack, in particular
- Efficiency: The Shapley values satisfy the sum rule $\sum_{i}\phi_{i} = v_{x}(N)$ for each data point $x$. This ensures that, for example, multiple features cannot simultaneously have large positive Shapley values for the same data point.
- Symmetry: If $v_{x}(S\cup\{i\}) = v_{x}(S\cup\{j\})$ for all $S$ such that $i, j \notin S$ then $\phi_{i} = \phi_{j}$. This ensures that features with equivalent "predictive power" have identical Shapley values.
- Null Contribution: If a feature makes no contribution to the prediction for $x$, that is $v(S\cup\{i\}) = v(S)$ for all $S$ such that $i\notin S$, then $\phi_{i}=0$. This ensures that features with "small" Shapley values are irrelevant.
- Additivity: For a model which gives predictions that can be expressed as the sum of two models, i.e. $\hat{f}(X)=\hat{f}_{1}(X) + \hat{f}_{2}(X)$, the Shapley value for $\hat{f}$ is the sum of the Shapley values for $\hat{f}_{1}$ and $\hat{f}_{2}$. This means that, for example, the Shapley value of a random forest is the mean of Shapley values of the constituent trees.
Since Shapley values must be computed for individual data points, they do not directly provide a global notion of "feature importance". This is, however, a significant advantage of this approach. If a model has very different Shapley values for different data points, this is telling us something about the underlying model, in accordance with the properties listed above. It is common practice to compute statistics of the Shapley values such as the mean absolute value, root mean square or standard deviation to describe the overall feature importance.
From the above discussion, it should be apparent that it is usually impossible to compute the Shapley values exactly in practice, as, for one, it would require us to know the probability distribution of $X$ and, for large numbers of features, it contains a very large number of terms each requiring evaluations of $\hat{f}$.
Basic Usage
This package exports the shapley
function, which can be used to compute the Shapley values of all points in a data set. For example, suppose we have trained some model using MLJ
using Shapley, MLJ, DataFrames
import RDatasets
using MLJDecisionTreeInterface: RandomForestRegressor
boston = RDatasets.dataset("MASS", "Boston")
y, X = unpack(boston, ==(:MedV), col -> true)
m = machine(RandomForestRegressor(), X, y)
fit!(m)
Here X
is a table in which each row is a single data point. We can compute a Vector
of Shapley values for the dataset with
ϕ = shapley(x -> predict(m, x), Shapley.MonteCarlo(512), X, :LStat)
The argument Shapley.MonteCarlo(512)
specifies that we wish to compute the Shapley value using the monte carlo method with 512
iterations. In this case, X
is used both as the dataset for which to compute the Shapley values, as well as an empirical probability distribution for X
which the monte carlo algorithm needs to make the estimate.
Because there are many machine learning algorithms which can only efficiently compute predictions in large batches (i.e. there is a high overhead for each evaluation), Shapley.jl will call the prediction function on the largest datasets possible, rather than making separate calls for each data point.
In the above example, ϕ
is a Vector{Float64}
of Shapley values of the feature :LStat
for the rows of X
(the feature can be specified by either a Symbol
or Integer
feature index of X
). One can then compute some helpful statistics of the Shapley values, for example
mean(abs.(ϕ)), std(ϕ)
To obtain a table of statistics describing all Shapley values, one can do
df = DataFrame(Shapley.summary(x -> predict(m, x), Shapley.MonteCarlo(512), X))
The function Shapley.summary
returns a table in the form of a Vector
of NamedTuple
s, which can be converted to a DataFrame
for convenience.
See the section on Algorithms and Parallelism for a description of available algorithms and how they should be specified.
Compatibility
For broad compatibility across the Julia package ecosystem, all Shapley.jl functions which use prediction functions require these functions to take arguments with one of the following forms
- An
AbstractMatrix
the rows of which are individual data points. - A Tables compatible table object, e.g. a
DataFrame
or aNamedTuple
ofAbstractVector
s of common length.
and produce outputs with one of the following forms
- An
AbstractVector
ofNumber
s. - An
AbstractVector
ofDistributions.Sampleable
objects.
Thanks to the flexibility of the Tables interface, this means that the vast majority of machine learning methods from existing packages will be compatible with Shapley.jl "out of the box", as in the above example.
If your prediction function does not satisfy the above requirements, we recommend wrapping the model in MLJ using the MLJ model interface. Alternatively, you can put any compatibility operations directly into the prediction function that you provide as the first argument to shapley
.
Probabilistic Models and Classifiers
Shapley values can be most starightforwardly calculated for regressions which retrun real numbers as predictions, however, they can also be computed for classifications. In these cases, it is customary to compute a separate Shapley value for each class in the classification. In this case, the total number of Shapley values for an individual data point is the number of features multiplied by the number of classes.
Shapley.jl expects models to return classification results as discrete probability distributions using the Distributions Sampleable
interface. When returning Shapley values for classifiers, rather than returning a Vector{Float64}
of Shapley values, a table will be returned the columns of which are the elements of the support of the distribution. The values returned by the classification prediction must be AbstractVector
s of Distributions.Sampleable
objects with all objects having identical support.
This works "out of the box" (i.e. without need for modificatoin) with MLJ. See this example of computing Shapley values for a classifier with MLJ.
Not all algorithms are currently required to support classification outputs. For these algorithms supports_classification(algo)
will return false
. To get classification output from these algorithms, one can provide a predict
function which returns a vector which gives the probability of a particular classification for each data point.
Shapley.summary
does not currently support classification. In these cases it's best to start from shapley
instead.