2025-02-22
whoamiWhat is Triton?Why Optimize with Triton?brrrrrrr πExample Problem: KL Divergence\[ D_{KL}(P \| Q) = \sum_{i} P_i \log\left(\frac{P_i}{Q_i}\right) \]
How about Triton?import triton
import triton.language as tl
@triton.jit
def kl_div_triton(
    p_ptr, q_ptr, output_ptr, n_elements, BLOCK_SIZE: tl.constexpr
):
    pid = tl.program_id(0)
    block_start = pid * BLOCK_SIZE
    offsets = block_start + tl.arange(0, BLOCK_SIZE)
    mask = offsets < n_elements
    
    p = tl.load(p_ptr + offsets, mask=mask)
    q = tl.load(q_ptr + offsets, mask=mask)
    
    output = p * (tl.log(p) - tl.log(q))
    tl.store(output_ptr + offsets, output, mask=mask)How to integrate with PyTorch?import torch
class KlDiv(torch.autograd.Function):
    @staticmethod
    def forward(ctx, p, q):
        ctx.save_for_backward(q)
        output = torch.empty_like(p)
        grid = (len(p) + 512 - 1) // 512
        kl_div_triton[grid](p, q, output, len(p), BLOCK_SIZE=512)
        return output
    @staticmethod
    def backward(ctx, grad_output):
        q = ctx.saved_tensors[0]
        # Calculate gradients (another triton kernel)
        return ...Some benchmarks

Do I have to write everything?

So how can I use this in my LLM? π- from transformers import AutoModelForCausalLM
+ from liger_kernel.transformers import AutoLigerKernelForCausalLM
model_path = "meta-llama/Meta-Llama-3-8B-Instruct"
- model = AutoModelForCausalLM.from_pretrained(model_path)
+ model = AutoLigerKernelForCausalLM.from_pretrained(model_path)
# training/inference logic...Key Optimization Techniques adapted by Liger KernelAaand some more benchmarks π

Last benchmark I promise...
Attention is all you need, so I thank you for yours! π€
