A Step by Step Guide to LLM Distillation

Jackson Stokes

Sep 18, 2024

In our last post, we introduced model distillation, and why you may want to distill your model for efficient task-specific deployment. 

Today, we’ll dive into some of the technical details behind distillation, and how you can play around with the process yourself. This guide hopes to provide a detailed, practical walkthrough of the distillation process, considering both platform based solutions, and sample code if you choose to perform the distillation directly on your hardware. 

Distillation Platforms

At Proxis, we’re building the most simple LLM distillation and fine-tuning platform in existence. Our customers use us to quickly build state-of-the-art models for their task, and seamlessly deploy those models into production with our serverless backend. 

While we’re not quite ready to open to the public, we are currently running a closed beta for clients looking to distill and fine tune models on their data. If interested, please fill out our closed-beta invitation request form, and we’ll get back to you within a day if approved.

Distilling a model from scratch

If you’d prefer to set up your own distillation pipeline on your own compute infrastructure, we’ll provide some step by step instructions to 

Step 1: Setting Up the Environment

Import Libraries

import torch
from transformers import AutoTokenizer, AutoModelForSequenceClassification
from datasets import load_dataset

Check Device

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Using device: {device}')

Step 2: Preparing the Dataset

Choose a Task and Dataset

For this guide, we'll use the IMDb movie reviews dataset for sentiment analysis.

from datasets import load_dataset
dataset = load_dataset('imdb')

Preprocess the Data

tokenizer = AutoTokenizer.from_pretrained('bert-base-uncased')

def preprocess(example):
    return tokenizer(example['text'], truncation=True, padding='max_length', max_length=128)

encoded_dataset = dataset.map(preprocess, batched=True)

Create DataLoaders

from torch.utils.data import DataLoader
  
train_dataset = encoded_dataset['train'].shuffle(seed=42).select(range(2000))  # Subset for quick training
test_dataset = encoded_dataset['test'].shuffle(seed=42).select(range(500))
train_loader = DataLoader(train_dataset, batch_size=16)
test_loader = DataLoader(test_dataset, batch_size=16)

Step 3: Selecting the Teacher Model

Choose a pre-trained model suitable for your task.

teacher_model_name = 'bert-base-uncased'
teacher_model = AutoModelForSequenceClassification.from_pretrained(teacher_model_name).to(device)

Step 4: Designing the Student Model

Customize the Student Model

Reduce the number of layers and hidden dimensions to create a smaller model.

from transformers import BertConfig, BertForSequenceClassification
  
student_config = BertConfig.from_pretrained(teacher_model_name)
student_config.num_hidden_layers = 4  # Reduce layers from 12 to 4
student_config.hidden_size = 256      # Reduce hidden size from 768 to 256
student_config.num_attention_heads = 4  # Reduce attention heads from 12 to 4
student_model = BertForSequenceClassification(student_config).to(device)

Initialize Student Model Weights

Optionally, initialize student model weights from the teacher model.

teacher_state_dict = teacher_model.state_dict()
student_state_dict = student_model.state_dict()

# Copy matching parameters
for name in student_state_dict.keys():
    if name in teacher_state_dict and student_state_dict[name].shape == teacher_state_dict[name].shape:
        student_state_dict[name] = teacher_state_dict[name]
student_model.load_state_dict(student_state_dict)

Step 5: Implementing the Distillation Process

Define Loss Functions

  • Hard Loss: Between student predictions and true labels.

  • Soft Loss: Between student and teacher predictions.

import torch.nn as nn

hard_loss_fn = nn.CrossEntropyLoss()
soft_loss_fn = nn.KLDivLoss(reduction='batchmean')

Set Hyperparameters

temperature = 2.0  # Softens probability distribution
alpha = 0.5        # Balances hard and soft losses

Prepare Optimizer and Scheduler

from transformers import AdamW, get_linear_schedule_with_warmup
  
optimizer = AdamW(student_model.parameters(), lr=5e-5)
num_epochs = 3
total_steps = len(train_loader) * num_epochs
scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=0, num_training_steps=total_steps)

Step 6: Training the Student Model

Training Loop

