HomePortfolioPosts

Saving VRAM with Cut Cross Entropy by Apple

deep learning
math
triton

May 16, 2025 · 16 min read

This is the source code that's referenced throughout this article. It might be helpful to have this open in another tab: GitHub - Apple ML Cut Cross Entropy

Memory-Efficient Cross-Entropy is a Game Changer

Cross Entropy is a ubiquitous loss function in deep learning, including the training of LLMs (next token prediction is a classification task). However, computing the cross entropy loss can be memory-intensive, as vocab sizes in LLMs can be enormous. As noted in Apple's CCE paper, cross-entropy loss can account for 40–90% of total GPU memory usage. By reducing the memory footprint of cross-entropy, we can significantly decrease the hardware requirements for training LLMs!

Quick Recap: Cross Entropy

If you are unfamiliar with the math, check out derivation of cross-entropy loss. But here is the structure in brief:

L=logezyikezk=zyi+logkezk\mathcal{L} = - \log \frac{e^{z_{y_i}}}{\sum_{k} e^{z_k}} = - z_{y_i} + \log \sum_{k} e^{z_k}

Cross-entropy can be broken into two terms:

  1. The negative dot product: zyi- z_{y_i}
  2. The log sum exponent (LSE): logkezk\log \sum_{k} e^{z_k}

The general goal of the whole "cut cross entropy" approach is to avoid direct computation of the logits (zk=CTEz_k = C^T E) naively, as CTC^T, the linear classifier, projects the embeddings EE onto an extremely large vocabulary space, which means working with very large logits (which we still have to compute further terms of, like softmax).

Overview of Cut Cross Entropy

Cut Cross Entropy saves memory by doing two things:

  1. Chunking the logits into smaller, sorted chunks (which already reduces peak memory usage)
  2. "Filtering" (although I think this is better described as "skipping" or "early stopping") chunks of near-zero logits in the gradient computation step (this is the "cutting" which which saves even more memory and computation)

The core innovation is not a new loss function, but a restructured computation: by chunking the logits and then selectively skipping over unnecessary chunks.

The “cutting” process can achieve memory savings of over 1000 times compared to traditional cross entropy implementations.


Breakdown: Forward Pass

Inputs:

  • Query Embeddings (batch size×dim)(\text{batch size} \times \text{dim})
  • Classifier embeddings (dim×vocab)(\text{dim} \times \text{vocab})
  • Bias vector (vocab×1)(\text{vocab} \times 1)
  • Filtering parameters

Notably, the logits have not been computed yet.


Key operations:

  1. Compute logit statistics
    • Compute the mean values across logits (logit_avg). This is later used for filtering.
  2. Main kernels:
    • cce_lse_forward_kernel: A triton kernel that computes the log-sum-exp over selected logits, returning both the LSE (and logit_avg).
    • indexed_dot_forward_kernel: A triton kernel that computes the negative dot product with the correct class embedding.
  3. Aggregate loss
    • The two terms are summed to compute the final loss.
    • Supports standard reductions (mean, sum).
  4. Save tensors for backward pass
    • Includes embeddings, targets, logits, and stats necessary for gradient computation.

Forward Kernel

Between the kernels, the cce_lse_forward_kernel is more interesting, so lets down its triton implementation.


Skipping over the offset and pointer intitialisation, we hit the first math operation:

triton

accum = tl.zeros((BLOCK_B, BLOCK_V), dtype=tl.float32)
for d in range(0, tl.cdiv(D, BLOCK_D)):
   e_mask = offs_b[:, None] < BMax
   if not EVEN_D:
      e_mask = e_mask & (offs_d[None, :] < (D - d * BLOCK_D))

   e = tl.load(e_ptrs, mask=e_mask, other=0.0)

   c_mask = offs_v[None, :] < V
   if not EVEN_D:
      c_mask = c_mask & (offs_d[:, None] < (D - d * BLOCK_D))

   c = tl.load(c_ptrs, mask=c_mask, other=0.0)

   accum = tl.dot(e, c, accum, input_precision=DOT_PRECISION)

   e_ptrs += BLOCK_D * stride_ed
   c_ptrs += BLOCK_D * stride_cd

Here we are computing accum: the dot product of the query and class embeddings (represented generally by e and c). This is done in chunks of BLOCK_D. By chunking this computation of CTEC^T E and doing it accumulatively, we avoid loading the entire intermediate sum into memory. That's a footprint saving of O(D)O(D).

Our kernel retrieves the value xix_i, the xix_i-th column from C, and the ii-th column from E, and stores them in on-chip shared memory (SRAM). It then performs a dot product between CxiC_{x_i} and EiE_{i} and writes the result into global memory. The kernel uses only on-chip SRAM throughout and does not allocate any GPU memory. For efficiency, we perform all operations blockwise to make the best use of GPU cache structure

After that, comes computation of the linear-log-sum-exponent (LSE).

triton

this_mx = tl.max(logits, axis=1)
e = tl.exp(logits - this_mx[:, None])
this_lse = this_mx + tl.log(tl.sum(e, axis=1))

