NN regularization

Author

Gorka Muñoz-Gil and Marcin Płodzień

Open in Colab

1 Overfitting

As we saw when learning about polynomial regression, overfitting is a common problem in deep learning, and occurs when a model is excessively complex, and manages to fit exactly each of the training’s set datapoints, rather than learning some general knowledge that would allow it to generalize to new, unseen data.

We identified two main factors for the overfitting of a model:

  1. The number of parameters is too large.
  2. The number of training examples is too small.

As we build neural networks (NN), this two factors become even more important. Adding too many layers to a NN would certaintly lead to overfitting, while having access to a too small dataset would make training also fail.

2 Regularization techniques

In this notebook, we will explore few common regularization techniques that allow to deal with overfitting. In particular, regularization techniques allow for:

  1. Improved generalization: By preventing overfitting, regularization can help to improve the generalization performance of the model on new, unseen data.

  2. Increased robustness: Regularized models are often more robust to noise and other types of perturbations, as they are less sensitive to specific patterns in the training data.

  3. Simplified models: Regularization can help to reduce the complexity of the model, which can make it easier to interpret and understand.

  4. Better interpretability: Regularization can help to identify the most important features in the data, which can improve the interpretability of the model.

  5. Improved efficiency: Regularization can help to reduce the number of parameters in the model, which can make the model more efficient to train and deploy.

  6. Improved optimization: Regularization can help to stabilize the optimization process, which can lead to faster convergence and better performance.

We will now review different regularization techniques. For that, we will train different models. We will consider for that the same problem as in the previous notebook: the multiclass classification of MNIST. We first get the data from torchvision.

Code
from torchvision.datasets import MNIST
from torchvision.transforms import ToTensor
from torch.utils.data import DataLoader, random_split
import torch

torch.manual_seed(7)


mnist_train = MNIST(root="data", train=True, download=True, transform=ToTensor())
mnist_test = MNIST(root="data", train=False, download=True, transform=ToTensor())

train_data, validation_data = random_split(mnist_train, [55000, 5000])

batch_size = 256
train_loader = DataLoader(train_data, batch_size, shuffle=True)
val_loader = DataLoader(validation_data, batch_size, shuffle=False)
test_loader = DataLoader(mnist_test, batch_size, shuffle=False)

Let’s define our loss function, again the cross-entropy loss.

criterion = torch.nn.CrossEntropyLoss() 
# criterion is typically used in ML coding jargon interchangeably with loss function.
# It will be clear later why we didn't call this loss_func

and the accuracy

def accuracy(predictions, targets):
    """Compute the accuracy of predictions given the true targets."""
    return (predictions.argmax(dim=1) == targets).float().mean()

We will also need a model. We will use the same one as the previous notebook.

import torch.nn as nn
import torch.nn.functional as F

