, ,

TPUs vs GPUs: When to Choose What for AI/ML Workloads

TPU vs GPU for AI/ML workloads: silicon architecture, JAX vs PyTorch fit, H100 pricing, spot automation, and total cost of ownership. A practical decision framework for ML infrastructure teams running training and inference at scale.

Roberto Pesce Avatar

Your team gets approved to train a 70B parameter model. The H100 reservation you submitted three months ago is still 6 months out. Someone on Slack drops a link to TPU v5p pricing on GCP. Now you’re reading Google’s XLA documentation at midnight, trying to figure out whether your team can adopt JAX in a quarter. This is the decision most ML infra engineers don’t want to make under pressure, but that’s usually when it arrives.

This post covers the TPU vs GPU decision as an infrastructure engineer would: silicon architecture first, then use-case fit, then total cost of ownership. By the end, you’ll have a concrete framework for making the call before the compute crisis forces it on you.

The Architecture Divide

GPUs and TPUs solve the same matrix-heavy problem through different hardware designs. The difference shapes which workloads each handles well, and which it handles poorly.

GPUs pack thousands of CUDA cores, each a general-purpose arithmetic logic unit, alongside dedicated Tensor Cores for mixed-precision matrix operations. The NVIDIA H100 SXM has 16,896 CUDA cores and 528 fourth-generation Tensor Cores. That design handles arbitrary computation: branching, elementwise ops, irregular memory access. It’s why CUDA became the de facto programming model for ML. The H100 delivers 989 BF16 TFLOPS dense (1,979 TFLOPS with structured sparsity) with 80GB HBM3 at 3.35 TB/s memory bandwidth. NVLink connects GPUs within a node at 900 GB/s; InfiniBand NDR handles cross-node traffic at 400 Gbps. The A100 SXM runs 312 BF16 TFLOPS at 2.0 TB/s. For cost-efficient inference, the L4 provides 24GB GDDR6 at roughly $0.70-1.00/GPU-hr on-demand. The H200 SXM5 (production 2025) extends the H100 with 141GB HBM3e at 4.8 TB/s. This matters for KV cache-intensive inference where 80GB becomes the binding constraint.

Distributed training typically achieves 30-50% of theoretical peak FLOPS due to memory bandwidth bottlenecks, collective communication overhead, and pipeline bubbles. MFU (Model FLOPS Utilization) is what matters for production planning. The spec sheet is a ceiling, not an expectation.

TPUs take the opposite approach. The compute primitive is a Matrix Multiply Unit (MXU), implemented as a systolic array. Pre-v6e MXUs are 128×128; v6e and TPU7x (Ironwood) use 256×256. Each pre-v6e MXU executes 16,384 multiply-accumulate operations per cycle. The 256×256 MXUs in v6e and Ironwood execute 65,536 per cycle, quadrupling the throughput. The key point: data flows through the systolic array without memory access during the matrix multiply itself. That design is why TPUs sustain higher raw matrix throughput than GPUs at comparable chip area. TPU7x (Ironwood) reaches 2,307 BF16 TFLOPS per chip with 192GB HBM at 7,380 GiB/s bandwidth. The tradeoff is rigid: MXUs handle only matrix operations. No branching. JAX’s Pallas kernel language supports custom ops but is far more constrained than CUDA Triton. Arbitrary CUDA-style control flow is not available.

The interconnect story also diverges. H100 nodes rely on NVLink within a node and InfiniBand across nodes, requiring expensive networking switches at scale. TPUs use ICI (Inter-Chip Interconnect), a direct nearest-neighbor topology arranged in a 2D/3D torus. TPU v5p delivers 4,800 Gbps (600 GB/s) aggregate unidirectional ICI bandwidth per chip across 6 links in a 3D torus, totaling 1.2 TB/s bidirectional per chip. That works out to ~100 GB/s per link per direction. This topology is simpler to wire at pod scale. TPU v5p pods reach 8,960 chips. TPU7x pods scale to 9,216. For model parallelism spanning thousands of chips, ICI’s all-to-all characteristics outperform InfiniBand at comparable cost.

One constraint catches teams by surprise in production: TPUs require static input shapes at XLA compile time. Each unique shape triggers a separate compilation pass. For inference serving with variable-length sequences, this recompilation overhead is prohibitive. GPUs handle dynamic shapes throughout the stack. This is the sharpest practical difference between the two accelerator classes. It alone rules out TPUs for most production inference workloads.

When TPUs Make Sense

TPUs are not a general-purpose GPU replacement. They work best when workload type, framework choice, and cloud commitment all align.

