Skip to content

引言:深入理解 PyTorch 中的 nn.Linear

Updated on

nn.Linear 是 PyTorch 中最基础的模块之一。
无论你在构建简单的多层感知机(MLP)、Transformer,还是大规模深度学习系统,线性层几乎无处不在——负责对数据执行快速的仿射变换。

在 PyTorch 2.x 中,线性层受益于更新后的内核实现、编译器优化(torch.compile)以及更好的 GPU 加速。本文是一次全面、现代的更新,从基础概念到当下主流深度学习场景的用法都会覆盖。

想要从 Python Pandas DataFrame 快速创建 零代码数据可视化?

PyGWalker 是一个用于可视化探索性数据分析的 Python 库。
它可以把你的 pandas 或 polars DataFrame 变成 Jupyter Notebook 中类似 Tableau 的交互界面。

👉 PyGWalker on GitHub (opens in a new tab)

PyGWalker for Data visualization (opens in a new tab)


理解 PyTorch 中的 nn.Linear

什么是 nn.Linear?

nn.Linear 对输入数据执行一次仿射变换:

[ y = xA^T + b ]

  • x: 输入张量
  • A: 权重矩阵(out_features × in_features
  • b: 偏置向量(out_features
  • y: 输出张量

该模块有两个必需参数:

nn.Linear(in_features, out_features, bias=True)

示例:

import torch
from torch import nn
 
linear_layer = nn.Linear(in_features=3, out_features=1)

这会创建:

  • 权重形状:1 × 3
  • 偏置形状:1

nn.Linear 是如何工作的?

一次前向传播会执行:

  • 使用权重进行矩阵乘法
  • 加上偏置(如果启用)
output = linear_layer(torch.tensor([1., 2., 3.]))
print(output)

📌 现代用法:批量输入(batched inputs)

PyTorch 会把 nn.Linear 应用在张量的最后一个维度上。
因此它可以无缝支持下面这些形状:

2D 输入(batch, features)

[batch, in_features]

3D 输入(batch, seq_len, features)

在 Transformer 中非常常见:

[batch, seq_len, in_features] → [batch, seq_len, out_features]

示例:

x = torch.randn(32, 128, 512)   # batch=32, seq_len=128, hidden=512
linear = nn.Linear(512, 1024)
y = linear(x)                   # output: [32, 128, 1024]

不需要额外 reshape —— 这对于注意力中的 Q、K、V 投影非常重要。


权重与偏置的初始化

PyTorch 默认使用 Kaiming Uniform 初始化权重,但你可以通过 torch.nn.init 自定义初始化方式。

import torch.nn.init as init
 
init.xavier_uniform_(linear_layer.weight)
init.zeros_(linear_layer.bias)

什么时候使用 bias=False

通常会在以下情况下关闭偏置:

  • 在线性层后面接有归一化层
    (例如 Transformer 中的 LayerNorm
  • 注意力模块中的 Q / K / V 投影层
  • 残差块(residual block)内部使用的线性层

示例:

nn.Linear(embed_dim, embed_dim, bias=False)

nn.Linear 和 nn.Conv2d 的对比

Layer适用场景运算类型关键参数
nn.Linear向量、展平后的输入、MLP、Transformer仿射变换(affine transform)in_features, out_features
nn.Conv2d图像数据、本地空间模式(local spatial pattern)滑动窗口卷积通道数、kernel size、stride、padding 等

主要区别:

  • Linear 层是全连接(dense connections),输入特征与输出特征之间是全连接。
  • Conv2D 使用局部空间卷积核,参数更少,同时引入对空间结构的归纳偏置(inductive bias)。

nn.Linear 在深度学习中的常见应用

1. 多层感知机(MLP)

layer = nn.Linear(256, 128)

2. Transformer 模型

标准的 Transformer 中,线性层用于:

  • Q、K、V 投影
  • MLP 前馈网络(feed-forward network)
  • 输出头(output head)

示例:

q = nn.Linear(d_model, d_model, bias=False)

3. 分类头(classification head)

classifier = nn.Linear(hidden_dim, num_classes)

4. 自编码器(Autoencoder)

编码器和解码器都会使用线性层来进行压缩与重构。


在 PyTorch 模型中使用 nn.Linear

import torch.nn.functional as F
from torch import nn
 
class Net(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc1 = nn.Linear(10, 5)
        self.fc2 = nn.Linear(5, 2)
 
    def forward(self, x):
        x = F.relu(self.fc1(x))
        return self.fc2(x)
 
net = Net()

现代 PyTorch 2.x 使用建议(很重要!)

🚀 1. 使用 torch.compile 优化线性层

model = torch.compile(net)

编译后可以融合多个运算,在 GPU 上获得更高性能。


⚡ 2. 使用 AMP 提升 GPU 训练速度

with torch.cuda.amp.autocast():
    output = model(x)

📉 3. 推理阶段使用量化线性层(Quantized Linear)

import torch.ao.quantization as quant

可以减小模型大小并加速推理。


🔧 4. 自定义融合操作时使用 F.linear

import torch.nn.functional as F
y = F.linear(x, weight, bias)

当你需要手动组合多个操作并进行优化时非常有用。


常见错误及其解决方法

1. 形状不匹配(Shape Mismatch)

RuntimeError: mat1 and mat2 shapes cannot be multiplied

解决:确保输入张量的最后一维in_features 一致。

2. 3D 输入被错误地提前 reshape

避免不必要的 reshape —— nn.Linear 已经原生支持 3D 输入。

3. 整型张量未转换为浮点型

x = x.float()

确保与线性层参数类型一致。


总结

nn.Linear 看起来简单,却是 PyTorch 中最核心的模块之一。
从 MLP 到 Transformer 再到现代生成式模型,线性层往往占据了计算量的大头。

理解:

  • 它的计算方式
  • 它如何处理不同形状的输入
  • 如何在 PyTorch 2.x 中对其进行优化

会显著提升你在构建神经网络模型时的效率与能力。


常见问答(FAQs)

nn.Linear 中偏置向量的作用是什么?

偏置可以让模型在不依赖输入的情况下整体平移输出值,这在输入数据不是以零为中心时尤其重要,可以提高模型表达能力。

如何初始化权重和偏置?

PyTorch 会自动初始化,但你可以使用 torch.nn.init 进行自定义(例如 Xavier 初始化)。

nn.Linear 与 nn.Conv2d 有什么区别?

nn.Linear 对输入执行稠密的仿射变换;nn.Conv2d 执行空间卷积。Linear 层常用于 MLP 和 Transformer,而 Conv2D 常用于 CNN 的图像相关任务。