class FullyConnected(nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super().__init__()
        self.linear_1 = nn.Linear(input_size, hidden_size)
        self.linear_2 = nn.Linear(hidden_size, output_size, bias=False)
        
    def forward(self, x):
        x = x.reshape(-1, 28 * 28)
        z = self.linear_1(x)
        x = F.relu(z)
        z = self.linear_2(x)
        return z

Let us now create the training loop. Because we will need it few times, we just create a training loop wrapper:

from tqdm.auto import tqdm

def training_loop(model, loss_func, n_epochs, train_loader, val_loader, device = 'cpu'):

    for _ in tqdm(range(n_epochs)):
        epoch_loss = 0
        for images, labels in train_loader:
            images, labels = images.to(device), labels.to(device)
            logits = model(images)
            loss = loss_func(logits, labels)
            loss.backward()
            optimizer.step()
            optimizer.zero_grad()
        
            epoch_loss += loss.item()
    
    with torch.no_grad():
        epoch_loss = 0
        val_preds, val_targets = [], []
        for images, labels in val_loader:
            images, labels = images.to(device), labels.to(device)
            
            logits = model(images)
            loss = loss_func(logits, labels)
        
            epoch_loss += loss.item()
            val_preds.append(F.softmax(logits, dim=1))
            val_targets.append(labels)
        
        val_acc = accuracy(torch.cat(val_preds), torch.cat(val_targets))
        validation_loss = epoch_loss / len(val_loader)
        
        print(f"Validation Loss: {validation_loss:.4f}, Accuracy: {val_acc:.4f}")
    
    return model
n_epochs = 10
model = FullyConnected(28 * 28, 500, 10).to(DEVICE)
optimizer = torch.optim.SGD(model.parameters(), lr=1e-2)

trained_model = training_loop(model, criterion, n_epochs, train_loader, val_loader, device = DEVICE)

2.1 Weight decay (L1 and L2 regularization)

We already say this method, when performing polynomial regression, although in the context of neural networks we typically refer to it as weight decay rather than weight regularization. As we saw, this method involves adding a penalty to the cost function during training to discourage the model from learning excessively large weights. These regularization techniques are based on the idea that large weights can lead to overfitting, as they may allow the model to fit the training data too closely. L1 and L2 regularization are methods for adding a penalty term to the cost function during training to discourage the model from learning excessively large weights. L1 regularization:

2.1.1 L1 regularization

L1 regularization, also known as Lasso regularization, adds a penalty term to the cost function that is proportional to the absolute value of the weights. The L1 regularization term has the form:

\[\begin{equation} L_1 = \lambda \sum |W| \end{equation}\] where \(\lambda\) is the regularization parameter, and \(W\) is the weight matrix.

The effect of L1 regularization is to push the weights towards zero, which can help to reduce overfitting by decreasing the complexity of the model. L1 regularization can also be useful for feature selection, as it tends to drive the weights of unimportant features to zero, effectively removing them from the model.

Now that we know pytorch, let’s use it to create our own loss function!

class LassoLoss(nn.Module):
    def __init__(self, base_criterion: nn.Module, model: nn.Module, l1_lambda: float = 1e-4):
        super().__init__()
        
        self.base = base_criterion
        
        self.params = [p for n,p in model.named_parameters() if p.requires_grad and not n.endswith(".bias") and p.ndim > 1]
        
        self.l1_lambda = l1_lambda

    def forward(self, outputs, targets):
        base_loss = self.base(outputs, targets)
        l1 = torch.zeros((), device=base_loss.device)
        for p in self.params:
            l1 = l1 + p.abs().sum()
        return base_loss + self.l1_lambda * l1
n_epochs = 10
model = FullyConnected(28 * 28, 500, 10).to(DEVICE)
optimizer = torch.optim.SGD(model.parameters(), lr=1e-2)

loss_lasso = LassoLoss(base_criterion = torch.nn.CrossEntropyLoss(),
                       model = model,
                       l1_lambda = 1e-4)

trained_model_lasso = training_loop(model, loss_lasso, n_epochs, train_loader, val_loader, device = DEVICE)
Validation Loss: 1.0699, Accuracy: 0.8940
Exercise

From the previous we can see that the validation loss is bigger compared to the previous training, although accuracy is comparable. What happened?

Beyond our purpose of improving our validation accuracy, our goal was to reduce the average value of the weights. Did we succeed? Let’s take a look a it:

weight_og_model = torch.concatenate([l for l in [trained_model.linear_1.weight.flatten(), 
                                                 trained_model.linear_2.weight.flatten()]]).detach().cpu()

weight_lasso_model = torch.concatenate([l for l in [trained_model_lasso.linear_1.weight.flatten(), 
                                                    trained_model_lasso.linear_2.weight.flatten()]]).detach().cpu()
plt.hist(weight_og_model, bins = 100, label = f'OG model - Avg. = {weight_og_model.abs().mean():.4f}', density = True)
plt.hist(weight_lasso_model, bins = 100, label = f'Lasso model - Avg. = {weight_lasso_model.abs().mean():.4f}', alpha = 0.4, density = True)
plt.legend(loc = 'lower left')
plt.xlim([-.1, .1])
plt.xlabel('Weight value'); plt.ylabel('Frequency')
;
''

2.1.2 L2 regularization

L2 regularization, also known as \({\it Ridge}\) regularization, adds a penalty term to the cost function that is proportional to the square of the weights. The L2 regularization term has the form:

\[\begin{equation} L_2 = \lambda \sum W^2 \end{equation}\]

where again \(\lambda\) is the regularization parameter, and \(W\) are weights of the model.

The effect of L2 regularization is the similar to the L1 one: decrease the weights towards zero, which can help to reduce overfitting by decreasing the complexity of the model.

However, unlike L1 regularization, L2 regularization does not lead to the complete removal of weights, as it only shrinks the weights rather than setting them to zero.

In general, L2 regularization is more commonly used than L1 regularization, as it tends to be more stable and easier to optimize. However, L1 regularization can be useful in situations where it is important to select a subset of features, as it has the ability to drive some weights to zero. Let’s see that in code:

class RidgeLoss(nn.Module):
    def __init__(self, base_criterion: nn.Module, model: nn.Module, l1_lambda: float = 1e-4):
        super().__init__()
        
        self.base = base_criterion
        
        self.params = [p for n,p in model.named_parameters() if p.requires_grad and not n.endswith(".bias") and p.ndim > 1]
        
        self.l1_lambda = l1_lambda

    def forward(self, outputs, targets):
        base_loss = self.base(outputs, targets)
        l1 = torch.zeros((), device=base_loss.device)
        for p in self.params:
            l1 = l1 + p.square().sum()
        return base_loss + self.l1_lambda * l1
n_epochs = 10
model = FullyConnected(28 * 28, 500, 10).to(DEVICE)
optimizer = torch.optim.SGD(model.parameters(), lr=1e-2)

loss_ridge = RidgeLoss(base_criterion = torch.nn.CrossEntropyLoss(),
                       model = model,
                       l1_lambda = 1e-4)

trained_model_ridge = training_loop(model, loss_lasso, n_epochs, train_loader, val_loader, device = DEVICE)
Validation Loss: 1.0605, Accuracy: 0.8962
weight_og_model = torch.concatenate([l for l in [trained_model.linear_1.weight.flatten(), 
                                                 trained_model.linear_2.weight.flatten()]]).detach().cpu()

weight_ridge_model = torch.concatenate([l for l in [trained_model_ridge.linear_1.weight.flatten(), 
                                                    trained_model_ridge.linear_2.weight.flatten()]]).detach().cpu()
plt.hist(weight_og_model, bins = 100, label = f'OG model - Avg. = {weight_og_model.abs().mean():.4f}', density = True)
plt.hist(weight_ridge_model, bins = 100, label = f'Ridge model - Avg. = {weight_ridge_model.abs().mean():.4f}', alpha = 0.4, density = True)
plt.legend(loc = 'lower left')
plt.xlim([-.1, .1])
plt.xlabel('Weight value'); plt.ylabel('Frequency')
;
''

Exercise

As we can see in the previous plot, not much change… Increase the \(\lambda\) parameter of the Ridge regularization and see what happens with: 1) the loss; 2) the weights.