The ideal TPU workload is large-scale pretraining (100B+ parameters) in JAX or MaxText on GCP. Effective batch sizes should be multiples of 128 (or 256 for v6e and Ironwood) to keep MXUs fully utilized. Recommendation and ranking models with massive embedding tables also benefit from SparseCores. It’s a dedicated compute subsystem in newer TPU generations built for sparse embedding lookups.

The economics align on long-running training with committed-use discounts. TPU v5e on-demand is $1.20/chip-hr, dropping to $0.54 on a 3-year CUD. TPU v5p is $4.20/chip-hr on-demand, $1.89 with a 3-year CUD. TPU v6e (Trillium), at 918 BF16 TFLOPS per chip and 32GB HBM, runs $2.70/chip-hr on-demand or $1.22 on a 3-year CUD. TPU7x (Ironwood) is $12.00/chip-hr on-demand, $5.40 with a 3-year CUD. For GCP-committed teams running JAX-native workloads on extended training runs, these numbers are competitive.

The organizations running TPUs at scale have a common profile: Google (Gemini, AlphaFold, Waymo), DeepMind, and academic labs, all standardized on JAX. If your team comes from that lineage, TPUs are a defensible default. If not, treat the framework migration cost as infrastructure investment, not a pricing optimization.

When GPUs Make Sense (Which Is Most of the Time)

The majority of production ML models, including LLaMA, Mistral, Falcon, and most open-source deployments, ship with PyTorch-native code and CUDA kernels. Running these on TPUs via PyTorch-XLA is technically possible but operationally rough. Coverage gaps exist, PyTorch-XLA limits dynamic shape support, and complex CUDA kernels don’t port directly. PyTorch-XLA supports custom ops via torch.library extensions, but optimized CUDA-only kernels require rewriting or elimination. For those models, GPU is the only realistic path without a significant rewrite.

The GPU CUDA ecosystem represents 15+ years of accumulated tooling: cuBLAS, cuDNN, FlashAttention, vLLM, TensorRT-LLM, Triton. The major inference serving stacks (vLLM, TGI, TensorRT-LLM) are GPU-first and optimized for NVIDIA hardware. TPU inference runs through JetStream and Sax, a smaller ecosystem with fewer production deployments behind it. For latency-sensitive serving, GPU infrastructure is more mature. Within the GPU category, NVIDIA dominates the tooling layer. If the H100’s 80GB HBM becomes the binding constraint on KV cache size, AMD’s MI300X is worth evaluating. It offers 192GB HBM3 per GPU and runs on Azure via the ND MI300X v5 series. Its ROCm support has improved across PyTorch and major open-source inference stacks.

Fine-tuning (LoRA, QLoRA, PEFT) is GPU territory. The tooling ecosystem for parameter-efficient fine-tuning doesn’t exist on TPUs. Research and rapid prototyping also favor GPU. Iteration speed depends on framework flexibility: switching optimizers, testing experimental kernels, running ablations before committing to a full training run.

Most important for infrastructure decisions: GPUs run everywhere. The same CUDA codebase runs on AWS, GCP, Azure, and bare metal. Kubernetes GPU operators and K8s Dynamic Resource Allocation work across EKS, GKE, AKS, and self-managed clusters. TPUs are GCP-only. There’s no on-prem option. JAX/XLA code doesn’t port cleanly to other platforms. As multi-cloud GPU infrastructure becomes standard in 2026, that portability gap is a real architectural risk.

TCO: Beyond the Chip Price

Raw chip pricing is the wrong comparison metric. Total cost of ownership includes framework migration, team ramp time, committed-use structure, and whether spot automation is viable.

GPU spot instances cut on-demand compute costs by a wide margin. An AWS p4d.24xlarge (8x A100 40GB) lists at roughly $22/hr on-demand; spot pricing typically runs 50-70% lower. H100 on-demand runs $2.30-3.50/GPU-hr on smaller cloud providers. On AWS and GCP it hits $10-12/GPU-hr. That’s where spot automation delivers the most leverage. The spot discount is available across all major clouds and applies to the same workloads you’re already running. For teams with automated spot interruption handling, GPU spot economics are among the best cost levers in AI infrastructure.

TPUs have no equivalent spot tier with comparable automation. Preemptible TPUs exist, but managing interruptions at pod scale requires more manual overhead than GPU spot automation. CUDs add financial lock-in on top of technical lock-in. Committing to 3-year TPU capacity on GCP is a hard bet on both the platform and the framework.

