Practical Guide to Transformer Fine-Tuning

Practical Guide to Transformer Fine-Tuning

Transformers have taken an ever-more-important place in our lifes. In this article, we focus on how we can finetune such a model for our needs.

Artificial Intelligence Jul 7, 2025 8 min read

Introduction

The field of machine learning has witnessed a seismic shift in recent years, largely driven by the rise of transformer networks. Originally conceived for natural language processing tasks, transformers, with their innovative attention mechanisms, have dramatically outperformed previous architectures like recurrent neural networks (RNNs) in a wide array of applications - from text generation and translation to image recognition and even protein folding.

Their ability to process information in parallel and capture long-range dependencies has fueled unprecedented performance gains, making them the de facto standard for many cutting-edge AI models like ChatGPTGemini or Claude1 However, a truly exciting aspect of this transformer revolution lies not just in their inherent power, but in the remarkable potential unlocked by finetuning and adapting pre-trained transformer models - a strategy that circumvents the often-difficult and resource-intensive process of optimizing the underlying architecture itself.

You can read more about previous neural network types in this article.

In today’s article, we will be diving into how to finetune a transformer model to fit our needs.

Transformers

What Held Back Previous Models?

Before transformers emerged, recurrent neural networks, such as LSTMs and GRUs, were the standard approach for handling sequence data like text and audio. These networks process data sequentially, essentially “remembering” past information as they go.

However, RNNs face several limitations. 2 3 4

  1. Parallelization: Their sequential processing nature means they are slow and do not easily lend themselves to parallelization, which is using multiple processors simultaneously for increased speed.
  2. Context Length: Furthermore, training RNNs, particularly when dealing with long sequences, can be challenging due to issues with vanishing or exploding gradients - situations where the values used to update the networks parameters become either too small or too large during the learning process.
  3. Effectiveness: Finally, RNNs often struggle to effectively capture relationships between words or elements that are far apart within a sequence, hindering their ability to understand context across longer stretches of data.

What Drives Transformers: The Key Concept

The transformer architecture, introduced in the 2017 paper “Attention is All You Need”, threw out the recurrent connections and relied entirely on a mechanism called attention. The fundamental idea is: instead of processing a sequence step-by-step, let the model directly look at all parts of the input sequence simultaneously to understand the relationships between them.

The observed improvements in translation accuracy, alongside a notable decrease in training costs, were instrumental in the rise of transformers.

Key Components of Transformers

A transformer consists of a variety of elements, which are unheard of in the realms of recurrent neural networks. Therefore, we provide a brief introduction to them here:

  1. Input Embedding: The input sequence (e.g., a sentence) is first converted into numerical representations called embeddings. Each word or token becomes a vector of numbers.
  2. Positional Encoding: Because transformers don’t inherently know the order of words (unlike RNNs), positional encodings are added to the embeddings. These encodings provide information about the position of each word in the sequence.
  3. Encoder: The encoder’s job is to process the input sequence and create a contextualized representation of it. It’s composed of multiple identical layers. Each layer consists of two main sub-layers:
    • Self-Attention: This is the heart of the transformer. It allows the model to weigh the importance of different words in the input sequence when processing a specific word. For instance, when processing the word “it” in the sentence “The cat sat on the mat, and it was fluffy,” the attention mechanism will identify that “it” refers to “cat”. The transformer model can figure this out by looking at the other words.
    • Feed-Forward Neural Network: A standard feed-forward network applies a transformation to the output of the self-attention layer, further refining the representation.
  4. Decoder: The decoder’s job is to generate the output sequence (e.g., a translation). It also consists of multiple identical layers. Each layer contains:
    • Masked Self-Attention: Similar to the encoder’s self-attention, but with a mask that prevents the decoder from “peeking” at future words during training. This ensures the decoder only uses past information to predict the next word.
    • Encoder-Decoder Attention: This allows the decoder to attend to the output of the encoder, allowing it to incorporate information from the entire input sequence.
    • Feed-Forward Neural Network: Another feed-forward network refines the output.
  5. Output Layer: The final layer of the decoder generates the output sequence, typically by predicting the probability of each word in the vocabulary at each position.

To help you understand further, we have included a short overview of the components.

Component Function Key Idea
Input Embedding Converts words to numerical vectors Representation
Positional Encoding Adds information about word order Order Matters
Encoder Processes input sequence Contextual Understanding
Self-Attention Weights importance of words Relationships Between Words
Decoder Generates output sequence Prediction
Encoder-Decoder Attention Connects encoder & decoder Information Flow
If you are interested in learning more about transformers, you can find a really good visualization over here.

