Optimizing LLM Performance Using Triton

Matej Sirovatka

2025-02-22

whoami

  • My name is Matej
  • I’m a Master’s student at Brno University of Technology
  • I’m currently working on distributed training at Hugging Face πŸ€—

What is Triton?

  • open-source programming language for GPU kernels by Open AI
  • Designed for AI/ML workloads
  • Simplifies GPU programming compared to CUDA

Why Optimize with Triton?

  • Simple yet effective
  • Less headache than CUDA
  • GPUs go brrrrrrr πŸš€
  • Feel cool when your kernel is faster than PyTorch 😎

Example Problem: KL Divergence

  • commonly used in LLMs for knowledge distillation
  • for probability distributions \(P\) and \(Q\), the Kullback-Leibler divergence is defined as:

\[ D_{KL}(P \| Q) = \sum_{i} P_i \log\left(\frac{P_i}{Q_i}\right) \]

import torch
from torch.nn.functional import kl_div

def kl_div_torch(p: torch.Tensor, q: torch.Tensor) -> torch.Tensor:
    return kl_div(p, q)

How about Triton?

import triton
import triton.language as tl

@triton.jit
def kl_div_triton(
    p_ptr, q_ptr, output_ptr, n_elements, BLOCK_SIZE: tl.constexpr
):
    pid = tl.program_id(0)
    block_start = pid * BLOCK_SIZE
    offsets = block_start + tl.arange(0, BLOCK_SIZE)
    mask = offsets < n_elements
    
    p = tl.load(p_ptr + offsets, mask=mask)
    q = tl.load(q_ptr + offsets, mask=mask)
    
    output = p * (tl.log(p) - tl.log(q))
    tl.store(output_ptr + offsets, output, mask=mask)

How to integrate with PyTorch?

  • How to use our custom kernel with PyTorch autograd?
import torch

class KlDiv(torch.autograd.Function):
    @staticmethod
    def forward(ctx, p, q):
        ctx.save_for_backward(q)
        output = torch.empty_like(p)
        grid = (len(p) + 512 - 1) // 512
        kl_div_triton[grid](p, q, output, len(p), BLOCK_SIZE=512)
        return output

    @staticmethod
    def backward(ctx, grad_output):
        q = ctx.saved_tensors[0]
        # Calculate gradients (another triton kernel)
        return ...

Some benchmarks

  • A KL Divergence kernel that is currently used in Liger Kernel written by @me

Do I have to write everything?

  • TLDR: No
  • Many cool projects already using Triton
  • Better Integration with PyTorch and even Hugging Face πŸ€—
  • Liger Kernel, Unsloth AI, etc.

So how can I use this in my LLM? πŸš€

  • Liger Kernel is a great example, providing examples of how to integrate with Hugging Face πŸ€— Trainer
- from transformers import AutoModelForCausalLM
+ from liger_kernel.transformers import AutoLigerKernelForCausalLM

model_path = "meta-llama/Meta-Llama-3-8B-Instruct"

- model = AutoModelForCausalLM.from_pretrained(model_path)
+ model = AutoLigerKernelForCausalLM.from_pretrained(model_path)

# training/inference logic...

Key Optimization Techniques adapted by Liger Kernel

  • Kernel Fusion
  • Domain-specific optimizations
  • Memory Access Patterns
  • Preemptive memory freeing

Aaand some more benchmarks πŸš€

  • Saving memory is key to run bigger batch size on smaller GPUs

Last benchmark I promise...

  • But is it faster? Yes, it is!

Attention is all you need, so I thank you for yours! πŸ€—