For teams not already on GCP, the migration cost to adopt TPUs starts with the framework. JAX is not a weekend project. Teams with mid-size PyTorch codebases (20-50k lines) typically report 2-4 engineer-quarters for a full migration. Larger stacks with custom CUDA kernels often take 6-12 months. The XLA compilation model requires static shape discipline throughout the training loop. Custom ops need rewrites or elimination. Your existing observability stack, job schedulers, and autoscaling infrastructure all require TPU-aware modifications. None of this appears in a chip-price comparison.

Running GPU Infrastructure Without Overpaying

The team in the opening scenario had a 6-month H100 waitlist, a PyTorch stack, and no appetite for a framework migration. They chose GPU. That was the right call. But choosing a GPU doesn’t make the cost problem disappear; it changes shape. The question shifts from “TPU or GPU?” to “how do we run this fleet without burning money on capacity we’re not actually using?”

At 1,000 GPU-hours per week, the question “are we on the right instance type?” compounds fast. A training job consuming 30% SM utilization on an A100 but filling 60GB of HBM still needs the A100. Memory is the binding constraint, not compute. Tools like DCGM expose both dimensions at once. Without that distinction, rightsizing decisions are guesswork: you downgrade to a smaller SKU, the job OOMs, you’re back where you started. The delta between what jobs request and what they actually use is often 30-40% wasted GPU capacity at cluster scale. Comparing resource requests against actual DCGM utilization across both compute and memory is what makes rightsizing safe to automate. That distinction separates a genuine SKU mismatch from a memory-bound job that only looks underutilized.

Spot economics are equally straightforward on paper and equally treacherous in practice. An H100 on AWS or GCP runs $10-12/GPU-hr on-demand; spot cuts that 60-70%. The math is obvious. The risk is graceful handling. When a spot instance is reclaimed mid-training without it, checkpoint progress is lost, the job restarts cold, and the “savings” partially evaporate. Making spot viable for production training takes four things: catch the preemption signal early, drain pods before the instance is reclaimed, reschedule onto available capacity, and fall back to on-demand when spot is exhausted. That is exactly the kind of operational plumbing that teams build badly once and inherit forever. Yotpo addressed this directly: 30-40% reduction in Kubernetes compute costs through spot automation, with workloads completing rather than failing on interruption.

Capacity constraints are the third problem, and they’re the one that sends teams reading TPU documentation at midnight. H100 backlogs ran 6-12 months for many cloud customers in 2024-2025. If your workload is on AWS and H100 capacity isn’t available, the options without multi-cloud access are: wait, overpay, or compromise on hardware.

Inference density is a separate lever. GPU time-slicing creates up to 48 virtual replicas per physical GPU. MIG partitioning creates up to 7 isolated instances with dedicated HBM on A100/H100. Both let multiple inference workloads share a single GPU without contention. Note on MIG: reconfiguration requires a node drain and pod eviction cycle. Dynamic profile changes are not live-compatible; plan around inference availability windows. Combined with GPU-optimized bin-packing and K8s Dynamic Resource Allocation, these tools increase inference deployment density without additional hardware spend.

Rightsizing, spot automation, cross-cloud capacity routing, and inference density optimization compound at cluster scale. Akamai reported 40-70% cloud cost reduction against unoptimized on-demand baselines. The pattern holds across organizations. The waste is predictable and teams know how to fix it. Most infrastructure budgets leak in the gap between running GPU and running it efficiently. Cast AI handles this across the full stack: SKU mismatch detection via DCGM, automated spot lifecycle management, and inference bin-packing. No changes to training code or framework choice required.

The Decision Framework

If you’re back at that midnight decision point, here is the short version:

  • Running JAX on GCP, pretraining at 100B+ params, long-term committed: evaluate TPU v5p, v6e (Trillium), or TPU7x (Ironwood).
  • Running PyTorch, any open-source model, any fine-tuning, any custom CUDA work: GPU.
  • Serving inference with variable sequence lengths (vLLM, TGI, TensorRT-LLM): GPU.
  • Multi-cloud, hybrid, or on-prem: GPU.
  • H100 unavailable, backlog unacceptable: pursue multi-cloud GPU access or spot capacity before committing to a platform migration to TPU.

This framing covers the dominant decision path. AWS Trainium2 and Intel Gaudi3 are production alternatives for teams with deep AWS or HPC commitments. The decision criteria are the same: framework compatibility, portability constraints, and total cost including migration overhead.

For most teams, the answer is GPU. Two questions settle the accelerator choice: what framework is your team already running, and how much migration overhead can you absorb. For that team, the answer wasn’t TPU. It was GPU capacity accessible wherever it exists, managed without manual overhead.


New to Cast AI? Start a free trial and connect your first cluster.

Cast AIBlogTPUs vs GPUs: When to Choose What for AI/ML Workloads