Example

We will now showcase the capabilities of transformers using the DistilBertForSequenceClassification model. This model will be applied to a binary classification task: determining whether short messages are spam or not. The dataset for this task can be downloaded from Kaggle.

More information about the model we use can be found here.

Implementation

We begin by including all necessary imports into our python file.

from transformers import AutoTokenizer, DistilBertForSequenceClassification

import torch
import torch.nn as nn
import torch.nn.functional as F

import pandas as pd
import numpy as np

import math

transformer-finetuning.py

To prepare our dataset for model training, we first convert the labels into a numerical representation. The ham label is assigned a value of zero, while the spam label is assigned a value of one. Following this, we perform a train-test split, using a custom function to divide the data into training and testing sets.

def prepareDataset(filepath):
    dataset = pd.read_csv(filepath)

    labels = dataset.iloc[:, 0]
    labels = labels.replace("ham", 0)
    labels = labels.replace("spam", 1)
    labels = torch.LongTensor(labels).to("cuda")

    texts = dataset.iloc[:, 1].tolist()

    return texts, labels

def splitDataset(texts, labels, trainingPercentage=0.8):
    latestIndex = math.floor(len(texts) * trainingPercentage)

    trainingTexts = texts[0:latestIndex]
    trainingLabels = labels[0:latestIndex]

    testingTexts = texts[latestIndex + 1:len(texts) - 1]
    testingLabels = labels[latestIndex + 1:len(labels) - 1]

    return trainingTexts, trainingLabels, testingTexts, testingLabels

transformer-finetuning.py

Following data preparation, we implemented a custom dataset batcher to organize the data into batches suitable for model training.

This batcher encapsulates input IDs, attention masks, and labels. To ensure consistent tensor lengths, the tokenizer is used to generate the input IDs and attention masks before the training process begins. Failing to perform this tokenization step will prevent training due to length mismatches.
class DatasetBatcher():
    def __init__(self, inputIds, attentionMasks, labels):
        self.inputIds = inputIds
        self.attentionMasks = attentionMasks
        self.labels = labels

    def getBatch(self, samplesAmount):
        randomIndicies = np.random.randint(low=0, high=self.inputIds.size()[0] - 1, size=samplesAmount)
        return self.inputIds[randomIndicies], self.attentionMasks[randomIndicies], self.labels[randomIndicies]

transformer-finetuning.py

Nextup, we utilize a One-Hot-Converter for the labels. This is necessary for loss computation during training, but can be substituted with other methods for the conversion.

class ToOneHot(nn.Module):
    def __init__(self, numClasses):
        super(ToOneHot, self).__init__()
        self.numClasses = numClasses
    def forward(self, x):
        return F.one_hot(x, num_classes=self.numClasses)

transformer-finetuning.py

Having completed the supporting code, we now focus on the binary classifier class. This class inherits from nn.Module to access core neural network features provided by PyTorch. The distilbert-base-uncased model is integrated, and configured to have two output labels, which directly correspond to the ‘ham’ and ‘spam’ labels in the dataset.

class BinaryClassifier(nn.Module):
    def __init__(self):
        super(BinaryClassifier, self).__init__()

        # Load the pre-trained DestilBERT model and tokenizer
        self.model = DistilBertForSequenceClassification.from_pretrained("distilbert-base-uncased", num_labels=2)
    
    def forward(self, input_ids, attention_mask):
        outputs = self.model(input_ids=input_ids, attention_mask=attention_mask)

        return outputs.logits

transformer-finetuning.py

We then begin our implementation of the training handler class, which encapsulates all model interactions for us and holds important parameters in memory.

class TrainingHandler():
    def __init__(self, model, tokenizer, epochs, batchSize, numberBatches):
        self.model = model
        self.tokenizer = tokenizer
        self.epochs = epochs
        self.batchSize = batchSize
        self.numberBatches = numberBatches

        self.datasetPath = "spam.csv"
        self.splitFactor = 0.8

        self.optimizer = torch.optim.Adam(self.model.parameters(), lr=5e-5)
        self.criterion = nn.CrossEntropyLoss()

        self.OneHotConverter = ToOneHot(2)
        self.trainingLoss, self.testingAccuracy = [], []

    def retrieveDataset(self):
        self.texts, self.labels = prepareDataset(self.datasetPath)
        trainingTexts, trainingLabels, testingTexts, testingLabels = splitDataset(self.texts, self.labels, self.splitFactor)

        # Prepare Training Dataset
        encodedTrainingTexts = self.tokenizer(trainingTexts, padding=True, return_tensors="pt")
        trainingInputIds = encodedTrainingTexts['input_ids'].to("cuda")
        trainingAttentionMask = encodedTrainingTexts['attention_mask'].to("cuda")
        self.trainingBatcher = DatasetBatcher(trainingInputIds, trainingAttentionMask, trainingLabels)

        # Prepare Testing Dataset
        encodedTestingTexts = self.tokenizer(testingTexts, padding=True, return_tensors="pt")
        testingInputIds = encodedTestingTexts['input_ids'].to("cuda")
        testingAttentionMask = encodedTestingTexts['attention_mask'].to("cuda")
        self.testingBatcher = DatasetBatcher(testingInputIds, testingAttentionMask, testingLabels)

