This page was generated from doc/source/methods/mmddrift.ipynb.

source

Maximum Mean Discrepancy

Overview

The Maximum Mean Discrepancy (MMD) detector is a kernel-based method for multivariate 2 sample testing. The MMD is a distance-based measure between 2 distributions p and q based on the mean embeddings \(\mu_{p}\) and \(\mu_{q}\) in a reproducing kernel Hilbert space \(F\):

\begin{align} MMD(F, p, q) & = || \mu_{p} - \mu_{q} ||^2_{F} \\ \end{align}

We can compute unbiased estimates of \(MMD^2\) from the samples of the 2 distributions after applying the kernel trick. We use by default a radial basis function kernel, but users are free to pass their own kernel of preference to the detector. We obtain a \(p\)-value via a permutation test on the values of \(MMD^2\).

For high-dimensional data, we typically want to reduce the dimensionality before computing the permutation test. Following suggestions in Failing Loudly: An Empirical Study of Methods for Detecting Dataset Shift, we incorporate Untrained AutoEncoders (UAE), black-box shift detection using the classifier’s softmax outputs (BBSDs) and PCA as out-of-the box preprocessing methods. Preprocessing methods which do not rely on the classifier will usually pick up drift in the input data, while BBSDs focuses on label shift.

Detecting input data drift (covariate shift) \(\Delta p(x)\) for text data requires a custom preprocessing step. We can pick up changes in the semantics of the input by extracting (contextual) embeddings and detect drift on those. Strictly speaking we are not detecting \(\Delta p(x)\) anymore since the whole training procedure (objective function, training data etc) for the (pre)trained embeddings has an impact on the embeddings we extract. The library contains functionality to leverage pre-trained embeddings from HuggingFace’s transformer package but also allows you to easily use your own embeddings of choice. Both options are illustrated with examples in the Text drift detection on IMDB movie reviews notebook.

Usage

Initialize

Parameters:

  • p_val: p-value used for significance of the permutation test.

  • X_ref: Data used as reference distribution.

  • preprocess_X_ref: Whether to already apply the (optional) preprocessing step to the reference data at initialization and store the preprocessed data. Dependent on the preprocessing step, this can reduce the computation time for the predict step significantly, especially when the reference dataset is large. Defaults to True.

  • update_X_ref: Reference data can optionally be updated to the last N instances seen by the detector or via reservoir sampling with size N. For the former, the parameter equals {‘last’: N} while for reservoir sampling {‘reservoir_sampling’: N} is passed.

  • preprocess_fn: Function to preprocess the data before computing the data drift metrics. Typically a dimensionality reduction technique.

  • preprocess_kwargs: Keyword arguments for preprocess_fn. Again see the notebooks for image and text data for concrete, detailed examples. The built-in UAE, BBSDs or text-specific preprocessing steps are passed here as well. See below for a brief example.

  • kernel: Kernel function used when computing the MMD. Defaults to a Gaussian kernel.

  • kernel_kwargs: Keyword arguments for the kernel function. For the Gaussian kernel this is the kernel bandwidth sigma. We can also sum over a number of different kernel bandwidths. sigma then becomes an array with different values. If sigma is not specified, the detector will infer it by computing the pairwise distances between each of the instances in the 2 samples and set sigma to the median distance.

  • n_permutations: Number of permutations used in the permutation test.

  • chunk_size: Used to optionally compute the MMD between the 2 samples in chunks using dask to avoid potential out-of-memory errors. In terms of speed, the optimal chunk size is application and hardware dependent, so it is often worth to test a few different values, including None. None means that the computation is done in-memory in NumPy.

  • data_type: can specify data type added to metadata. E.g. ‘tabular’ or ‘image’.

Initialized drift detector example:

from alibi_detect.cd import MMDDrift
from alibi_detect.cd.preprocess import uae  # Untrained AutoEncoder

encoder_net = tf.keras.Sequential(
  [
      InputLayer(input_shape=(32, 32, 3)),
      Conv2D(64, 4, strides=2, padding='same', activation=tf.nn.relu),
      Conv2D(128, 4, strides=2, padding='same', activation=tf.nn.relu),
      Conv2D(512, 4, strides=2, padding='same', activation=tf.nn.relu),
      Flatten(),
      Dense(32,)
  ]
)
uae = UAE(encoder_net=encoder_net)

cd = MMDDrift(
    p_val=.05,
    X_ref=X_ref,
    preprocess_X_ref=True,
    preprocess_kwargs={'model': uae, 'batch_size': 128},
    kernel=gaussian_kernel,
    kernel_kwargs={'sigma': np.array([.5, 1., 5.])},
    chunk_size=1000,
    n_permutations=1000
)

Detect Drift

We detect data drift by simply calling predict on a batch of instances X. We can return the p-value and the threshold of the permutation test by setting return_p_val to True and the maximum mean discrepancy metric and threshold by setting return_distance to True.

The prediction takes the form of a dictionary with meta and data keys. meta contains the detector’s metadata while data is also a dictionary which contains the actual predictions stored in the following keys:

  • is_drift: 1 if the sample tested has drifted from the reference data and 0 otherwise.

  • p_val: contains the p-value if return_p_val equals True.

  • threshold: p-value threshold if return_p_val equals True.

  • distance: MMD metric between the reference data and the new batch if return_distance equals True.

  • distance_threshold: MMD metric value from the permutation test which corresponds to the the p-value threshold.

preds_drift = cd.predict(X, return_p_val=True, return_distance=True)

Saving and loading

The drift detectors can be saved and loaded in the same way as other detectors when using the built-in preprocessing steps (alibi_detect.cd.preprocess.UAE and alibi_detect.cd.preprocess.HiddenOutput) or no preprocessing at all:

from alibi_detect.utils.saving import save_detector, load_detector

filepath = 'my_path'
save_detector(cd, filepath)
cd = load_detector(filepath)

A custom preprocessing step can be passed as follows:

cd = load_detector(filepath, **{'preprocess_kwargs': preprocess_kwargs})