5 min read
February 17, 2025
February 17, 2025 · 5 min read
Source code: GitHub - Apple ML Cut Cross Entropy
Cross Entropy is a commonly used loss function in deep learning, particularly in training Large Language Models (LLMs). However, its computation can be memory-intensive, as vocab sizes in LLMs can be enormous. As noted in Apple's CCE paper, cross-entropy loss can account for 40–90% of total GPU memory usage. By reducing the memory footprint of cross-entropy, we can significantly decrease the hardware requirements for training LLMs!
A quick recap of cross-entropy:
(If this looks unfamiliar, check out this derivation of cross-entropy loss.)
The loss function can be broken into two terms:
The general goal of the whole "cut cross entropy" approach is to avoid direct computation of the logits () naively, as , the linear classifier, projects the embeddings onto an extremely large vocabulary space, which means working with very large logits (which we still have to compute further terms of, like softmax).
The cut cross entropy method partitions the logits into smaller, ranked, and filtered chunks. This “cutting” process can achieve memory savings of over 1000 times compared to traditional cross entropy implementations.
The inputs to the CCE loss layer include:
Compute logit_avg
(optional)
logit_avg
is used to detect outlier logits.Forward kernels:
cce_lse_forward_kernel
a triton kernel that computes the log-sum-exp (LSE):logit_avg
.indexed_dot_forward_kernel
a triton kernel calculates the negative dot product:Reduction (optional):
Saving tensors for the backward pass:
ctx.save_for_backward(e, c, bias, lse, params.targets, params.valids, logit_avg)
, i.e., the inputs, targets, and some stats.Between the two kernels, I think the cce_lse_forward_kernel
is definitely more interesting, so let's break down its triton implementation further.
accum = tl.zeros((BLOCK_B, BLOCK_V), dtype=tl.float32)
for d in range(0, tl.cdiv(D, BLOCK_D)):
e = tl.load(e_ptrs, mask=e_mask, other=0.0)
c = tl.load(c_ptrs, mask=c_mask, other=0.0)
accum = tl.dot(e, c, accum, input_precision=DOT_PRECISION)
e_ptrs += BLOCK_D * stride_ed # Move to next tile
c_ptrs += BLOCK_D * stride_cd
... (Work in progress)
logit_avg
for potential filtering (the "cut").cce_backward_kernel
computes the gradients and is where the magic happens:where: