2025-02-22
whoami
What is Triton?
Why Optimize with Triton?
brrrrrrr
πExample Problem: KL Divergence
\[ D_{KL}(P \| Q) = \sum_{i} P_i \log\left(\frac{P_i}{Q_i}\right) \]
How about Triton?
import triton
import triton.language as tl
@triton.jit
def kl_div_triton(
p_ptr, q_ptr, output_ptr, n_elements, BLOCK_SIZE: tl.constexpr
):
pid = tl.program_id(0)
block_start = pid * BLOCK_SIZE
offsets = block_start + tl.arange(0, BLOCK_SIZE)
mask = offsets < n_elements
p = tl.load(p_ptr + offsets, mask=mask)
q = tl.load(q_ptr + offsets, mask=mask)
output = p * (tl.log(p) - tl.log(q))
tl.store(output_ptr + offsets, output, mask=mask)
How to integrate with PyTorch?
import torch
class KlDiv(torch.autograd.Function):
@staticmethod
def forward(ctx, p, q):
ctx.save_for_backward(q)
output = torch.empty_like(p)
grid = (len(p) + 512 - 1) // 512
kl_div_triton[grid](p, q, output, len(p), BLOCK_SIZE=512)
return output
@staticmethod
def backward(ctx, grad_output):
q = ctx.saved_tensors[0]
# Calculate gradients (another triton kernel)
return ...
Some benchmarks
Do I have to write everything?
So how can I use this in my LLM? π
- from transformers import AutoModelForCausalLM
+ from liger_kernel.transformers import AutoLigerKernelForCausalLM
model_path = "meta-llama/Meta-Llama-3-8B-Instruct"
- model = AutoModelForCausalLM.from_pretrained(model_path)
+ model = AutoLigerKernelForCausalLM.from_pretrained(model_path)
# training/inference logic...
Key Optimization Techniques adapted by Liger Kernel
Aaand some more benchmarks π
Last benchmark I promise...
Attention is all you need, so I thank you for yours! π€