Skip to content

Introduction: nn.Linear in PyTorch, Clearly Explained

Updated on

nn.Linear is one of the most fundamental building blocks in PyTorch.
Whether you are building a simple Multi-Layer Perceptron (MLP), a Transformer, or a large-scale deep learning system, linear layers are everywhere — performing fast affine transformations on your data.

With PyTorch 2.x, linear layers now benefit from updated kernels, compiler optimizations (torch.compile), and better GPU acceleration. This updated guide covers everything you need to know, from the basics to modern deep-learning usage.

Want to quickly create Data Visualization from Python Pandas DataFrame with no code?

PyGWalker is a Python library for Exploratory Data Analysis with Visualization.
It turns your pandas or polars DataFrame into a Tableau-like UI inside Jupyter Notebook.

👉 PyGWalker on GitHub (opens in a new tab)

PyGWalker for Data visualization (opens in a new tab)


Understanding nn.Linear in PyTorch

What is nn.Linear?

nn.Linear applies an affine transformation to input data:

[ y = xA^T + b ]

  • x: input tensor
  • A: weight matrix (out_features × in_features)
  • b: bias vector (out_features)
  • y: output tensor

The module takes two required arguments:

nn.Linear(in_features, out_features, bias=True)

Example:

import torch
from torch import nn
 
linear_layer = nn.Linear(in_features=3, out_features=1)

This creates:

  • weight shape: 1 × 3
  • bias shape: 1

How Does nn.Linear Work?

A forward pass performs:

  • matrix multiplication with the weights
  • addition of the bias (if enabled)
output = linear_layer(torch.tensor([1., 2., 3.]))
print(output)

📌 Modern Usage: Batched Inputs

PyTorch applies nn.Linear to the last dimension. So it works seamlessly with:

2D inputs (batch, features)

[batch, in_features]

3D inputs (batch, seq_len, features)

Common in Transformers:

[batch, seq_len, in_features] → [batch, seq_len, out_features]

Example:

x = torch.randn(32, 128, 512)   # batch=32, seq_len=128, hidden=512
linear = nn.Linear(512, 1024)
y = linear(x)                   # output: [32, 128, 1024]

No reshape needed — this is important for attention projections (Q, K, V).


Initializing Weights and Biases

PyTorch initializes weights using Kaiming Uniform by default, but you can customize this using torch.nn.init.

import torch.nn.init as init
 
init.xavier_uniform_(linear_layer.weight)
init.zeros_(linear_layer.bias)

When to use bias=False

You often disable bias when:

  • normalization layers follow the linear layer (e.g., LayerNorm in Transformers)
  • Q/K/V projection layers in attention blocks
  • linear layers used inside residual blocks

Example:

nn.Linear(embed_dim, embed_dim, bias=False)

Comparing nn.Linear and nn.Conv2d

LayerBest forOperationParameters
nn.Linearvectors, flattened input, MLPs, TransformersAffine transformin_features, out_features
nn.Conv2dimage data, local spatial patternsSliding window convolutionchannels, kernel size, stride, padding

Key differences:

  • Linear layers use dense connections (all-to-all).
  • Conv2D uses spatially local kernels, reducing parameters and introducing inductive bias.

Applications of nn.Linear in Deep Learning

1. Multi-Layer Perceptron (MLP)

layer = nn.Linear(256, 128)

2. Transformer Models

A standard Transformer uses linear layers for:

  • Q, K, V projections
  • MLP feed-forward networks
  • output heads

Example:

q = nn.Linear(d_model, d_model, bias=False)

3. Classification Heads

classifier = nn.Linear(hidden_dim, num_classes)

4. Autoencoders

Both encoder and decoder use linear layers to compress and reconstruct.


Using nn.Linear in a PyTorch Model

import torch.nn.functional as F
from torch import nn
 
class Net(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc1 = nn.Linear(10, 5)
        self.fc2 = nn.Linear(5, 2)
 
    def forward(self, x):
        x = F.relu(self.fc1(x))
        return self.fc2(x)
 
net = Net()

Modern PyTorch 2.x Tips (Important!)

🚀 1. torch.compile optimizes linear layers

model = torch.compile(net)

Compiling fuses operations for faster GPU performance.


⚡ 2. AMP for faster training on GPUs

with torch.cuda.amp.autocast():
    output = model(x)

📉 3. Quantized Linear for inference

import torch.ao.quantization as quant

Reduces model size and speeds up inference.


🔧 4. Use F.linear for custom fused operations

import torch.nn.functional as F
y = F.linear(x, weight, bias)

Useful when combining multiple operations manually.


Common Errors and How to Fix Them

1. Shape Mismatch

RuntimeError: mat1 and mat2 shapes cannot be multiplied

Fix: ensure input last dimension matches in_features.

2. Passing 3D inputs without letting nn.Linear handle shape

Avoid reshaping unnecessarily — Linear already supports 3D.

3. Forgetting .float() when mixing integer tensors

x = x.float()

Conclusion

nn.Linear is a deceptively simple but foundational module in PyTorch. From MLPs to Transformers to modern generative models, linear layers make up the majority of computation in today’s deep learning systems.

Understanding:

  • how they work
  • how they handle different input shapes
  • how to optimize them with PyTorch 2.x

will make you significantly more effective when building neural models.


FAQs

What is the purpose of a bias vector in nn.Linear?

The bias allows the model to shift output values independently of the input. This improves flexibility, especially when inputs are not zero-centered.

How do you initialize weights and biases?

PyTorch initializes them automatically, but you can override using torch.nn.init (e.g., Xavier initialization).

Difference between nn.Linear and nn.Conv2d?

nn.Linear applies a dense affine transform; nn.Conv2d applies spatial convolution. Linear layers are used in MLPs and Transformers, while Conv2D layers are used in CNNs.