Study Note: ML Compilation

2025/11/08

This note, mostly generated by Gemini, breifly explores the different compilation paths a simple line of Python like torch.matmul can take to be executed by a specific hardware module like the Tensor Core.

Understanding these different paths is key to understanding why the ML compilation stack exists, what “programmability” really means, and how developers make crucial trade-offs between ease-of-use, flexibility, and raw hardware performance.

Core Problem: The Two-Language Dilemma

The entire ML compilation stack exists to solve one problem: bridging the gap between how humans write AI code and how hardware runs it.

The stack’s job is to translate the dynamic, high-level code into static, low-level instructions.

Path A: The “Eager” Default Path (What happens without torch.compile)

This is the traditional, step-by-step path. It’s simple but has a major bottleneck.

The Workflow:

  1. You (Python): You write torch.matmul(a, b).
  2. ATen Backend (CPU): Your call goes to the PyTorch C++ backend (ATen), which is a pre-compiled binary (.so or .dll) running on your CPU.
  3. Dispatcher (CPU): ATen acts as a “dispatcher.” It sees you’re on a GPU and need a matmul. It doesn’t write a new kernel; it makes a function call to a pre-compiled library like cuBLAS.
    • cuBLAS: For math (e.g., matmul).
    • cuDNN: For neural network layers (e.g., convolution).
  4. cuBLAS (CPU): The cuBLAS library (also on the CPU) looks at your matrix shapes (e.g., 1024x512) and uses a heuristic to select the best pre-compiled matmul kernel (e.g., “Kernel #73”) from its collection.
  5. PTX to SASS (CPU): This kernel is stored in PTX (NVIDIA’s assembly language). The NVIDIA Driver (CPU software) performs a Just-in-Time (JIT) compilation to translate that PTX into the final binary machine code, SASS, that is specific to your exact GPU (e.g., H100).
  6. CUDA Runtime: Throughout this, the ATen/cuBLAS C++ code is using the CUDA Runtime API to manage the GPU. This API is the “middle manager” that sends commands to the driver like:
    • cudaMalloc() (to allocate memory)
    • cudaMemcpy() (to move data)
    • cudaLaunchKernel() (to tell the GPU to run the SASS binary)
  7. Tensor Core (GPU): The GPU’s scheduler finally executes this SASS binary on its Tensor Cores to get the result.

Kernel Launch Overhead:

If your code is torch.relu(torch.matmul(a, b)), this path runs TWO full, separate kernels.

  1. Run matmul kernel, write the entire intermediate result to GPU memory.
  2. Run relu kernel, read that entire result back from memory, apply relu, and write it out again. This memory round-trip is extremely slow.

Path B: The “Compiled” Path (Using @torch.compile)

This is the modern, high-performance path. Its entire goal is to solve the bottleneck of Path A through operator fusion.

The Workflow:

  1. You (Python): You add @torch.compile to your function.
  2. TorchDynamo (Frontend): This is a “tracer.” It runs your dynamic Python code once and captures a static graph of all the PyTorch operations.
    • Symbolic Shapes: It’s smart enough to use “symbols” (like s0 for batch size) so the same compiled graph can work for batch_size=8 or batch_size=16.
    • Graph Breaks: If it sees too dynamic Python (like a complex if statement), it “breaks” the graph, runs that part in normal Python, and then starts a new graph.
  3. TorchInductor (Backend): This is the optimizer. It takes the graph from Dynamo and finds ways to make it faster.
    • Operator Fusion: This is its main job. It sees the matmul -> relu graph and decides to fuse them into one single kernel. This eliminates the memory round-trip.
    • Heuristics: If there are many fusion options (e.g., (a+b) + (c+d)), it uses a scoring system to pick the best fusion that minimizes memory I/O.
    • Why JIT? Combinatorial Explosion. It’s impossible to pre-compile a library of all possible fusions (matmul+relu, matmul+bias+gelu, etc.). TorchInductor JIT-compiles the exact kernel you need, when you need it.
  4. Triton (Compiler): TorchInductor uses Triton to generate this new, fused kernel.
    • Triton is a language (using Python syntax) and a compiler.
    • It is NOT CUDA C++. The Triton compiler uses LLVM (a general-purpose compiler toolchain) to generate PTX directly.
  5. Driver & GPU: From here, the path is the same as before. The driver takes the new, custom-fused PTX code from Triton, compiles it to SASS, and the GPU runs it.

Main Benefit: Performance & Portability

Path C: The “Hand-Tuned” Kernel Path

This is the “speed-of-light” path. It’s not a path most developers take, but it’s crucial for understanding the performance landscape.

The Performance vs. Effort Trade-off

This is the core dilemma for high-performance computing.

Why Do It? When you are a company operating at a massive scale, that “tiny” 5% speedup on your main model saves millions of dollars in hardware and energy costs. The engineering cost is worth it.

It’s not for everyone

For 99.9% of all developers and researchers, Path B is the clear winner. It makes your code fast without requiring you to become an elite GPU kernel developer.