Exercise

Let’s come up with a completely new loss function: implement both the Ridge and Lasso regularizations but with a twist: the former will only apply to the first layer and the former to the second. Use the same loss function construction as above. Use a similar plot as above to showcase what happened to the first and second layer’s weights.

### Your code here
Solution
class MixedL1L2Loss(nn.Module):
    def __init__(self, base_criterion: nn.Module, model: nn.Module,
                 l1_lambda: float = 1e-4, l2_lambda: float = 1e-4):
        super().__init__()
        self.base = base_criterion
        self.model = model
        self.l1_lambda = l1_lambda
        self.l2_lambda = l2_lambda

    def forward(self, outputs, targets):
        base_loss = self.base(outputs, targets)

        l1_reg = torch.zeros((), device=base_loss.device)
        l2_reg = torch.zeros((), device=base_loss.device)

        # Iterate over named parameters
        for name, param in self.model.named_parameters():
            if not param.requires_grad:
                continue
            if "linear_1" in name:                # L2 regularization
                l1_reg += param.pow(2).sum()
            elif "linear_2" in name:              # L1 regularization
                l2_reg += param.abs().sum()

        return base_loss + self.l1_lambda * l1_reg + self.l2_lambda * l2_reg
## Rather than training your model, which may take some time, you can test the loss on a non-trained model:

# We first get a batch
images, targets = next(iter(train_loader))

# Define the model and loss
model = FullyConnected(28 * 28, 500, 10).to(DEVICE)
loss_mixed = MixedL1L2Loss(base_criterion = nn.CrossEntropyLoss(), model = model, l1_lambda=1e-4, l2_lambda=1e-3)

# Do forward pass
outputs = model(images.to(DEVICE))

# Compute the loss
loss_mixed(outputs, targets.to(DEVICE))
tensor(3.0156, device='cuda:0', grad_fn=<AddBackward0>)