# ...

this_locks = Locks + (pid_b // tl.cdiv(B, BLOCK_B * num_locks))
while tl.atomic_cas(this_locks, 0, 1) == 1:
   pass

lse = tl.load(lse_ptrs, mask=o_mask, other=0.0, eviction_policy="evict_last")
lse = tl_logaddexp(lse, this_lse)
tl.store(lse_ptrs, lse, mask=o_mask, eviction_policy="evict_last")

tl.debug_barrier()
tl.atomic_xchg(this_locks, 0)

There are 3 main steps here:

  1. Compute the "partial LSE": the local LSE for each batch, using the (BLOCK_B, BLOCK_V) chunk of logits
  2. Use a spin lock to atomically update the global LSE
    • tl.atomic_cas is used to implement a spin lock.
    • The lock index, as defined by (pid_b // tl.cdiv(B, BLOCK_B * num_locks)), causes multiple pid_b blocks to share the same lock. Not shown in all this code is that pid_v blocks that correspond to the same pid_b (same block, but different vocabulary slices) will share the same pid_b lock.
    • Yeah, this might require a writing out or starring at the code to catch.
    • tl;dr: we are only allowing one block/"vocabulary chunk" to update the global LSE at a time.
  3. Store the global LSE back to global memory
    • lse = tl.load(lse_ptrs...) loads the current global value of LSE. This is intialized as -inf (because e=0e^{-\infty} = 0).
    • lse = tl_logaddexp(lse, this_lse) is the core aggregation step. By repeatedly computing log(exp(lse)+exp(this_lse))\log(\exp(\text{lse}) + \exp(\text{this\_lse})), different blocks/vocabulary chunks will combine their partial LSEs into a global LSE.
    • tl.store(lse_ptrs, lse, mask=o_mask, eviction_policy="evict_last") stores the updated global LSE back to global memory.
    • tl.atomic_xchg(this_locks, 0) releases the lock, allowing other threads to update the global LSE.

In short: to ensure synchronicity, we use a spin lock so that only one vocabulary chunk can update the global LSE at a time. Each block computes a partial LSE, which is then merged into the global one. This is an aggregation step to get our final result, the LSE\text{LSE} term.

Breakdown: Backward Pass

The backward pass mirrors the forward structure, but with added complexity for gradient scaling and filtering. This backward pass needs to compute at least two gradient terms:

  • the gradient with respect to query embeddings, ET\nabla E^T
  • the gradient with respect to class embeddings, CT\nabla C^T
  • if present, the gradient with respect to bias, b\partial b. We're gonna ignore this.

Here is how the paper describes the backward pass (don't worry, you can skip this: it's just to help add some color and we'll be breaking things down anyway):

Formally, the gradient is defined as ET=(SLSE)CandCT=(SLSE)TE\nabla E^T = (S \cdot \nabla \text{LSE}) C \quad \text{and} \quad \nabla C^T = (S \cdot \nabla \text{LSE})^T E where S=softmax(CTE)S = \text{softmax}(C^T E) and \cdot refers to the row-by-row elementwise multiplication of the softmax SS and the gradient LSE\nabla \text{LSE}: S^=SLSE\hat{S} = S \cdot \nabla \text{LSE}.


Computationally, the backward pass is a double matrix multiplication CTEC^T E and S^C\hat{S} C or S^TE\hat{S}^T E with intermediate matrices SS and S^\hat{S} that do not fit into GPU memory and undergo a non-linear operation. We take a similar approach to the forward pass, recomputing the matrix CTEC^T E implicitly in the GPU's shared memory. For the backward pass, we do not need to compute the normalization constant of the softmax, since S=softmax(CTE)=exp(CTELSE)S = \text{softmax}(C^T E) = \exp(C^T E - \text{LSE}). This allows us to reuse the global synchronization of the forward pass, and compute SS efficiently in parallel.


We implement the second matrix multiplication in the main memory of the GPU, as a canonical blockwise implementation would require storing or synchronizing SS. Algorithm 3 and Fig. 2c summarize the computation and access patterns. A naive implementation of this algorithm requires zero additional memory but is slow due to repeated global memory load and store operations. We use two techniques to improve the memory access pattern: gradient filtering and vocabulary sorting.


Overall, this is what the backward pass needs to do:

  1. Retrieve saved tensors: embeddings, targets, valid entries, logit averages.
  2. Recompute the logits (CTEC^T E) in chunks (like in the forward pass).
  3. Compute S=softmax(CTE)=exp(CTELSE)S = \text{softmax}(C^T E) = \exp(C^T E - \text{LSE}), the Softmax probabilities.
  4. Compute the gradient signal for the log-sum-exponent, S^=(SY)LSE\hat{S} = (S - Y) \cdot \nabla \text{LSE}, using its upstream gradient LSE\nabla \text{LSE} (or d_out in the code) and the targets YY.

triton

# Compute S
d_accum = tl.exp(accum - lse[:, None])
d_accum = tl.where(offs_v[None, :] < V, d_accum, 0.0)

# Compute S - Y
if HAS_TARGETS:
   if HAS_SHIFT:
      target_offs_b = offs_b + shift
   else:
      target_offs_b = offs_b

   targets = tl.load(Targets + target_offs_b, mask=target_offs_b < BMax, other=V + 1)
   is_target = targets[:, None] == offs_v[None, :]
   d_accum += tl.where(is_target, -1.0, 0.0) # <-- subtracting Y from S
else:
   is_target = None

# gradient filtering! Explained later
should_skip = False
if (FILTER_E_GRAD and COMPUTE_DE) and (FILTER_C_GRAD and COMPUTE_DC):
   if _block_is_filtered(tl.abs(d_accum), filter_eps):
      return
elif (FILTER_E_GRAD and COMPUTE_DE) or (FILTER_C_GRAD and COMPUTE_DC):
   should_skip = _block_is_filtered(tl.abs(d_accum), filter_eps)

if HAS_SOFTCAP:
   d_accum = tl_softcapping_grad(d_accum, accum, softcap)

if ITEM_DO:
   d_out = tl.load(dOut)
else:
   if HAS_SHIFT:
      d_out_offs_b = offs_b + shift
   else:
      d_out_offs_b = offs_b

   d_out = tl.load(dOut + d_out_offs_b, mask=d_out_offs_b < BMax, other=0.0)[:, None]

d_out = grad_scale * d_out

# The final sum we're looking for: S_hat = (S - Y) * upstream_grad
d_accum = d_accum * d_out

  1. Compute ET=S^C\nabla E^T = \hat{S} C, the gradient of the query embeddings.
  2. Compute CT=S^TE\nabla C^T = \hat{S}^T E, the gradient of the class embeddings.

If you can recall chain-rule, you can see why S^\hat{S} makes up the gradient of both the query and class embeddings. It is the key intermediate product in this series of steps.


If this looks like a wall of mathy text, the main point is that we need to go through steps 1-4, to create the intermediate term we'll be using in steps 5-6. Steps 5-6 are matrix multiplications, which are expensive operations that we'd like to avoid if we can, and that brings us to the main optimization in the backward pass: gradient filtering.




Before we dive into the gradient filtering, it helps to recall that S^\hat{S} in the triton kernel is really S^block\hat{S}_{\text{block}}, the gradient signal for a single block/vocabulary chunk. The same applies for SSblockS \to S_{\text{block}} and the other terms.


Then let's zoom in on the gradient filtering segment of the code:

triton

should_skip = False
if (FILTER_E_GRAD and COMPUTE_DE) and (FILTER_C_GRAD and COMPUTE_DC):
   if _block_is_filtered(tl.abs(d_accum), filter_eps):
      return
elif (FILTER_E_GRAD and COMPUTE_DE) or (FILTER_C_GRAD and COMPUTE_DC):
   should_skip = _block_is_filtered(tl.abs(d_accum), filter_eps)

# ... 

if COMPUTE_DE:
   if FILTER_E_GRAD:
      should_skip_e = should_skip
   # ...

if COMPUTE_DC:
   if FILTER_C_GRAD:
      should_skip_c = should_skip
   # ...

# elsewhere:

@triton.jit
def _block_is_filtered(check_val: tl.tensor, filter_eps: tl.tensor) -> tl.tensor:
    return tl.reduce(check_val < filter_eps, None, tl_and_reduce_fn)

_block_is_filtered checks if all values within d_accum (which is S^block\hat{S}_{\text{block}}) are below our "near-zero" threshold. If so, we terminate the kernel early, or skip over some of the remaining matrix multiplications involved in calculating the gradient outputs ET\nabla E^T and CT\nabla C^T.


As a whole, the gradient filtering process is as follows:

  1. Vocab sorting: sorting the vocab terms by their average logit to increase the number of "zero" gradient blocks we'll encounter in the gradient filtering step. This is a heuristic to help reduce the number of partially filled non-zero blocks whose gradients we need to compute; we want all "significant" blocks to be as dense and contiguous as possible. This happened before passing the inputs to the kernel.
  2. Gradient filtering: whenever S^\hat{S} is too small, we can skip the expensive matrix multiplications in steps 5 and 6 and zero out the contributions from these blocks.

Takeaways: Usage + Design

Cut cross entropy is likely going to become the go-to implementation of cross-entropy going forward. The memory savings it presents are too good to pass up.

As for its design, this approach to chunking, avoiding global memory loads, and perhaps even the zeroing-out and doing an "early stop" to expensive computations are methods that can be applied to other related problems such as kernel optimization.

I'll definitely be writing more about kernel optimization and other similar breakdowns by the way, so if you found this interesting, you might want to follow me on any of my socials.




btw, I happen to also be looking for a job as a machine learning engineer now, so if you happen to know of any opportunities, please let me know! You can find my portfolio/resume here.

Related Posts

Revisit Garden

9 min read

May 13, 2025

Derivation: Cross-Entropy

Revisit Garden