for epoch in range(num_epochs):
    student_model.train()
    total_loss = 0
    for batch in train_loader:
        inputs = {k: v.to(device) for k, v in batch.items() if k in tokenizer.model_input_names}
        labels = batch['label'].to(device)
        
        # Teacher predictions
        with torch.no_grad():
            teacher_outputs = teacher_model(**inputs)
        
        # Student predictions
        student_outputs = student_model(**inputs)
        
        # Calculate losses
        hard_loss = hard_loss_fn(student_outputs.logits, labels)
        
        # Soften logits
        teacher_logits = teacher_outputs.logits / temperature
        student_logits = student_outputs.logits / temperature
        
        soft_loss = soft_loss_fn(
            nn.functional.log_softmax(student_logits, dim=-1),
            nn.functional.softmax(teacher_logits, dim=-1)
        ) * (temperature ** 2)
        
        # Combined loss
        loss = alpha * hard_loss + (1 - alpha) * soft_loss
        total_loss += loss.item()
        
        # Backpropagation
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        scheduler.step()
    
    avg_loss = total_loss / len(train_loader)
    print(f'Epoch {epoch + 1}/{num_epochs}, Loss: {avg_loss:.4f}')

Step 7: Evaluating the Distilled Model

Evaluation Function

def evaluate(model, loader):
    model.eval()
    total_correct = 0
    total_examples = 0
    with torch.no_grad():
        for batch in loader:
            inputs = {k: v.to(device) for k, v in batch.items() if k in tokenizer.model_input_names}
            labels = batch['label'].to(device)
            outputs = model(**inputs)
            predictions = torch.argmax(outputs.logits, dim=-1)
            total_correct += (predictions == labels).sum().item()
            total_examples += labels.size(0)
    accuracy = total_correct / total_examples
    return accuracy

Evaluate Teacher and Student Models

teacher_accuracy = evaluate(teacher_model, test_loader)
student_accuracy = evaluate(student_model, test_loader)
print(f'Teacher Model Accuracy: {teacher_accuracy * 100:.2f}%')
print(f'Student Model Accuracy: {student_accuracy * 100:.2f}%')

Compare Model Sizes

teacher_params = sum(p.numel() for p in teacher_model.parameters())
student_params = sum(p.numel() for p in student_model.parameters())
print(f'Teacher Model Parameters: {teacher_params / 1e6:.2f}M')
print(f'Student Model Parameters: {student_params / 1e6:.2f}M')

Measure Inference Time

import time

def measure_inference_time(model, loader):
    model.eval()
    start_time = time.time()
    with torch.no_grad():
        for batch in loader:
            inputs = {k: v.to(device) for k, v in batch.items() if k in tokenizer.model_input_names}
            model(**inputs)
    end_time = time.time()
    return end_time - start_time

teacher_time = measure_inference_time(teacher_model, test_loader)
student_time = measure_inference_time(student_model, test_loader)
print(f'Teacher Inference Time: {teacher_time:.2f}s')
print(f'Student Inference Time: {student_time:.2f}s')

Step 8: Fine-Tuning and Deployment

Fine-Tuning (Optional)

If performance is not satisfactory, consider fine-tuning:

  • Adjust hyperparameters (e.g., alpha, temperature, learning rate).

  • Increase the size of the training dataset.

  • Use data augmentation techniques.

Save the Student Model

 student_model.save_pretrained('distilled_student_model')
tokenizer.save_pretrained('distilled_student_model')

Deploy the Model

You can deploy the model using frameworks like FastAPI or Flask.

from fastapi import FastAPI, Request
import uvicorn
  
app = FastAPI()
# Load the model and tokenizer
model = AutoModelForSequenceClassification.from_pretrained('distilled_student_model').to(device)
tokenizer = AutoTokenizer.from_pretrained('distilled_student_model')
@app.post("/predict")

async def predict(request: Request):
    data = await request.json()
    inputs = tokenizer(data['text'], return_tensors='pt', truncation=True, padding='max_length', max_length=128).to(device)
    outputs = model(**inputs)
    prediction = torch.argmax(outputs.logits, dim=-1).item()
    return {'prediction': prediction}

if __name__ == "__main__":
    uvicorn.run(app, host="0.0.0.0", port=8000)


Go Forth

Distillation is a crucial step in productionizing your models. We hope this post helps you understand the steps you can take yourself to distill your models, or points you in the right direction towards an efficient distillation platform such as Proxis.