2.2 Dropout

Another popular technique to prevent overfitting in neural networks is dropout (Hinton et al. 2012) While ridge and lasso constrain the magnitude of model parameters, dropout acts directly on the network’s activations: during training, each neuron is randomly “dropped out” (i.e., temporarily set to zero) with a fixed probability. This prevents the network from relying too heavily on any single neuron or feature and encourages the development of redundant, robust representations. At inference time, all neurons are active, but their outputs are scaled to account for the missing activations during training. In essence, dropout can be viewed as a stochastic regularizer that approximates training an ensemble of many smaller subnetworks, improving generalization without adding explicit parameter penalties.

Where to place the dropout: dropout should be placed after the activation function for most cases. When using ReLu, it is more efficient to place it before, but the numerical result won’t change.

Oppose to what regularization through L1 or L2, dropout acts directly on the model. Let’s see how to do this in pytorch:

class FullyConnected_dropout(nn.Module):
    def __init__(self, input_size, hidden_size, output_size, dropout_prob = 0.2):
        super().__init__()
        self.linear_1 = nn.Linear(input_size, hidden_size)
        self.linear_2 = nn.Linear(hidden_size, output_size, bias=False)

        self.dropout = nn.Dropout(p = dropout_prob)
        
    def forward(self, x):
        x = x.reshape(-1, 28 * 28)
        z = self.linear_1(x)      
        x = F.relu(z)
        # After computing, we apply the dropout layer:
        x = self.dropout(x)
        z = self.linear_2(x)
        return z

Let’s now train the new model:

n_epochs = 10
model = FullyConnected_dropout(28 * 28, 500, 10, dropout_prob=0.2).to(DEVICE)
optimizer = torch.optim.SGD(model.parameters(), lr=1e-2)

trained_model_dropout = training_loop(model, criterion, n_epochs, train_loader, val_loader, device = DEVICE)
Validation Loss: 0.4138, Accuracy: 0.8830
Tip

Very important: dropout is only useful during training. During inference, we don’t want our neurons to randomly shut down, as this could affect the accuracy of the output. To solve this, pytorch’s models have to modes: model.train(), which sets the model’s layers into training mode (e.g. dropout is considered); or model.eval(), which sets the model to inference mode, so dropout and other layer properties are not considered.

Let’s see the difference between the two modes:

for i in range(3):
    pred = trained_model_dropout(next(iter(val_loader))[0].to(DEVICE))
    print(pred[:4,:4])
tensor([[ 0.5062, -4.7286, -1.7473,  0.6674],
        [-1.5978, -1.3038,  2.4630,  6.5667],
        [ 2.2861, -5.4669,  5.9600,  4.6127],
        [-2.7885,  2.0934, -1.2175,  0.7194]], device='cuda:0',
       grad_fn=<SliceBackward0>)
tensor([[ 0.4503, -4.5496, -1.8042,  0.8437],
        [-0.0629, -2.2050,  1.7746,  7.6969],
        [ 0.3124, -4.2750,  5.9855,  4.3752],
        [-2.8338,  1.8375, -1.4099,  0.6517]], device='cuda:0',
       grad_fn=<SliceBackward0>)
tensor([[-0.3726, -3.7200, -1.7582,  0.7903],
        [ 0.5540, -2.1269,  1.7473,  5.9036],
        [ 1.4706, -4.9557,  5.8581,  4.6562],
        [-3.0979,  1.9791, -1.6882,  1.1849]], device='cuda:0',
       grad_fn=<SliceBackward0>)
trained_model_dropout.eval()
for i in range(3):
    pred = trained_model_dropout(next(iter(val_loader))[0].to(DEVICE))
    print(pred[:4,:4])
tensor([[ 0.3722, -4.6152, -1.9534,  0.9229],
        [-0.7638, -1.5321,  1.8037,  7.0875],
        [ 1.3205, -4.8303,  6.0047,  4.7763],
        [-3.0167,  2.0212, -1.4367,  0.9608]], device='cuda:0',
       grad_fn=<SliceBackward0>)
tensor([[ 0.3722, -4.6152, -1.9534,  0.9229],
        [-0.7638, -1.5321,  1.8037,  7.0875],
        [ 1.3205, -4.8303,  6.0047,  4.7763],
        [-3.0167,  2.0212, -1.4367,  0.9608]], device='cuda:0',
       grad_fn=<SliceBackward0>)
