This page was generated from examples/cd_online_camelyon.ipynb.

Online drift detection for Camelyon17 medical imaging dataset

This notebook demonstrates a typical workflow for applying online drift detectors to streams of image data. For those unfamiliar with how the online drift detectors operate in alibi_detect we recommend first checking out the more introductory example Online Drift Detection on the Wine Quality Dataset where online drift detection is performed for the wine quality dataset.

Install the wilds library to fetch the dataset used in the example:

pip install wilds

This notebook requires the wilds, torch and torchivision packages which can be installed via pip:

[ ]:
!pip install wilds torch torchvision
[1]:
from typing import Tuple, Generator, Callable, Optional
import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
from torch.utils.data import TensorDataset, DataLoader
import torchvision.transforms as transforms
from wilds.common.data_loaders import get_train_loader
from wilds import get_dataset

torch.manual_seed(0)
np.random.seed(0)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
%matplotlib inline

Data

We will use the Camelyon17 dataset, one of the WILDS datasets of Koh et al, (2020) that represent “in-the-wild” distribution shifts for various data modalities. It contains tissue scans to be classificatied as benign or cancerous. The pre-change distribution corresponds to scans from across three hospitals and the post-change distribution corresponds to scans from a new fourth hospital.

bb21dbc4c6544d19b9d494206b05054a

Koh et al, (2020) show that models trained on scans from the pre-change distribution achieve an accuracy of 93.2% on unseen scans from same distribution, but only 70.3% accuracy on scans from the post-change distribution.

First we create a function that converts the Camelyon dataset to a stream in order to simulate a live deployment environment. We extract N instances to act as the reference set on which a model of interest was trained. We then consider a stream of images from the pre-change (same) distribution and a stream of images from the post-change (drifted) distribution.

[2]:
WILDS_PATH = './data/wilds'
DOWNLOAD = False  # set to True for first run
N = 2000  # size of reference set

The following cell will download the Camelyon dataset (if DOWNLOAD=True). The download size is ~10GB and size on disk is ~15GB.

[3]:
def stream_camelyon(
    split: str='train',
    img_size: Tuple[int]=(96,96),
    root_dir: str=None,
    download: bool=False
) -> Generator:

    camelyon = get_dataset('camelyon17', root_dir=root_dir, download=download)
    ds = camelyon.get_subset(
        split,
        transform=transforms.Compose([transforms.Resize(img_size), transforms.ToTensor()])
    )
    ds_iter = iter(get_train_loader('standard', ds, batch_size=1))

    while True:
        try:
            img = next(ds_iter)[0][0]
        except Exception:
            ds_iter = iter(get_train_loader('standard', ds, batch_size=1))
            img = next(ds_iter)[0][0]
        yield img.numpy()

stream_p = stream_camelyon(split='train', root_dir=WILDS_PATH, download=DOWNLOAD)
x_ref = np.stack([next(stream_p) for _ in range(N)], axis=0)

stream_q_h0 = stream_camelyon(split='id_val', root_dir=WILDS_PATH, download=DOWNLOAD)
stream_q_h1 = stream_camelyon(split='test', root_dir=WILDS_PATH, download=DOWNLOAD)
Downloading dataset to ./data/wilds/camelyon17_v1.0...
You can also download the dataset manually at https://wilds.stanford.edu/downloads.
Downloading https://worksheets.codalab.org/rest/bundles/0xe45e15f39fb54e9d9e919556af67aabe/contents/blob/ to ./data/wilds/camelyon17_v1.0/archive.tar.gz
---------------------------------------------------------------------------
KeyboardInterrupt                         Traceback (most recent call last)
/tmp/ipykernel_2913127/866419896.py in <module>
     22
     23 stream_p = stream_camelyon(split='train', root_dir=WILDS_PATH, download=DOWNLOAD)
---> 24 x_ref = np.stack([next(stream_p) for _ in range(N)], axis=0)
     25
     26 stream_q_h0 = stream_camelyon(split='id_val', root_dir=WILDS_PATH, download=DOWNLOAD)

