9 min read
May 13, 2025
May 16, 2025 · 16 min read
This is the source code that's referenced throughout this article. It might be helpful to have this open in another tab: GitHub - Apple ML Cut Cross Entropy
Cross Entropy is a ubiquitous loss function in deep learning, including the training of LLMs (next token prediction is a classification task). However, computing the cross entropy loss can be memory-intensive, as vocab sizes in LLMs can be enormous. As noted in Apple's CCE paper, cross-entropy loss can account for 40–90% of total GPU memory usage. By reducing the memory footprint of cross-entropy, we can significantly decrease the hardware requirements for training LLMs!
If you are unfamiliar with the math, check out derivation of cross-entropy loss. But here is the structure in brief:
Cross-entropy can be broken into two terms:
The general goal of the whole "cut cross entropy" approach is to avoid direct computation of the logits () naively, as , the linear classifier, projects the embeddings onto an extremely large vocabulary space, which means working with very large logits (which we still have to compute further terms of, like softmax
).
Cut Cross Entropy saves memory by doing two things:
The core innovation is not a new loss function, but a restructured computation: by chunking the logits and then selectively skipping over unnecessary chunks.
The “cutting” process can achieve memory savings of over 1000 times compared to traditional cross entropy implementations.
Inputs:
Notably, the logits have not been computed yet.
Key operations:
logit_avg
). This is later used for filtering.cce_lse_forward_kernel
: A triton kernel that computes the log-sum-exp over selected logits, returning both the LSE (and logit_avg
).indexed_dot_forward_kernel
: A triton kernel that computes the negative dot product with the correct class embedding.Between the kernels, the cce_lse_forward_kernel
is more interesting, so lets down its triton implementation.
Skipping over the offset and pointer intitialisation, we hit the first math operation:
accum = tl.zeros((BLOCK_B, BLOCK_V), dtype=tl.float32)
for d in range(0, tl.cdiv(D, BLOCK_D)):
e_mask = offs_b[:, None] < BMax
if not EVEN_D:
e_mask = e_mask & (offs_d[None, :] < (D - d * BLOCK_D))
e = tl.load(e_ptrs, mask=e_mask, other=0.0)
c_mask = offs_v[None, :] < V
if not EVEN_D:
c_mask = c_mask & (offs_d[:, None] < (D - d * BLOCK_D))
c = tl.load(c_ptrs, mask=c_mask, other=0.0)
accum = tl.dot(e, c, accum, input_precision=DOT_PRECISION)
e_ptrs += BLOCK_D * stride_ed
c_ptrs += BLOCK_D * stride_cd
Here we are computing accum
: the dot product of the query and class embeddings (represented generally by e
and c
). This is done in chunks of BLOCK_D
.
By chunking this computation of and doing it accumulatively, we avoid loading the entire intermediate sum into memory. That's a footprint saving of .
Our kernel retrieves the value , the -th column from C, and the -th column from E, and stores them in on-chip shared memory (SRAM). It then performs a dot product between and and writes the result into global memory. The kernel uses only on-chip SRAM throughout and does not allocate any GPU memory. For efficiency, we perform all operations blockwise to make the best use of GPU cache structure
After that, comes computation of the linear-log-sum-exponent (LSE).
this_mx = tl.max(logits, axis=1)
e = tl.exp(logits - this_mx[:, None])
this_lse = this_mx + tl.log(tl.sum(e, axis=1))
# ...
this_locks = Locks + (pid_b // tl.cdiv(B, BLOCK_B * num_locks))
while tl.atomic_cas(this_locks, 0, 1) == 1:
pass
lse = tl.load(lse_ptrs, mask=o_mask, other=0.0, eviction_policy="evict_last")
lse = tl_logaddexp(lse, this_lse)
tl.store(lse_ptrs, lse, mask=o_mask, eviction_policy="evict_last")
tl.debug_barrier()
tl.atomic_xchg(this_locks, 0)
There are 3 main steps here:
(BLOCK_B, BLOCK_V)
chunk of logitstl.atomic_cas
is used to implement a spin lock.(pid_b // tl.cdiv(B, BLOCK_B * num_locks))
, causes multiple pid_b
blocks to share the same lock. Not shown in all this code is that pid_v
blocks that correspond to the same pid_b
(same block, but different vocabulary slices) will share the same pid_b
lock.lse = tl.load(lse_ptrs...)
loads the current global value of LSE
. This is intialized as -inf
(because ).lse = tl_logaddexp(lse, this_lse)
is the core aggregation step. By repeatedly computing , different blocks/vocabulary chunks will combine their partial LSEs into a global LSE.tl.store(lse_ptrs, lse, mask=o_mask, eviction_policy="evict_last")
stores the updated global LSE back to global memory.tl.atomic_xchg(this_locks, 0)
releases the lock, allowing other threads to update the global LSE.In short: to ensure synchronicity, we use a spin lock so that only one vocabulary chunk can update the global LSE at a time. Each block computes a partial LSE, which is then merged into the global one. This is an aggregation step to get our final result, the term.
The backward pass mirrors the forward structure, but with added complexity for gradient scaling and filtering. This backward pass needs to compute at least two gradient terms:
Here is how the paper describes the backward pass (don't worry, you can skip this: it's just to help add some color and we'll be breaking things down anyway):
Formally, the gradient is defined as where and refers to the row-by-row elementwise multiplication of the softmax and the gradient : .
Computationally, the backward pass is a double matrix multiplication and or with intermediate matrices and that do not fit into GPU memory and undergo a non-linear operation. We take a similar approach to the forward pass, recomputing the matrix implicitly in the GPU's shared memory. For the backward pass, we do not need to compute the normalization constant of the softmax, since . This allows us to reuse the global synchronization of the forward pass, and compute efficiently in parallel.
We implement the second matrix multiplication in the main memory of the GPU, as a canonical blockwise implementation would require storing or synchronizing . Algorithm 3 and Fig. 2c summarize the computation and access patterns. A naive implementation of this algorithm requires zero additional memory but is slow due to repeated global memory load and store operations. We use two techniques to improve the memory access pattern: gradient filtering and vocabulary sorting.
Overall, this is what the backward pass needs to do:
d_out
in the code) and the targets .# Compute S
d_accum = tl.exp(accum - lse[:, None])
d_accum = tl.where(offs_v[None, :] < V, d_accum, 0.0)
# Compute S - Y
if HAS_TARGETS:
if HAS_SHIFT:
target_offs_b = offs_b + shift
else:
target_offs_b = offs_b
targets = tl.load(Targets + target_offs_b, mask=target_offs_b < BMax, other=V + 1)
is_target = targets[:, None] == offs_v[None, :]
d_accum += tl.where(is_target, -1.0, 0.0) # <-- subtracting Y from S
else:
is_target = None
# gradient filtering! Explained later
should_skip = False
if (FILTER_E_GRAD and COMPUTE_DE) and (FILTER_C_GRAD and COMPUTE_DC):
if _block_is_filtered(tl.abs(d_accum), filter_eps):
return
elif (FILTER_E_GRAD and COMPUTE_DE) or (FILTER_C_GRAD and COMPUTE_DC):
should_skip = _block_is_filtered(tl.abs(d_accum), filter_eps)
if HAS_SOFTCAP:
d_accum = tl_softcapping_grad(d_accum, accum, softcap)
if ITEM_DO:
d_out = tl.load(dOut)
else:
if HAS_SHIFT:
d_out_offs_b = offs_b + shift
else:
d_out_offs_b = offs_b
d_out = tl.load(dOut + d_out_offs_b, mask=d_out_offs_b < BMax, other=0.0)[:, None]
d_out = grad_scale * d_out
# The final sum we're looking for: S_hat = (S - Y) * upstream_grad
d_accum = d_accum * d_out
If you can recall chain-rule, you can see why makes up the gradient of both the query and class embeddings. It is the key intermediate product in this series of steps.
If this looks like a wall of mathy text, the main point is that we need to go through steps 1-4, to create the intermediate term we'll be using in steps 5-6. Steps 5-6 are matrix multiplications, which are expensive operations that we'd like to avoid if we can, and that brings us to the main optimization in the backward pass: gradient filtering.
Before we dive into the gradient filtering, it helps to recall that in the triton kernel is really , the gradient signal for a single block/vocabulary chunk. The same applies for and the other terms.
Then let's zoom in on the gradient filtering segment of the code:
should_skip = False
if (FILTER_E_GRAD and COMPUTE_DE) and (FILTER_C_GRAD and COMPUTE_DC):
if _block_is_filtered(tl.abs(d_accum), filter_eps):
return
elif (FILTER_E_GRAD and COMPUTE_DE) or (FILTER_C_GRAD and COMPUTE_DC):
should_skip = _block_is_filtered(tl.abs(d_accum), filter_eps)
# ...
if COMPUTE_DE:
if FILTER_E_GRAD:
should_skip_e = should_skip
# ...
if COMPUTE_DC:
if FILTER_C_GRAD:
should_skip_c = should_skip
# ...
# elsewhere:
@triton.jit
def _block_is_filtered(check_val: tl.tensor, filter_eps: tl.tensor) -> tl.tensor:
return tl.reduce(check_val < filter_eps, None, tl_and_reduce_fn)
_block_is_filtered
checks if all values within d_accum
(which is ) are below our "near-zero" threshold. If so, we terminate the kernel early, or skip over some of the remaining matrix multiplications involved in calculating the gradient outputs and .
As a whole, the gradient filtering process is as follows:
Cut cross entropy is likely going to become the go-to implementation of cross-entropy going forward. The memory savings it presents are too good to pass up.
As for its design, this approach to chunking, avoiding global memory loads, and perhaps even the zeroing-out and doing an "early stop" to expensive computations are methods that can be applied to other related problems such as kernel optimization.
I'll definitely be writing more about kernel optimization and other similar breakdowns by the way, so if you found this interesting, you might want to follow me on any of my socials.
btw, I happen to also be looking for a job as a machine learning engineer now, so if you happen to know of any opportunities, please let me know! You can find my portfolio/resume here.