Introduction: nn.Linear in PyTorch, Clearly Explained
Updated on
nn.Linear is one of the most fundamental building blocks in PyTorch.
Whether you are building a simple Multi-Layer Perceptron (MLP), a Transformer, or a large-scale deep learning system, linear layers are everywhere — performing fast affine transformations on your data.
With PyTorch 2.x, linear layers now benefit from updated kernels, compiler optimizations (torch.compile), and better GPU acceleration. This updated guide covers everything you need to know, from the basics to modern deep-learning usage.
Want to quickly create Data Visualization from Python Pandas DataFrame with no code?
PyGWalker is a Python library for Exploratory Data Analysis with Visualization.
It turns your pandas or polars DataFrame into a Tableau-like UI inside Jupyter Notebook.
Understanding nn.Linear in PyTorch
What is nn.Linear?
nn.Linear applies an affine transformation to input data:
[ y = xA^T + b ]
x: input tensorA: weight matrix (out_features × in_features)b: bias vector (out_features)y: output tensor
The module takes two required arguments:
nn.Linear(in_features, out_features, bias=True)Example:
import torch
from torch import nn
linear_layer = nn.Linear(in_features=3, out_features=1)This creates:
- weight shape:
1 × 3 - bias shape:
1
How Does nn.Linear Work?
A forward pass performs:
- matrix multiplication with the weights
- addition of the bias (if enabled)
output = linear_layer(torch.tensor([1., 2., 3.]))
print(output)📌 Modern Usage: Batched Inputs
PyTorch applies nn.Linear to the last dimension.
So it works seamlessly with:
2D inputs (batch, features)
[batch, in_features]3D inputs (batch, seq_len, features)
Common in Transformers:
[batch, seq_len, in_features] → [batch, seq_len, out_features]Example:
x = torch.randn(32, 128, 512) # batch=32, seq_len=128, hidden=512
linear = nn.Linear(512, 1024)
y = linear(x) # output: [32, 128, 1024]No reshape needed — this is important for attention projections (Q, K, V).
Initializing Weights and Biases
PyTorch initializes weights using Kaiming Uniform by default, but you can customize this using torch.nn.init.
import torch.nn.init as init
init.xavier_uniform_(linear_layer.weight)
init.zeros_(linear_layer.bias)When to use bias=False
You often disable bias when:
- normalization layers follow the linear layer
(e.g.,
LayerNormin Transformers) - Q/K/V projection layers in attention blocks
- linear layers used inside residual blocks
Example:
nn.Linear(embed_dim, embed_dim, bias=False)Comparing nn.Linear and nn.Conv2d
| Layer | Best for | Operation | Parameters |
|---|---|---|---|
nn.Linear | vectors, flattened input, MLPs, Transformers | Affine transform | in_features, out_features |
nn.Conv2d | image data, local spatial patterns | Sliding window convolution | channels, kernel size, stride, padding |
Key differences:
- Linear layers use dense connections (all-to-all).
- Conv2D uses spatially local kernels, reducing parameters and introducing inductive bias.
Applications of nn.Linear in Deep Learning
1. Multi-Layer Perceptron (MLP)
layer = nn.Linear(256, 128)2. Transformer Models
A standard Transformer uses linear layers for:
- Q, K, V projections
- MLP feed-forward networks
- output heads
Example:
q = nn.Linear(d_model, d_model, bias=False)3. Classification Heads
classifier = nn.Linear(hidden_dim, num_classes)4. Autoencoders
Both encoder and decoder use linear layers to compress and reconstruct.
Using nn.Linear in a PyTorch Model
import torch.nn.functional as F
from torch import nn
class Net(nn.Module):
def __init__(self):
super().__init__()
self.fc1 = nn.Linear(10, 5)
self.fc2 = nn.Linear(5, 2)
def forward(self, x):
x = F.relu(self.fc1(x))
return self.fc2(x)
net = Net()Modern PyTorch 2.x Tips (Important!)
🚀 1. torch.compile optimizes linear layers
model = torch.compile(net)Compiling fuses operations for faster GPU performance.
⚡ 2. AMP for faster training on GPUs
with torch.cuda.amp.autocast():
output = model(x)📉 3. Quantized Linear for inference
import torch.ao.quantization as quantReduces model size and speeds up inference.
🔧 4. Use F.linear for custom fused operations
import torch.nn.functional as F
y = F.linear(x, weight, bias)Useful when combining multiple operations manually.
Common Errors and How to Fix Them
1. Shape Mismatch
RuntimeError: mat1 and mat2 shapes cannot be multipliedFix: ensure input last dimension matches in_features.
2. Passing 3D inputs without letting nn.Linear handle shape
Avoid reshaping unnecessarily — Linear already supports 3D.
3. Forgetting .float() when mixing integer tensors
x = x.float()Conclusion
nn.Linear is a deceptively simple but foundational module in PyTorch.
From MLPs to Transformers to modern generative models, linear layers make up the majority of computation in today’s deep learning systems.
Understanding:
- how they work
- how they handle different input shapes
- how to optimize them with PyTorch 2.x
will make you significantly more effective when building neural models.
FAQs
What is the purpose of a bias vector in nn.Linear?
The bias allows the model to shift output values independently of the input. This improves flexibility, especially when inputs are not zero-centered.
How do you initialize weights and biases?
PyTorch initializes them automatically, but you can override using torch.nn.init (e.g., Xavier initialization).
Difference between nn.Linear and nn.Conv2d?
nn.Linear applies a dense affine transform; nn.Conv2d applies spatial convolution. Linear layers are used in MLPs and Transformers, while Conv2D layers are used in CNNs.
