What’s Next in AI Inference?

As model providers look to compete on price, inference optimization is key.

Inference cost matters. GPUs are expensive and transformer inference is notoriously inefficient. As model providers look to compete on price, inference optimization is a key lever to pull.

BCV Labs regularly brings researchers and engineers together to discuss the latest optimizations and improvements. We recently met with 15 AI researchers from Microsoft, OpenAI, Anthropic and Databricks to discuss the latest in inference optimization. Here’s what we learned.

Key Takeaways

  • Transformers inference is memory bound. The fundamental constraint in AI Inference is memory. Whereas training is largely bound by GPU FLOPS (compute bound), inference is almost always memory bound. This is largely caused by KV caching strategies used in calculating the attention matrix. Cache size becomes massive quickly, for the Llama 7B model it is about 4.3 GB and for the 65B model it’s 21.5 GB.
  • Selection is key to sparsity. One solution to inference memory constraints is sparsity and cache eviction. Strategically skipping certain calculations and retrieval sequencing is critical. The movement toward sparsity takes inspiration from BERT, which showed that only about 10% of the key and query values in the Q*K’ attention calculation contributes meaningfully to the final attention value. Two techniques, FastGen and Sparse Sharding Attention, propose different heuristics for choosing when to compute and store versus ignore or evict.
  • Mixture of Experts (MoE) has sparsity built in. Defined as the E/K ratio, only the top few expert “mini-models”, K, of the total number of experts, E, run on a given input. This benefits in the pre-fill stage and is especially pronounced for large context windows. As demand for larger context windows grows, we might see high E/K ratio (more sparse) MoE models become increasingly popular.

Sparsity is King

Caches have long been used to accelerate data retrieval, storage, and transformation. LLMs are no different. KV Caching is a critical speed up used in calculating the attention matrix for a given input at inference time. It caches important elements of the Key and Value matrices and removes redundant calculations by only calculating attention for the new token in the sequence. There is one problem: caching all of these elements requires a massive amount of memory. As a result, LLM inference is “memory-bound,” or constrained by the memory available on the GPU–not by compute. We learned about the techniques researchers are using to tackle this, namely sparsity: removing as many redundant calculations as possible without harming accuracy.

Selective sharding reduces memory constraints in Llama models

The first paper we discussed is FastGen. The initial inspiration for this work came from BERT. It’s insight was that only about 10% of the elements in the Query * KeyT part of the attention calculation contributes meaningfully to the final attention value. The trouble is finding which 10% matters. In Llama, the selective importance of elements varies across attention heads, but follows a pattern: some heads index higher on special tokens, some on punctuation, and others on local context. Instead of running all input tokens through every head, FastGen calculates accurate attention values by choosing only selected, attention-head relevant input tokens for a given head and preemptively evicting the others from the cache. The outcome is a 40% reduction in memory for Llama models with no loss in accuracy.

Sharding can be generalized to reduce cache utilization in all models

The second paper is Sparse Sharding Attention (SSA), built on a similar premise that input sequence elements relevant to attention computation vary across attention heads. But, instead of building in sparsity unique to a specific model’s head patterns, it proposes a more generalizable approach: an input token selection invariant that holds across models. By splitting input in a fixed pattern across attention heads (with each head only operating on a subset of input elements), SSA shows similarly reduced compute and memory while maintaining accuracy. The results held true across multiple Llama-type models. Improvements over Flash Attention 2 increase as context window size increases.

Mixture of Experts models may become more popular as demand for larger context windows grows

Mixture of Experts (MoE) models use a collection of small “expert” models to generate output. For any given input a central router chooses which subset of these experts will be activated to compute the response. The total number of these experts E, divided by the number that are used for any given inference, K, is what’s known as an E/K ratio. The E/K ratio measures sparsity, with a higher E/K ratio suggesting increased efficiency at the cost of accuracy. Today, the industry generally centers on an E/K ratio of 4.

Longer context windows and higher batch sizes might push E/K ratios up and in turn make MoEs more popular. MoEs benefit from reduced prefill latency, which becomes especially pronounced at larger context windows. They are less GPU efficient at lower batch sizes, but as traffic volume and batch sizes increase in real-world use cases that is becoming less of an issue.


What’s Next?

Inference cost and efficiency is shaping how companies approach model development and downstream AI application development. Much of this is caused by working around hardware constraints. So far, NVIDIA GPUs have balanced improvements in both memory size and bandwidth, taking a middle ground between training and inference workloads. As production demands grow, we could see more dedicated hardware or clusters optimized for either task.

In the meantime we’ll likely continue to see research pushing sparsity and efficient caching to its theoretical limit.