transformer-finetuning.py

The current training handler is missing the essential functionality to train the model. We will now implement this crucial step. The training process for each iteration can be broken down into the following steps:

  1. Retrieve a batch of data from the training batcher.
  2. Reset the accumulated gradient values.
  3. Generate model predictions and evaluate them using the defined loss function.
  4. Perform the backpropagation step to update model parameters.
    def trainingStep(self):
        # Set Model to Training Mode
        self.model.train()
        totalLoss = 0

        for _ in range(self.numberBatches):
            # Retrieve Batch
            currentIds, currentAttentionMask, currentLabels = self.trainingBatcher.getBatch(self.batchSize)
            currentLabels = self.OneHotConverter(currentLabels).to(torch.float)

            # Reset Gradients
            self.optimizer.zero_grad()

            # Generate Outputs and Evaluation
            outputs = self.model(currentIds, currentAttentionMask)
            loss = self.criterion(outputs, currentLabels)

            totalLoss += loss.item()

            # Backpropagation Step
            loss.backward()
            self.optimizer.step()

        return totalLoss

transformer-finetuning.py

We implement the same process for our testing data:

    def testingStep(self):
        # Set Model to Evaluation Mode
        self.model.eval()
        totalCorrect = 0

        for _ in range(self.numberBatches):
            currentIds, currentAttentionMask, currentLabels = self.testingBatcher.getBatch(self.batchSize)
            currentLabels = self.OneHotConverter(currentLabels).to(torch.float)

            # Generate Model Output
            outputs = self.model(currentIds, currentAttentionMask)
            outputs = torch.argmax(outputs, dim=1)
            argmaxLabels = torch.argmax(currentLabels, dim=1)

            # Generate Similarity Metrics
            comparisonOutcome = torch.eq(argmaxLabels, outputs)
            totalCorrect += torch.sum(torch.tensor(comparisonOutcome, dtype=torch.int)).item()

        percentageCorrect = totalCorrect / (self.numberBatches * self.batchSize)
        return percentageCorrect

transformer-finetuning.py

Lastly, we can combine both training and testing methods into a single encapsulated method, in order to simplify access from outside the class.

    def trainingEpochs(self):
        for i in range(self.epochs):
            # Perform Training Step
            loss = self.trainingStep()

            print(f"Training Loss in Epoch {i + 1}: {loss}")
            self.trainingLoss.append(loss)

            # Perform Evaluation Step
            accuracy = self.testingStep()

            print(f"Testing Accuracy in Epoch {i + 1}: {accuracy}")
            self.testingAccuracy.append(accuracy)
        
        return self.trainingLoss, self.testingAccuracy

transformer-finetuning.py

Now that we have concluded our implementation, we can begin testing it out!

Results

To evaluate the training process, we ran ten independent training runs and averaged the results. We begin by presenting the loss history:

The rapid decrease in loss suggests that the model is effectively learning from the training data.

Next, we examine the testing accuracy during training:

The increasing mean accuracy indicates that the model is converging. This concludes our demonstration that the implementation functions as expected.

Note: Due to the limited size of the dataset, convergence was achieved relatively easily. With smaller datasets, there is a greater risk of overfitting, causing the model to specialize rather than generalize. For production use, we recommend utilizing a dataset with a significantly larger number of data points.

TL;DR

This article explores the reasons behind the widespread adoption of transformers, which have largely superseded recurrent neural networks, and details their significant advantages. We then examine the individual components of transformer models and their function within the overall architecture. Finally, we present an implementation for fine-tuning a DistilBERT model on a binary classification task, demonstrating its effectiveness and acknowledging its limitations.

Sources

  1. wikipedia.org
  2. baeldung.com
  3. medium.com
  4. ibm.com