tensor([[ 0.3722, -4.6152, -1.9534,  0.9229],
        [-0.7638, -1.5321,  1.8037,  7.0875],
        [ 1.3205, -4.8303,  6.0047,  4.7763],
        [-3.0167,  2.0212, -1.4367,  0.9608]], device='cuda:0',
       grad_fn=<SliceBackward0>)

As you can see, but setting the model into evaluation mode we transform it into a deterministic model, just as expect!

2.3 Batch normalization

Batch normalization (Ioffe and Szegedy 2015) is a technique that normalizes the activations of each mini-batch to stabilize and speed up the training of deep neural networks.

Instead of computing normalization statistics over the entire training set, which would be impractical during stochastic optimization, batch normalization operates on each mini-batch. This helps keep the distribution of activations consistent across layers and reduces overfitting.

During training, for each mini-batch, the layer computes the mean and standard deviation of the activations and normalizes them as:

\[\begin{equation} \hat{x} = \frac{x - \mu_{\text{batch}}}{\sigma_{\text{batch}}} \end{equation}\]

where \(\mu_{\text{batch}}\) and \(\sigma_{\text{batch}}\) are the mean and standard deviation of the mini-batch activations.

The layer maintains running averages of these statistics, which are updated during training. At evaluation time, these running averages are used instead of the batch statistics, ensuring stable behavior and consistent outputs between training and inference. It is hence very important to set model.eval(), because if not the result will change depending on, for instance, the input size!

In pytorch, batch normalization is implemented as a layer, similar to what we did with Dropout, although here we will also need to specifiy the input size:

class FullyConnected_batchnorm(nn.Module):
    def __init__(self, input_size, hidden_size, output_size, dropout_p=0.5):
        super().__init__()
        self.linear_1 = nn.Linear(input_size, hidden_size)
        self.bn_1     = nn.BatchNorm1d(hidden_size)   # batch normalization layer
        self.linear_2 = nn.Linear(hidden_size, output_size, bias=False)

    def forward(self, x):
        x = x.reshape(x.size(0), -1)          
        x = self.linear_1(x)
        
        # We now implement the batch normalization to the activations (i.e. before the activation function!)
        x = self.bn_1(x)  
        
        x = F.relu(x)
        x = self.linear_2(x)
        return x

3 Data augumentation

Sometimes our dataset is very small but also very complex. This means that while we have very few samples, we will need a quite powerful model that will anyway overfit… The solution here is not to regularize the model but the opposite: extend our dataset. While in some cases this is possible, generally extending a dataset is either impossible or very costly. In this situation, we rely in data augmentation: performing transformations to our dataset samples that, while maintaining their general aspect, change them in particular ways. A common one for instance is rotations. Intuitively, this will also lead to better generalization: an image of a dog is an image of a dog no matter how we rotate it. By training a model with many different rotations, we have a better chance of it learning the concept of a dog!

Let’s look at some built-in data augmentation transformations in torchvision:

Figure code
from torchvision import datasets, transforms

# --- Load a sample image (MNIST example) ---
dataset = datasets.MNIST(root='data', train=True, download=True, transform=transforms.ToTensor())
img, label = dataset[0]  # take first image

# --- Define different transformations ---
transformations = {
    "Original": transforms.Compose([]),
    "Horizontal Flip": transforms.RandomHorizontalFlip(p=1.0),
    "Rotation (30°)": transforms.RandomRotation(30),
    "Crop + Resize": transforms.RandomResizedCrop(28, scale=(0.6, 1.0)),
    "Color Jitter": transforms.ColorJitter(brightness=0.5, contrast=0.5),
    "Gaussian Blur": transforms.GaussianBlur(kernel_size=3)
}

# --- Apply transformations ---
fig, axes = plt.subplots(1, len(transformations), figsize=(15, 3))

for ax, (name, tf) in zip(axes, transformations.items()):
    # Apply transformation
    img_t = tf(img)
    # Convert to numpy for plotting
    ax.imshow(img_t.squeeze(), cmap='gray')
    ax.set_title(name)
    ax.axis('off')

plt.tight_layout()
plt.show()

