When Archer Zhang set out to port his GPU Flash Attention kernel to TPU, he expected a translation job. He ended up with something closer to a full reimplementation, and a detailed account of why two platforms that share the same workloads can differ so sharply in how they want you to code for them.

The core mismatch is mutability. Triton hands developers raw pointer arithmetic and mutable memory — you write where you want, when you want. JAX forbids this entirely. Because JAX traces Python functions into pure computation graphs for XLA compilation, in-place writes would break the model. Flash Attention's canonical stateful loop — running softmax maximums, sums, and accumulators — has to be restructured from scratch. Python for-loops become `jax.lax.fori_loop` constructs; every tensor write becomes a `lax.dynamic_update_slice` call that returns a new array copy. The algorithm is the same. Almost nothing else is.

The hardware gap runs deeper. TPUs carry substantially more on-chip SRAM than a GPU — Zhang compares ~128MB on the TPU v5e against ~164KB of shared memory per streaming multiprocessor on a comparable GPU, though it bears noting these figures sit at different levels of each platform's memory hierarchy. The systolic array architecture pipelines matrix multiply-accumulate operations differently from GPU warp-level parallelism, and Zhang built an emulator to work through the data flow before writing a line of kernel code. His benchmark on a Colab TPU v5e delivered the headline result: standard attention with XLA auto-fusion was already fast enough that his hand-tiled Flash Attention implementation offered no meaningful throughput gain. XLA's compiler was doing the work that manual tiling achieves on GPU. Getting beyond that threshold requires dropping to Pallas, Google's lower-level kernel framework.

The practical implication hits closest to home in long-context inference. Multi-step agentic workloads — tool-calling chains, retrieval-augmented pipelines, extended reasoning loops — are where attention costs concentrate and where providers make hardware-specific optimization calls. Google runs Gemini on TPUs; GPU-tuned kernels don't port over without rethinking. Zhang's post is a worked example of the gap between understanding Flash Attention as an algorithm and knowing how to implement it for a specific target. As inference infrastructure spreads across GPU, TPU, and custom silicon, that gap is widening, and it falls on the engineers writing inference stacks to bridge it.