Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
probml
GitHub Repository: probml/pyprobml
Path: blob/master/notebooks/book2/34/batch_bald_mnist_pytorch.ipynb
1193 views
Kernel: Python 3.7.13 ('colab': conda)
import math import matplotlib.pyplot as plt import seaborn as sns import random import numpy as np import os from dataclasses import dataclass %matplotlib inline try: import torch except ModuleNotFoundError: %pip install -qq torch import torch from torch import nn as nn from torch.nn import functional as F try: from tqdm.notebook import tqdm except ModuleNotFoundError: %pip install -qq tqdm from tqdm.notebook import tqdm try: from torchvision import datasets, transforms except ModuleNotFoundError: %pip install -qq torchvision from torchvision import datasets, transforms try: from batchbald_redux import ( active_learning, batchbald, consistent_mc_dropout, joint_entropy, repeated_mnist, ) except ModuleNotFoundError: %pip install -qq batchbald_redux from batchbald_redux import ( active_learning, batchbald, consistent_mc_dropout, joint_entropy, repeated_mnist, ) try: import probml_utils as pml from probml_utils import savefig, latexify except ModuleNotFoundError: %pip install -qq git+https://github.com/probml/probml-utils.git import probml_utils as pml from probml_utils import savefig, latexify
use_cuda = torch.cuda.is_available() print(f"use_cuda: {use_cuda}") device = "cuda" if use_cuda else "cpu" kwargs = {"num_workers": 0, "pin_memory": True} if use_cuda else {}
class BayesianCNN(consistent_mc_dropout.BayesianModule): def __init__(self, num_classes=10): super().__init__() self.conv1 = nn.Conv2d(1, 32, kernel_size=5) self.conv1_drop = consistent_mc_dropout.ConsistentMCDropout2d() self.conv2 = nn.Conv2d(32, 64, kernel_size=5) self.conv2_drop = consistent_mc_dropout.ConsistentMCDropout2d() self.fc1 = nn.Linear(1024, 128) self.fc1_drop = consistent_mc_dropout.ConsistentMCDropout() self.fc2 = nn.Linear(128, num_classes) def mc_forward_impl(self, input: torch.Tensor): input = F.relu(F.max_pool2d(self.conv1_drop(self.conv1(input)), 2)) input = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(input)), 2)) input = input.view(-1, 1024) input = F.relu(self.fc1_drop(self.fc1(input))) input = self.fc2(input) input = F.log_softmax(input, dim=1) return input
@dataclass class CandidateBatch: scores: [] indices: [] def get_random( log_probs_N_K_C: torch.Tensor, batch_size: int, num_samples: int, dtype=None, device=None, ) -> CandidateBatch: N, K, C = log_probs_N_K_C.shape batch_size = min(batch_size, N) candidate_indices = [] candidate_scores = [] if batch_size == 0: return CandidateBatch(candidate_scores, candidate_indices) # We always keep these on the CPU. scores_N = torch.empty(N, dtype=torch.double, pin_memory=torch.cuda.is_available()) picked_indices = torch.randperm(N)[:batch_size].numpy() candidate_score, candidate_index = scores_N.max(dim=0) candidate_indices.append(picked_indices) candidate_scores.append(candidate_score.item()) return CandidateBatch(candidate_scores, candidate_indices)
batch_list = [4, 8, 16, 32] algo_list = ["bald", "batchbald"] final_test_accs = [] final_indices = [] max_training_samples = 180 # Maximum limit of train samples needed num_inference_samples = 100 num_test_inference_samples = 5 num_samples = 100000 # Total number of samples test_batch_size = 512 # Test Loader Batch size batch_size = 64 # Train loader Batch size scoring_batch_size = 128 # Pool Loader Batch size training_iterations = 4096 * 6
for type in algo_list: if type == "bald": print("******************************************BALD Implementation******************************************") else: print( "******************************************BatchBALD Implementation******************************************" ) for acquisition_batch_size in batch_list: # Batch size per iteration print( "******************************************Batch Size: " + str(acquisition_batch_size) + "******************************************" ) seed_value = 0 torch.manual_seed(seed_value) torch.cuda.manual_seed(seed_value) torch.backends.cudnn.deterministic = True torch.cuda.manual_seed_all(seed_value) random.seed(seed_value) np.random.seed(seed_value) os.environ["PYTHONHASHSEED"] = str(seed_value) num_initial_samples = 20 # Number of initial samples required num_classes = 10 # Total classes in MNIST dataset train_dataset, test_dataset = repeated_mnist.create_repeated_MNIST_dataset(num_repetitions=1, add_noise=False) # Generates 20 samples (2 from each class) and returns their indices initial_samples = active_learning.get_balanced_sample_indices( repeated_mnist.get_targets(train_dataset), num_classes=num_classes, n_per_digit=num_initial_samples / num_classes, ) test_accs = [] test_loss = [] added_indices = [] active_learning_data = active_learning.ActiveLearningData( train_dataset ) # Splits the dataset into training dataset and pool dataset active_learning_data.acquire( initial_samples ) # Seperates the initial indices from the pool and fixes it as initial train dataset active_learning_data.extract_dataset_from_pool( 40000 ) # Extracts 40000 samples from pool and makes it as validation dataset train_loader = torch.utils.data.DataLoader( active_learning_data.training_dataset, sampler=active_learning.RandomFixedLengthSampler( active_learning_data.training_dataset, training_iterations ), batch_size=batch_size, **kwargs, ) pool_loader = torch.utils.data.DataLoader( active_learning_data.pool_dataset, batch_size=scoring_batch_size, shuffle=False, **kwargs, ) test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=test_batch_size, shuffle=False, **kwargs) pbar = tqdm( initial=len(active_learning_data.training_dataset), total=max_training_samples, desc="Training Set Size", ) while True: model = BayesianCNN(num_classes).to(device=device) # initialise model optimizer = torch.optim.Adam(model.parameters()) model.train() # Train for data, target in tqdm(train_loader, desc="Training", leave=False): data = data.to(device=device) target = target.to(device=device) optimizer.zero_grad() prediction = model(data, 1).squeeze(1) loss = F.nll_loss(prediction, target) loss.backward() optimizer.step() # Test loss = 0 correct = 0 with torch.no_grad(): for data, target in tqdm(test_loader, desc="Testing", leave=False): data = data.to(device=device) target = target.to(device=device) prediction = torch.logsumexp(model(data, num_test_inference_samples), dim=1) - math.log( num_test_inference_samples ) loss += F.nll_loss(prediction, target, reduction="sum") prediction = prediction.max(1)[1] correct += prediction.eq(target.view_as(prediction)).sum().item() loss /= len(test_loader.dataset) test_loss.append(loss) percentage_correct = 100.0 * correct / len(test_loader.dataset) test_accs.append(percentage_correct) print("Test set: Average loss: {:.4f}, Accuracy: ({:.2f}%)".format(loss, percentage_correct)) if len(active_learning_data.training_dataset) >= max_training_samples: break # Acquire pool predictions N = len(active_learning_data.pool_dataset) logits_N_K_C = torch.empty( (N, num_inference_samples, num_classes), dtype=torch.double, pin_memory=use_cuda, ) with torch.no_grad(): model.eval() for i, (data, _) in enumerate(tqdm(pool_loader, desc="Evaluating Acquisition Set", leave=False)): data = data.to(device=device) lower = i * pool_loader.batch_size upper = min(lower + pool_loader.batch_size, N) logits_N_K_C[lower:upper].copy_(model(data, num_inference_samples).double(), non_blocking=True) with torch.no_grad(): if type == "batchbald": candidate_batch = batchbald.get_batchbald_batch( logits_N_K_C, acquisition_batch_size, num_samples, dtype=torch.double, device=device, # Returns the indices and scores(Mutual Information) for the batch selected by Batchbald/BALD Strategy. ) elif type == "bald": candidate_batch = batchbald.get_bald_batch( logits_N_K_C, acquisition_batch_size, dtype=torch.double, device=device, ) targets = repeated_mnist.get_targets(active_learning_data.pool_dataset) # Returns the target labels dataset_indices = active_learning_data.get_dataset_indices( candidate_batch.indices ) # Returns indices for candidate batch print("Dataset indices: ", dataset_indices) # print("Scores: ", candidate_batch.scores) print("Labels: ", targets[candidate_batch.indices]) active_learning_data.acquire(candidate_batch.indices) # add the new indices to training dataset added_indices.append(dataset_indices) pbar.update(len(dataset_indices)) final_test_accs.append(test_accs) final_indices.append(added_indices)
print(final_test_accs)
latexify(width_scale_factor=2, fig_height=2)

