hushmoon avatar
A Glimpse of Torch & Einops
Overview

A Glimpse of Torch & Einops

May 18, 2026
3 min read

Introduction to Torch & Einops

Torch

在具体记录一些 Torch 的 common-use functions and classes 之前,我有必要确立 PyTorch 的心智模型 — 当我们在谈论 Torch 的时候,我们在谈论什么。

Einops

Classical Functions Reinvention

LogSumExp

为什么需要计算 LogSumExp ?

实际应用

Softmax 是神经网络中最常见的操作:

Softmax(xi)=exijexjSoftmax(x_i) = \frac{e^{x_i}}{\sum_j e^{x_j}}

两边都取 log:

logSoftmax(xi)=xilogjexj\log Softmax(x_i) = x_i - \log\sum_j e^{x_j}

LogSumExp 就是 Softmax 的分母部分,数值稳定的 Softmax 就依赖它。

Note (数值稳定是什么意思?)

假设我们现在要计算 logiexi\log\sum_i e^{x_i},直接计算的问题:

如果 xi=1000x_i = 1000e1000e^{1000}数值溢出(Inf)。

如果 xi=1000x_i = -1000e1000e^{-1000}下溢为 0,导致 log(0)=\log(0) = -\infty

数值稳定的数学技巧 —— 减去最大值

logiexi=c+logiexic\log\sum_i e^{x_i} = c + \log\sum_i e^{x_i - c}

其中 c=maxi(xi)c = \max_i(x_i),这是因为:

c+logiexic=logec+logiexic=log(eciexic)=logiexic + \log\sum_i e^{x_i - c} = \log e^c + \log\sum_i e^{x_i - c} = \log\left(e^c \cdot \sum_i e^{x_i-c}\right) = \log\sum_i e^{x_i}

为什么这样就稳定了?

  • xic0x_i - c \leq 0,所以 exic(0,1]e^{x_i-c} \in (0, 1]不会溢出
  • 至少有一个 xic=0x_i - c = 0,所以求和结果 1\geq 1log\log 不会是 -\infty

Code

def batched_logsumexp(matrix: Tensor) -> Tensor:
c = matrix.max(dim=-1, keepdim=True).values
# Tensor.max 返回 torch.return_types.max,有两个字段 values & indices
shifted = matrix - c
# matrix shape (batch, n)
# c shape (batch, 1)
# broadcasting c --> (batch, n)
return c.squeeze(-1) + shifted.exp().sum(dim=-1).log()

Softmax & LogSoftmax

我现在来推导一下数值稳定版的 Softmax 和 LogSoftmax。

LogSoftmax 推导和代码实现

数学推导

首先,Softmax 的原表达式为:

Softmax(xi)=exijexjSoftmax(x_i) = \frac{e^{x_i}}{\sum_j e^{x_j}}

数值稳定版本,我需要算出 LogSumExp,其中 c=maxj(xj)c = \max_j(x_j)

logjexj=c+logjexjc\log\sum_j e^{x_j} = c + \log\sum_j e^{x_j - c}

给 Softmax 原表达式两边加上 log:

logSoftmax(xi)=xilogjexj\log Softmax(x_i) = x_i - \log \sum_j e^{x_j}

最后将右侧代换,就得到了数值稳定版的 LogSoftmax:

logSoftmax(xi)=xiLogSumExp\log Softmax(x_i) = x_i - LogSumExp

代码实现

def batched_logsoftmax(matrix: Tensor) -> Tensor:
"""Compute log(softmax(row)) for each row of the matrix.
matrix: shape (batch, n)
Return: shape (batch, n)
"""
# 数值稳定:先减最大值
c = matrix.max(dim=-1, keepdim=True).values # (batch, 1)
# 对应 LogSumExp 表达式
log_sum_exp = c + (matrix - c).exp().sum(dim=-1, keepdim=True).log()
# keepdim=True 保持 shape 为 (batch, 1)
# 最终结果 LogSoftmax
# broadcasting: log_sum_exp 从 (batch, 1) 广播为 (batch, n)
return matrix - log_sum_exp

Softmax 推导和代码实现

数学推导

从上面已经得出了数值稳定的 LogSoftmax 表示:

logSoftmax(xi)=xiLogSumExp=xiclogjexjc\log Softmax(x_i) = x_i - LogSumExp = x_i - c - \log\sum_j e^{x_j - c}

此时,加上 exp 可以得到 Softmax 的另一种表达方式:

Softmax(xi)=e(xic)logjexjc=exicjexjc\begin{aligned} Softmax(x_i) &= e^{(x_i - c) - \log\sum_j e^{x_j - c}} \\ &= \frac{e^{x_i - c}}{\sum_j e^{x_j - c}} \end{aligned}

代码实现

def batched_softmax(matrix: Tensor) -> Tensor:
c = matrix.max(dim=-1, keepdim=True).values
shifted_exp = (matrix - c).exp()
return shifted_exp / shifted_exp.sum(dim=-1, keepdim=True)

Cross Entropy Loss

数学推导

交叉熵可以分别从统计角度和信息论的角度来解析和推导,具体的内容可以 refer to ML 的 Foundation 部分(还在施工中)。

Cross Entropy Loss 本质上是在惩罚模型没有把足够大的概率分配给真实分布(类别)。

假设真实的概率分布为 pp(通常是 one-hot,只有一个类别概率是 1,其余的概率都是 0);

假设模型学习到的概率分布为 qq

Information Theory 给的直觉 —— 概率越小的事件,发生时信息量越大,必然事件没有信息量。以此确定的自信息(Self-Information)和交叉熵分别如下所示:

I(x)=logp(x)CrossEntropy=plogq\begin{gather} I(x) = -\log p(x) \\ CrossEntropy = -\sum p\log q \end{gather}

代码实现

def batched_cross_entropy_loss(logits: Tensor, true_labels: Tensor) -> Tensor:
"""Compute the cross entropy loss for each example in the batch.
logits: shape (batch, classes). logits[i][j] is the unnormalized prediction for example i and class j.
true_labels: shape (batch, ). true_labels[i] is an integer index representing the true class for example i.
Return: shape (batch, ). out[i] is the loss for example i.
"""
log_prob = batched_logsoftmax(logits) # return (batch, classes)
return -log_prob[t.arange(len(true_labels)), true_labels]