torchvision allows us to easily implement transformations on existing dataset. In this case, each transformation will be implemented sequentially. Important: DO NOT implement these transformations into the test set: we want to test that the model correctly predicts real images, not the transformations!

from torchvision import transforms

train_transform = transforms.Compose([
    transforms.RandomHorizontalFlip(),     # flip image horizontally with p=0.5
    transforms.RandomRotation(10),         # rotate image by ±10 degrees
    transforms.RandomResizedCrop(28, scale=(0.8, 1.0)),  # random crop and resize
    transforms.ToTensor(),                 # convert to PyTorch tensor
])

test_transform = transforms.Compose([
    transforms.ToTensor(),                 # only convert, no randomness in test
])

train_data = MNIST(root='data', train=True, download=True, transform=train_transform)
test_data  = MNIST(root='data', train=False, download=True, transform=test_transform)

batch_size = 256
train_loader = DataLoader(train_data, batch_size=batch_size, shuffle=True)
test_loader  = DataLoader(test_data, batch_size=batch_size, shuffle=False)
Exercise

Train a model as we did in the previous class. Track and plot the accuracy over the test dataset (you can also follow the loss functions for both the training and test dataset) for two datasets: 1) the “normal” MNIST; 2) the data augmented MNIST from above.

Because the MNIST dataset is quite big, its hard to overfit with the simple model we have. Let’s make it a bit smaller by considering only few samples per class:

def select_n_per_class(dataset, n_per_class):
    ''' Given an MNIST dataset, gets n_per_class samples per class '''
    targets = dataset.targets
    indices = []
    for c in range(10):  # digits 0–9
        class_idx = (targets == c).nonzero(as_tuple=True)[0]
        selected = class_idx[:n_per_class]
        indices.append(selected)
    return torch.cat(indices)

You can use this function in the previous dataset to take 100 samples per category only:

# Get MNIST
train_data_full = MNIST(root='data', train=True)

# Get the indices of your new smaller dataset and use Subset to get your smaller dataset
train_idx = select_n_per_class(train_data_full, n_per_class = 100)

from torch.utils.data import Subset
train_data = Subset(train_data_full, train_idx)

batch_size = 32
train_loader = DataLoader(train_data, batch_size=batch_size, shuffle=True)

Note that for the test set you can take as many as you want, as this won’t affect trainig.

Solution - Data augmented dataset
train_transform = transforms.Compose([
    transforms.RandomHorizontalFlip(),     # flip image horizontally with p=0.5
    transforms.RandomRotation(10),         # rotate image by ±10 degrees
    transforms.RandomResizedCrop(28, scale=(0.8, 1.0)),  # random crop and resize
    transforms.ToTensor(),                 # convert to PyTorch tensor
])

test_transform = transforms.Compose([
    transforms.ToTensor(),                 # only convert, no randomness in test
])

train_data_full = MNIST(root='data', train=True, download=True, transform=train_transform)
test_data_full  = MNIST(root='data', train=False, download=True, transform=test_transform)

train_idx = select_n_per_class(train_data_full, 100)
test_idx  = select_n_per_class(test_data_full, 300)

train_data = Subset(train_data_full, train_idx)
test_data  = Subset(test_data_full, test_idx)

batch_size = 32
train_loader_aug = DataLoader(train_data, batch_size=batch_size, shuffle=True)
test_loader_aug  = DataLoader(test_data, batch_size=batch_size, shuffle=False)
Solution - Data augmented training
model_aug = FullyConnected(28 * 28, 500, 10).to(DEVICE)
optimizer = torch.optim.SGD(model_aug.parameters(), lr=1e-2)
n_epochs = 50

accuracy_aug = []
training_loss_aug = []
validation_loss_aug = []

for _ in tqdm(range(n_epochs)):
    epoch_loss = 0
    for images, labels in train_loader_aug:
        images, labels = images.to(DEVICE), labels.to(DEVICE)
        logits = model_aug(images)
        loss = F.cross_entropy(logits, labels)
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()

        epoch_loss += loss.item()

    training_loss_aug.append(epoch_loss / len(train_loader))

    with torch.no_grad():
        epoch_loss = 0
        val_preds, val_targets = [], []
        for images, labels in test_loader_aug:
            images, labels = images.to(DEVICE), labels.to(DEVICE)
            logits = model_aug(images)
            loss = F.cross_entropy(logits, labels)

            epoch_loss += loss.item()
            val_preds.append(F.softmax(logits, dim=1))
            val_targets.append(labels)

        accuracy_aug.append(accuracy(torch.cat(val_preds), torch.cat(val_targets)).detach().cpu())
        validation_loss_aug.append(epoch_loss / len(val_loader))