BALD Test Accuracy Curve

p = plt.rcParams p["axes.grid"] = True p["grid.color"] = "#999999" p["grid.linestyle"] = "--" p["lines.linewidth"] = 2 plt.plot(np.arange(0, 132, 4), final_test_accs[0][:33], label="4") plt.plot(np.arange(0, 132, 8), final_test_accs[1][:17], label="8") plt.plot(np.arange(0, 132, 16), final_test_accs[2][:9], label="16") plt.plot(np.arange(0, 132, 32), final_test_accs[3][:-1], label="32") plt.legend() plt.legend(title="Batch-Size", loc="lower right") plt.xlabel("No. of Points Queried", fontsize=9) plt.ylabel("Test Accuracy", fontsize=9) plt.xticks([i for i in range(0, 132, 16)], rotation=90) plt.tight_layout() sns.despine() savefig("test_accuracy_bald") plt.show()

BatchBALD Test Accuracy Curve

p = plt.rcParams p["axes.grid"] = True p["grid.color"] = "#999999" p["grid.linestyle"] = "--" p["lines.linewidth"] = 2 plt.plot(np.arange(0, 132, 4), final_test_accs[4][:33], label="4") plt.plot(np.arange(0, 132, 8), final_test_accs[5][:17], label="8") plt.plot(np.arange(0, 132, 16), final_test_accs[6][:9], label="16") plt.plot(np.arange(0, 132, 32), final_test_accs[7][:-1], label="32") plt.legend() plt.legend(title="Batch-Size", loc="lower right") plt.xlabel("No. of Points Queried", fontsize=9) plt.ylabel("Test Accuracy", fontsize=9) plt.xticks([i for i in range(0, 132, 16)], rotation=90) plt.tight_layout() sns.despine() savefig("test_accuracy_batchbald") plt.show()
latexify(width_scale_factor=3, fig_height=2)

BALD Samples

rows = ["Batch {}".format(row) for row in [1, 2, 3]] plt.rcParams["axes.titlesize"] = 10 fig, axes = plt.subplots(nrows=3, ncols=4) plot_indices = [final_indices[0][i][j] for i in range(4, 7) for j in range(0, 4)] for i, ax in zip(range(1, 4 * 3 + 1), axes.flatten()): image = train_dataset[plot_indices[i - 1]][0].reshape((28, 28)) ax.imshow(image, cmap="gray") ax.grid(False) ax.tick_params( axis="both", labelsize=0, length=0, left=False, bottom=False, labelleft=False, labelbottom=False, ) for ax, row in zip(axes[:, 0], rows): ax.set_ylabel(row, rotation=90, size="large") fig.subplots_adjust(wspace=0, hspace=0) savefig("bald_samples") plt.show()

#BatchBALD Samples

rows = ["Batch {}".format(row) for row in [1, 2, 3]] plt.rcParams["axes.titlesize"] = 10 fig, axes = plt.subplots(nrows=3, ncols=4) plot_indices = [final_indices[4][i][j] for i in range(0, 3) for j in range(0, 4)] for i, ax in zip(range(1, 4 * 3 + 1), axes.flatten()): image = train_dataset[plot_indices[i - 1]][0].reshape((28, 28)) ax.imshow(image, cmap="gray") ax.grid(False) ax.tick_params( axis="both", labelsize=0, length=0, left=False, bottom=False, labelleft=False, labelbottom=False, ) for ax, row in zip(axes[:, 0], rows): ax.set_ylabel(row, rotation=90, size="large") fig.subplots_adjust(wspace=0, hspace=0) savefig("batchbald_samples") plt.show()

Random Acquisition