/tmp/ipykernel_2913127/866419896.py in <listcomp>(.0)
     22
     23 stream_p = stream_camelyon(split='train', root_dir=WILDS_PATH, download=DOWNLOAD)
---> 24 x_ref = np.stack([next(stream_p) for _ in range(N)], axis=0)
     25
     26 stream_q_h0 = stream_camelyon(split='id_val', root_dir=WILDS_PATH, download=DOWNLOAD)

/tmp/ipykernel_2913127/866419896.py in stream_camelyon(split, img_size, root_dir, download)
      6 ) -> Generator:
      7
----> 8     camelyon = get_dataset('camelyon17', root_dir=root_dir, download=download)
      9     ds = camelyon.get_subset(
     10         split,

~/.conda/envs/py38/lib/python3.8/site-packages/wilds/get_dataset.py in get_dataset(dataset, version, **dataset_kwargs)
     24     elif dataset == 'camelyon17':
     25         from wilds.datasets.camelyon17_dataset import Camelyon17Dataset
---> 26         return Camelyon17Dataset(version=version, **dataset_kwargs)
     27
     28     elif dataset == 'celebA':

~/.conda/envs/py38/lib/python3.8/site-packages/wilds/datasets/camelyon17_dataset.py in __init__(self, version, root_dir, download, split_scheme)
     55     def __init__(self, version=None, root_dir='data', download=False, split_scheme='official'):
     56         self._version = version
---> 57         self._data_dir = self.initialize_data_dir(root_dir, download)
     58         self._original_resolution = (96,96)
     59

~/.conda/envs/py38/lib/python3.8/site-packages/wilds/datasets/wilds_dataset.py in initialize_data_dir(self, root_dir, download)
    363         try:
    364             start_time = time.time()
--> 365             download_and_extract_archive(
    366                 url=download_url,
    367                 download_root=data_dir,

~/.conda/envs/py38/lib/python3.8/site-packages/wilds/datasets/download_utils.py in download_and_extract_archive(url, download_root, extract_root, filename, md5, remove_finished, size)
    293         filename = os.path.basename(url)
    294
--> 295     download_url(url, download_root, filename, md5, size)
    296
    297     archive = os.path.join(download_root, filename)

~/.conda/envs/py38/lib/python3.8/site-packages/wilds/datasets/download_utils.py in download_url(url, root, filename, md5, size)
    106         try:
    107             print('Downloading ' + url + ' to ' + fpath)
--> 108             urllib.request.urlretrieve(
    109                 url, fpath,
    110                 reporthook=gen_bar_updater(size)

~/.conda/envs/py38/lib/python3.8/urllib/request.py in urlretrieve(url, filename, reporthook, data)
    274
    275             while True:
--> 276                 block = fp.read(bs)
    277                 if not block:
    278                     break

~/.conda/envs/py38/lib/python3.8/http/client.py in read(self, amt)
    457             # Amount is given, implement using readinto
    458             b = bytearray(amt)
--> 459             n = self.readinto(b)
    460             return memoryview(b)[:n].tobytes()
    461         else:

~/.conda/envs/py38/lib/python3.8/http/client.py in readinto(self, b)
    491
    492         if self.chunked:
--> 493             return self._readinto_chunked(b)
    494
    495         if self.length is not None:

~/.conda/envs/py38/lib/python3.8/http/client.py in _readinto_chunked(self, b)
    591
    592                 if len(mvb) <= chunk_left:
--> 593                     n = self._safe_readinto(mvb)
    594                     self.chunk_left = chunk_left - n
    595                     return total_bytes + n

~/.conda/envs/py38/lib/python3.8/http/client.py in _safe_readinto(self, b)
    619         """Same as _safe_read, but for reading into a buffer."""
    620         amt = len(b)
--> 621         n = self.fp.readinto(b)
    622         if n < amt:
    623             raise IncompleteRead(bytes(b[:n]), amt-n)

~/.conda/envs/py38/lib/python3.8/socket.py in readinto(self, b)
    667         while True:
    668             try:
--> 669                 return self._sock.recv_into(b)
    670             except timeout:
    671                 self._timeout_occurred = True

~/.conda/envs/py38/lib/python3.8/ssl.py in recv_into(self, buffer, nbytes, flags)
   1239                   "non-zero flags not allowed in calls to recv_into() on %s" %
   1240                   self.__class__)
-> 1241             return self.read(nbytes, buffer)
   1242         else:
   1243             return super().recv_into(buffer, nbytes, flags)

~/.conda/envs/py38/lib/python3.8/ssl.py in read(self, len, buffer)
   1097         try:
   1098             if buffer is not None:
-> 1099                 return self._sslobj.read(len, buffer)
   1100             else:
   1101                 return self._sslobj.read(len)

KeyboardInterrupt:

Shown below are samples from the pre-change distribution:

[5]:
fig, axs = plt.subplots(nrows=1, ncols=6, figsize=(15,4))
for i in range(6):
    axs[i].imshow(np.transpose(next(stream_p), (1,2,0)))
    axs[i].axis('off')
../_images/examples_cd_online_camelyon_11_0.png

And samples from the post-change distribution:

[6]:
fig, axs = plt.subplots(nrows=1, ncols=6, figsize=(15,4))
for i in range(6):
    axs[i].imshow(np.transpose(next(stream_q_h1), (1,2,0)))
    axs[i].axis('off')
../_images/examples_cd_online_camelyon_13_0.png

Kernel Projection

The images are of dimension 96x96x3. We train an autoencoder in order to define a more structured representational space of lower dimension. This projection can be thought of as an extension of the kernel. It is important that trained preprocessing components are trained on a split of data that doesn’t then form part of the reference data passed to the drift detector.

[7]:
ENC_DIM = 32
BATCH_SIZE = 32
EPOCHS = 5
LEARNING_RATE = 1e-3
[8]:
encoder = nn.Sequential(
    nn.Conv2d(3, 8, 5, stride=3, padding=1),    # [batch, 8, 32, 32]
    nn.ReLU(),
    nn.Conv2d(8, 12, 4, stride=2, padding=1),   # [batch, 12, 16, 16]
    nn.ReLU(),
    nn.Conv2d(12, 16, 4, stride=2, padding=1),   # [batch, 16, 8, 8]
    nn.ReLU(),
    nn.Conv2d(16, 20, 4, stride=2, padding=1),   # [batch, 20, 4, 4]
    nn.ReLU(),
    nn.Conv2d(20, ENC_DIM, 4, stride=1, padding=0),   # [batch, enc_dim, 1, 1]
    nn.Flatten(),
)
decoder = nn.Sequential(
    nn.Unflatten(1, (ENC_DIM, 1, 1)),
    nn.ConvTranspose2d(ENC_DIM, 20, 4, stride=1, padding=0),  # [batch, 20, 4, 4]
    nn.ReLU(),
    nn.ConvTranspose2d(20, 16, 4, stride=2, padding=1),  # [batch, 16, 8, 8]
    nn.ReLU(),
    nn.ConvTranspose2d(16, 12, 4, stride=2, padding=1),  # [batch, 12, 16, 16]
    nn.ReLU(),
    nn.ConvTranspose2d(12, 8, 4, stride=2, padding=1),  # [batch, 8, 32, 32]
    nn.ReLU(),
    nn.ConvTranspose2d(8, 3, 5, stride=3, padding=1),   # [batch, 3, 96, 96]
    nn.Sigmoid(),
)
ae = nn.Sequential(encoder, decoder).to(device)

x_fit, x_ref = np.split(x_ref, [len(x_ref)//2])
x_fit = torch.as_tensor(x_fit)
x_fit_dl = DataLoader(TensorDataset(x_fit, x_fit), BATCH_SIZE, shuffle=True)

We can train the autoencoder using a helper function provided for convenience in alibi-detect.

[9]:
from alibi_detect.models.pytorch import trainer

trainer(ae, nn.MSELoss(), x_fit_dl, device, learning_rate=LEARNING_RATE, epochs=EPOCHS)
2021-08-13 15:13:11.305800: I tensorflow/stream_executor/platform/default/dso_loader.cc:53] Successfully opened dynamic library libcudart.so.11.0
Epoch 1/5: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 32/32 [00:00<00:00, 57.14it/s, loss=0.0797]
Epoch 2/5: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 32/32 [00:00<00:00, 126.92it/s, loss=0.0334]
Epoch 3/5: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 32/32 [00:00<00:00, 127.28it/s, loss=0.0234]
Epoch 4/5: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 32/32 [00:00<00:00, 127.19it/s, loss=0.0182]
Epoch 5/5: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 32/32 [00:00<00:00, 128.28it/s, loss=0.0164]

The preprocessing/projection functions are expected to map numpy arrays to numpy array, so we wrap the encoder within the function below.

[10]:
def encoder_fn(x: np.ndarray) -> np.ndarray:
    x = torch.as_tensor(x).to(device)
    with torch.no_grad():
        x_proj = encoder(x)
    return x_proj.cpu().numpy()

Drift Detection

alibi-detect’s online drift detectors window the stream of data in an ‘overlapping window’ manner such that a test is performed at every time step. We will use an estimator of MMD as the test statistic. The estimate is updated incrementally at low cost. The thresholds are configured via simulation in an initial configuration phase to target the desired expected runtime (ERT) in the absence of change. For a detailed description of this calibration procedure see Cobb et al, 2021.

[11]:
ERT = 150  # expected run-time in absence of change
W = 20  # size of test window
B = 50_000  # number of simulations to configure threshold
[12]:
from alibi_detect.cd import MMDDriftOnline

dd = MMDDriftOnline(x_ref, ERT, W, backend='pytorch', preprocess_fn=encoder_fn)
Generating permutations of kernel matrix..
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1000/1000 [00:00<00:00, 16177.23it/s]
Computing thresholds: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 20/20 [00:02<00:00,  7.70it/s]

We define a function which will apply the detector to the streams and return the time at which drift was detected.

[13]:
def compute_runtime(detector: Callable, stream: Generator) -> int:

    t = 0
    detector.reset()
    detected = False

    while not detected:
        t += 1
        z = next(stream)
        pred = detector.predict(z)
        detected = pred['data']['is_drift']
    print(t)
    return t

First we apply the detector multiple times to the pre-change stream where the distribution is unchanged.

[14]:
times_h0 = [compute_runtime(dd, stream_p) for i in range(15)]
print(f"Average runtime in absence of change: {np.array(times_h0).mean()}")
575
134
328
35
28
166
79
158
33
298
54
122
28
216
80
Average runtime in absence of change: 155.6

We see that the average runtime in the absence of change is close to the desired ERT, as expected. We can inspect the detector’s test_stats and thresholds properties to see how the test statistic varied over time and how close it got to exceeding the threshold.

[15]:
ts = np.arange(dd.t)
plt.plot(ts, dd.test_stats, label='Test statistic')
plt.plot(ts, dd.thresholds, label='Thresholds')
plt.xlabel('t', fontsize=16)
plt.ylabel('$T_t$', fontsize=16)
plt.legend(loc='upper right', fontsize=14)
plt.show()
../_images/examples_cd_online_camelyon_31_0.png

Now we apply it to the post-change stream where the images are from a drifted distribution.

[16]:
times_h1 = [compute_runtime(dd, stream_q_h1) for i in range(15)]
print(f"Average detection delay following change: {np.array(times_h1).mean()}")
15
17
12
20
5
15
12
12
20
9
11
6
12
14
11
Average detection delay following change: 12.733333333333333

We see that the detector is quick to flag drift when it has occured.

[17]:
ts = np.arange(dd.t)
plt.plot(ts, dd.test_stats, label='Test statistic')
plt.plot(ts, dd.thresholds, label='Thresholds')
plt.xlabel('t', fontsize=16)
plt.ylabel('$T_t$', fontsize=16)
plt.legend(loc='upper right', fontsize=14)
plt.show()
../_images/examples_cd_online_camelyon_35_0.png