HomePortfolioPosts

Overview: Apple's smol Cut Cross-Entropy

deep learning
math

February 17, 2025 · 5 min read

Source code: GitHub - Apple ML Cut Cross Entropy

Context

Cross Entropy is a commonly used loss function in deep learning, particularly in training Large Language Models (LLMs). However, its computation can be memory-intensive, as vocab sizes in LLMs can be enormous. As noted in Apple's CCE paper, cross-entropy loss can account for 40–90% of total GPU memory usage. By reducing the memory footprint of cross-entropy, we can significantly decrease the hardware requirements for training LLMs!

A quick recap of cross-entropy:

L=logezyikezk=zyi+logkezk\mathcal{L} = - \log \frac{e^{z_{y_i}}}{\sum_{k} e^{z_k}} = - z_{y_i} + \log \sum_{k} e^{z_k}

(If this looks unfamiliar, check out this derivation of cross-entropy loss.)

The loss function can be broken into two terms:

  1. The negative dot product: zyi- z_{y_i}
  2. The log sum exponent (LSE): logkezk\log \sum_{k} e^{z_k}

The general goal of the whole "cut cross entropy" approach is to avoid direct computation of the logits (zk=CTEz_k = C^T E) naively, as CTC^T, the linear classifier, projects the embeddings EE onto an extremely large vocabulary space, which means working with very large logits (which we still have to compute further terms of, like softmax).

The cut cross entropy method partitions the logits into smaller, ranked, and filtered chunks. This “cutting” process can achieve memory savings of over 1000 times compared to traditional cross entropy implementations.

Forward Pass Breakdown

The inputs to the CCE loss layer include:

  • Embeddings (batch size×dim)(\text{batch size} \times \text{dim})
  • Classifier embeddings (dim×vocab)(\text{dim} \times \text{vocab})
  • Bias (vocab×1)(\text{vocab} \times 1)
  • other CCE parameters

Key Steps

  1. Compute logit_avg (optional)

    • If filtering logits later, logit_avg is used to detect outlier logits.
  2. Forward kernels:

    • cce_lse_forward_kernel a triton kernel that computes the log-sum-exp (LSE):
      logkezk\log \sum_{k} e^{z_k} Also outputs logit_avg.
    • indexed_dot_forward_kernel a triton kernel calculates the negative dot product:
      neg_dot=ecyi\text{neg\_dot} = - e \cdot c_{y_i} where ee is the embedding and cyic_{y_i} is the correct class embeddings; the negative dot product between embeddings and the correct class logits.
    • Summing these terms yields cross-entropy loss (negative log-likelihood). Using the dot product makes this the more general form of cross entropy loss rather than negative log likelihood, since c(y)c(y) can be the target probability distribution across multiple classes.
  3. Reduction (optional):

    • Loss is aggregated using mean or sum.
  4. Saving tensors for the backward pass:

    • ctx.save_for_backward(e, c, bias, lse, params.targets, params.valids, logit_avg), i.e., the inputs, targets, and some stats.

The kernel magic

Between the two kernels, I think the cce_lse_forward_kernel is definitely more interesting, so let's break down its triton implementation further.

accum = tl.zeros((BLOCK_B, BLOCK_V), dtype=tl.float32)
for d in range(0, tl.cdiv(D, BLOCK_D)):
    e = tl.load(e_ptrs, mask=e_mask, other=0.0)
    c = tl.load(c_ptrs, mask=c_mask, other=0.0)
    accum = tl.dot(e, c, accum, input_precision=DOT_PRECISION)
    e_ptrs += BLOCK_D * stride_ed  # Move to next tile
    c_ptrs += BLOCK_D * stride_cd

... (Work in progress)

Backward Pass Breakdown

  1. Retrieve the tensors that were saved for the backward step.
  2. Sort vocab terms based on logit_avg for potential filtering (the "cut").
  3. Obtain a scale factor for gradients based on the reduction method that was used to compute the final loss.
  4. cce_backward_kernel computes the gradients and is where the magic happens:
  • Le\frac{\partial \mathcal{L}}{\partial e} (Gradient w.r.t. query embeddings)
  • Lc\frac{\partial \mathcal{L}}{\partial c} (Gradient w.r.t. candidate embeddings)
  • Lb\frac{\partial \mathcal{L}}{\partial b} (Gradient w.r.t. bias, if present)

Softmax gradient computation

Lzi=p(yi)1i=y\frac{\partial \mathcal{L}}{\partial z_i} = p(y_i) - \mathbb{1}_{i=y}

where: p(yi)=ezijezjp(y_i) = \frac{e^{z_i}}{\sum_j e^{z_j}}

Key operations in the backward kernel

  • Filtering, which also prevents underflows: p(yi)=max(ezijezj,filter_eps)p(y_i) = \max\left(\frac{e^{z_i}}{\sum_j e^{z_j}}, \text{filter\_eps}\right)
  • Soft-capping, to limit large logits: eziemin(zi,softcap)e^{z_i} \to e^{\min(z_i, \text{softcap})}

Related Posts

5 min read

February 17, 2025

Derivation: Cross-Entropy