final_random_accs = [] for i in range(5): torch.manual_seed(i) torch.cuda.manual_seed(i) torch.backends.cudnn.deterministic = True random.seed(i) np.random.seed(i) max_training_samples = 148 # Maximum limit of train samples needed acquisition_batch_size = 4 # Batch size per iteration labels_list = [] select_indices = [] test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=test_batch_size, shuffle=False, **kwargs) active_learning_data = active_learning.ActiveLearningData(train_dataset) # Split off the initial samples first. active_learning_data.acquire(initial_samples) # Initial train # THIS REMOVES MOST OF THE POOL DATA. UNCOMMENT THIS TO TAKE ALL UNLABELLED DATA INTO ACCOUNT! active_learning_data.extract_dataset_from_pool(40000) ## Validation data train_loader = torch.utils.data.DataLoader( active_learning_data.training_dataset, sampler=active_learning.RandomFixedLengthSampler(active_learning_data.training_dataset, training_iterations), batch_size=batch_size, **kwargs, ) pool_loader = torch.utils.data.DataLoader( active_learning_data.pool_dataset, batch_size=scoring_batch_size, shuffle=False, **kwargs, ) # Run experiment test_accs = [] test_loss = [] added_indices = [] pbar = tqdm( initial=len(active_learning_data.training_dataset), total=max_training_samples, desc="Training Set Size", ) while True: model = BayesianCNN(num_classes).to(device=device) optimizer = torch.optim.Adam(model.parameters()) model.train() # Train for data, target in tqdm(train_loader, desc="Training", leave=False): data = data.to(device=device) target = target.to(device=device) optimizer.zero_grad() prediction = model(data, 1).squeeze(1) loss = F.nll_loss(prediction, target) loss.backward() optimizer.step() # Test loss = 0 correct = 0 with torch.no_grad(): for data, target in tqdm(test_loader, desc="Testing", leave=False): data = data.to(device=device) target = target.to(device=device) prediction = torch.logsumexp(model(data, num_test_inference_samples), dim=1) - math.log( num_test_inference_samples ) loss += F.nll_loss(prediction, target, reduction="sum") prediction = prediction.max(1)[1] correct += prediction.eq(target.view_as(prediction)).sum().item() loss /= len(test_loader.dataset) test_loss.append(loss) percentage_correct = 100.0 * correct / len(test_loader.dataset) test_accs.append(percentage_correct) print( "Test set: Average loss: {:.4f}, Accuracy: {}/{} ({:.2f}%)".format( loss, correct, len(test_loader.dataset), percentage_correct ) ) if len(active_learning_data.training_dataset) >= max_training_samples: break # Acquire pool predictions N = len(active_learning_data.pool_dataset) logits_N_K_C = torch.empty( (N, num_inference_samples, num_classes), dtype=torch.double, pin_memory=use_cuda, ) with torch.no_grad(): model.eval() for i, (data, _) in enumerate(tqdm(pool_loader, desc="Evaluating Acquisition Set", leave=False)): data = data.to(device=device) lower = i * pool_loader.batch_size upper = min(lower + pool_loader.batch_size, N) logits_N_K_C[lower:upper].copy_(model(data, num_inference_samples).double(), non_blocking=True) with torch.no_grad(): candidate_batch = get_random( logits_N_K_C, acquisition_batch_size, num_samples, dtype=torch.double, device=device, ) targets = repeated_mnist.get_targets(active_learning_data.pool_dataset) dataset_indices = active_learning_data.get_dataset_indices(candidate_batch.indices) print("Dataset indices: ", dataset_indices) print("Scores: ", candidate_batch.scores) print("Labels: ", targets[candidate_batch.indices]) labels_list.append(targets[candidate_batch.indices]) select_indices.append(dataset_indices) active_learning_data.acquire(candidate_batch.indices) added_indices.append(dataset_indices) pbar.update(len(dataset_indices)) final_random_accs.append(test_accs)
random_scores_array = np.array(final_random_accs) random_mean = np.mean(random_scores_array, axis=0) random_std = np.std(random_scores_array, axis=0)
p = plt.rcParams p["axes.grid"] = True p["grid.color"] = "#999999" p["grid.linestyle"] = "--" p["lines.linewidth"] = 2 plt.plot(np.arange(0, 100, acquisition_batch_size), random_mean[:25], label="Random") plt.fill_between( np.arange(0, 100, acquisition_batch_size), random_mean[:25] - random_std[:25], random_mean[:25] + random_std[:25], color="lightskyblue", ) plt.plot(np.arange(0, 100, acquisition_batch_size), final_test_accs[0][:25], label="BALD") plt.plot( np.arange(0, 100, acquisition_batch_size), final_test_accs[4][:25], label="BatchBALD", ) plt.legend(loc="lower right", fontsize=7) plt.xlabel("No. of Points Queried", fontsize=10) plt.ylabel("Test Accuracy", fontsize=10) plt.xticks([i for i in range(0, 100, 16)], rotation=90) plt.yticks([i for i in range(55, 95, 5)], rotation=90) plt.tight_layout() sns.despine() savefig("accuracy_comparison") plt.show()