Solution - Non-augmented dataset
train_transform = transforms.Compose([
    transforms.ToTensor(),                 # convert to PyTorch tensor
])

test_transform = transforms.Compose([
    transforms.ToTensor(),                 # only convert, no randomness in test
])


train_data_full = MNIST(root='data', train=True, download=True, transform=train_transform)
test_data_full  = MNIST(root='data', train=False, download=True, transform=test_transform)

train_idx = select_n_per_class(train_data_full, 100)
test_idx  = select_n_per_class(test_data_full, 300)

train_data = Subset(train_data_full, train_idx)
test_data  = Subset(test_data_full, test_idx)

batch_size = 1024
train_loader_Naug = DataLoader(train_data, batch_size=batch_size, shuffle=True)
test_loader_Naug  = DataLoader(test_data, batch_size=batch_size, shuffle=False)
Solution - Non-augmented training
model_Naug = FullyConnected(28 * 28, 500, 10).to(DEVICE)
optimizer = torch.optim.SGD(model_Naug.parameters(), lr=1e-2)

accuracy_Naug = []
training_loss_Naug = []
validation_loss_Naug = []

for _ in tqdm(range(n_epochs)):
    epoch_loss = 0
    for images, labels in train_loader_Naug:
        images, labels = images.to(DEVICE), labels.to(DEVICE)
        logits = model_Naug(images)
        loss = F.cross_entropy(logits, labels)
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()

        epoch_loss += loss.item()

    training_loss_Naug.append(epoch_loss / len(train_loader))

    with torch.no_grad():
        epoch_loss = 0
        val_preds, val_targets = [], []
        for images, labels in test_loader_Naug:
            images, labels = images.to(DEVICE), labels.to(DEVICE)
            logits = model_Naug(images)
            loss = F.cross_entropy(logits, labels)

            epoch_loss += loss.item()
            val_preds.append(F.softmax(logits, dim=1))
            val_targets.append(labels)

        accuracy_Naug.append(accuracy(torch.cat(val_preds), torch.cat(val_targets)).detach().cpu())
        validation_loss_Naug.append(epoch_loss / len(val_loader))
Solution - Result figure
import plotly.graph_objs as go
import plotly.io as pio 
pio.renderers.default='notebook'

# Matplotlib 'C0' and 'C1' as hex
C0 = '#1f77b4'  # blue
C1 = '#ff7f0e'  # orange

fig = go.Figure()

# Augmented
fig.add_scatter(x=torch.arange(len(training_loss_aug)), y=accuracy_aug,
                mode='lines', name='Augmented dataset', line=dict(color=C0), visible='legendonly',)
# fig.add_scatter(x=torch.arange(len(validation_loss_aug)), y=validation_loss_aug,
#                 mode='lines', name='Aug. test', line=dict(color=C0, dash='dash'), visible='legendonly',)

# Non-augmented
fig.add_scatter(x=torch.arange(len(training_loss_Naug)), y=accuracy_Naug,
                mode='lines', name='No-Augmented dataset', line=dict(color=C1), visible='legendonly',)
# fig.add_scatter(x=torch.arange(len(validation_loss_Naug)), y=validation_loss_Naug,
#                 mode='lines', name='No-Aug. test', line=dict(color=C1, dash='dash'), visible='legendonly',)

fig.update_layout(
    xaxis_title='Training epoch',
    yaxis_title='Cross-entropy loss',
    template='simple_white',
    legend_title='Dataset',
)
fig.show()

References

Hinton, Geoffrey E, Nitish Srivastava, Alex Krizhevsky, Ilya Sutskever, and Ruslan R Salakhutdinov. 2012. “Improving Neural Networks by Preventing Co-Adaptation of Feature Detectors.” arXiv Preprint arXiv:1207.0580.
Ioffe, Sergey, and Christian Szegedy. 2015. “Batch Normalization: Accelerating Deep Network Training by Reducing Internal Covariate Shift.” In International Conference on Machine Learning, 448–56. pmlr.