nn.Linear in PyTorch: Complete Guide to Linear Layers
Updated on
You are building a neural network in PyTorch and every tutorial mentions nn.Linear, but the documentation is dense and the shape errors are cryptic. You pass a tensor with the wrong dimensions and get RuntimeError: mat1 and mat2 shapes cannot be multiplied. You are not sure when to set bias=False, how weight initialization affects training, or why Transformer models use linear layers everywhere instead of convolutions.
These confusions slow down debugging and model development. A mismatched in_features can silently produce garbage outputs before you notice. Choosing the wrong initialization scheme can cause gradients to vanish or explode. And without understanding how nn.Linear handles batched and 3D inputs, you end up writing unnecessary reshape operations that clutter your code.
This guide explains nn.Linear from the ground up: what it computes, how shapes work, how to initialize weights properly, when to disable bias, how it fits into modern architectures, and how to optimize it with PyTorch 2.x features like torch.compile and automatic mixed precision.
What is nn.Linear?
nn.Linear is a module that applies an affine (linear) transformation to input data. It is the core building block for fully connected layers in neural networks.
The mathematical operation is:
$$y = xW^T + b$$
Where:
- x is the input tensor
- W is the weight matrix with shape
(out_features, in_features) - b is the bias vector with shape
(out_features,) - y is the output tensor
Creating a Linear Layer
The constructor takes two required arguments and one optional:
import torch
from torch import nn
# Create a linear layer: 3 inputs -> 1 output
linear = nn.Linear(in_features=3, out_features=1)
# Inspect the parameters
print(linear.weight.shape) # torch.Size([1, 3])
print(linear.bias.shape) # torch.Size([1])The in_features parameter defines how many input values the layer expects. The out_features parameter defines the output dimension. PyTorch automatically creates the weight matrix and bias vector with the correct shapes.
torch.nn.Linear Constructor Parameters
| Parameter | Type | Default | Description |
|---|---|---|---|
in_features | int | required | Size of each input sample |
out_features | int | required | Size of each output sample |
bias | bool | True | If False, the layer does not learn an additive bias |
device | device | None | Device for parameter tensors |
dtype | dtype | None | Data type for parameter tensors |
How nn.Linear Works: Forward Pass
When you call the layer on an input tensor, it performs matrix multiplication followed by bias addition:
import torch
from torch import nn
linear = nn.Linear(3, 2)
# Single sample: shape [3]
x = torch.tensor([1.0, 2.0, 3.0])
output = linear(x)
print(output.shape) # torch.Size([2])
# Verify manually
manual = x @ linear.weight.T + linear.bias
print(torch.allclose(output, manual)) # TrueThe key insight is that nn.Linear operates on the last dimension of the input tensor. This means it handles batched inputs automatically without any reshaping.
2D Input: Batched Samples
# Batch of 32 samples, each with 3 features
x = torch.randn(32, 3)
linear = nn.Linear(3, 2)
output = linear(x)
print(output.shape) # torch.Size([32, 2])3D Input: Sequence Data (Transformers)
This is critical for Transformer models where inputs have shape (batch, sequence_length, hidden_dim):
# Batch of 8, sequence length 128, hidden dimension 512
x = torch.randn(8, 128, 512)
linear = nn.Linear(512, 256)
output = linear(x)
print(output.shape) # torch.Size([8, 128, 256])No reshape needed. The linear layer applies the same transformation to every position in the sequence independently. This is exactly how Q, K, V projection layers work in attention mechanisms.
Arbitrary Dimensions
nn.Linear works with any number of leading dimensions:
# 4D input
x = torch.randn(2, 4, 8, 16)
linear = nn.Linear(16, 32)
output = linear(x)
print(output.shape) # torch.Size([2, 4, 8, 32])Weight Initialization
PyTorch initializes nn.Linear weights using Kaiming Uniform initialization by default. This works well for layers followed by ReLU activations. The bias is initialized from a uniform distribution bounded by 1 / sqrt(in_features).
Default Initialization
linear = nn.Linear(256, 128)
# Default: Kaiming uniform
print(linear.weight.min().item(), linear.weight.max().item())
# Approximately -0.0625 to 0.0625Custom Initialization
Use torch.nn.init to apply a different initialization scheme:
import torch.nn.init as init
linear = nn.Linear(256, 128)
# Xavier uniform — good for sigmoid/tanh activations
init.xavier_uniform_(linear.weight)
init.zeros_(linear.bias)
# Xavier normal
init.xavier_normal_(linear.weight)
# Kaiming normal — good for ReLU activations
init.kaiming_normal_(linear.weight, mode='fan_out', nonlinearity='relu')
# Constant initialization
init.constant_(linear.weight, 0.01)
init.zeros_(linear.bias)
# Normal distribution
init.normal_(linear.weight, mean=0.0, std=0.02)Initialization Comparison Table
| Method | Best For | Formula | When to Use |
|---|---|---|---|
| Kaiming Uniform (default) | ReLU networks | U(-bound, bound) | MLPs, CNNs with ReLU |
| Xavier Uniform | Sigmoid/Tanh | U(-a, a) where a = sqrt(6/(fan_in + fan_out)) | Networks with sigmoid/tanh |
| Xavier Normal | Sigmoid/Tanh | N(0, std) where std = sqrt(2/(fan_in + fan_out)) | Alternative to Xavier Uniform |
| Kaiming Normal | ReLU networks | N(0, std) where std = sqrt(2/fan_in) | ResNets, deeper networks |
| Normal(0, 0.02) | Transformers | N(0, 0.02) | GPT-style models |
Initialization in a Model
Apply initialization during __init__ or with apply():
class MLP(nn.Module):
def __init__(self, input_dim, hidden_dim, output_dim):
super().__init__()
self.fc1 = nn.Linear(input_dim, hidden_dim)
self.fc2 = nn.Linear(hidden_dim, output_dim)
self._init_weights()
def _init_weights(self):
for module in self.modules():
if isinstance(module, nn.Linear):
init.kaiming_normal_(module.weight, nonlinearity='relu')
if module.bias is not None:
init.zeros_(module.bias)
def forward(self, x):
x = torch.relu(self.fc1(x))
return self.fc2(x)When to Use bias=False
The bias parameter adds a learnable constant to the output. You can disable it when it is redundant or harmful:
Disable bias when:
- A normalization layer follows immediately. BatchNorm and LayerNorm have their own bias terms, making the linear layer's bias redundant:
# Bias is absorbed by LayerNorm
layer = nn.Sequential(
nn.Linear(512, 512, bias=False),
nn.LayerNorm(512)
)- Q/K/V projections in attention. Many Transformer implementations disable bias for projections:
self.q_proj = nn.Linear(d_model, d_model, bias=False)
self.k_proj = nn.Linear(d_model, d_model, bias=False)
self.v_proj = nn.Linear(d_model, d_model, bias=False)- Embedding projections. When projecting embedding vectors, bias can add unnecessary parameters:
self.output_proj = nn.Linear(hidden_dim, vocab_size, bias=False)Keep bias when:
- The layer is the final output layer of a regression model
- No normalization follows the layer
- You need the model to learn a constant offset
nn.Linear vs nn.Conv2d
Both are linear operations, but they differ in how they connect inputs to outputs:
| Feature | nn.Linear | nn.Conv2d |
|---|---|---|
| Input shape | (*, in_features) | (N, C_in, H, W) |
| Connection type | Dense (all-to-all) | Local (kernel-sized window) |
| Parameter count | in * out + out | C_out * C_in * kH * kW + C_out |
| Spatial awareness | None | Yes (preserves spatial structure) |
| Best for | MLPs, Transformers, classification heads | Image processing, spatial data |
| Weight sharing | No | Yes (same kernel across all positions) |
A nn.Conv2d with kernel_size=1 is mathematically equivalent to applying nn.Linear across spatial positions:
# These produce the same result for 1x1 convolution
conv = nn.Conv2d(in_channels=64, out_channels=128, kernel_size=1, bias=False)
linear = nn.Linear(64, 128, bias=False)
# Copy weights
linear.weight.data = conv.weight.data.squeeze()
x = torch.randn(1, 64, 8, 8) # image input
out_conv = conv(x) # [1, 128, 8, 8]
out_linear = linear(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)
print(torch.allclose(out_conv, out_linear, atol=1e-6)) # TrueF.linear: The Functional API
torch.nn.functional.linear (commonly imported as F.linear) performs the same operation as nn.Linear but without a module wrapper. Use it when you need more control or are building custom layers:
import torch.nn.functional as F
x = torch.randn(32, 64)
weight = torch.randn(128, 64)
bias = torch.randn(128)
# Functional API — same as nn.Linear forward pass
output = F.linear(x, weight, bias)
print(output.shape) # torch.Size([32, 128])When to Use F.linear vs nn.Linear
| Scenario | Use |
|---|---|
| Standard model building | nn.Linear (manages parameters automatically) |
| Weight sharing between layers | F.linear (pass the same weight tensor) |
| Custom forward logic | F.linear (more flexible) |
| Dynamic weight generation (hypernetworks) | F.linear (weights computed on-the-fly) |
Example of weight sharing:
class SharedWeightNetwork(nn.Module):
def __init__(self, dim):
super().__init__()
self.weight = nn.Parameter(torch.randn(dim, dim))
self.bias = nn.Parameter(torch.zeros(dim))
def forward(self, x):
# Apply the same linear transformation twice
x = F.linear(x, self.weight, self.bias)
x = torch.relu(x)
x = F.linear(x, self.weight, self.bias)
return xApplications in Deep Learning
1. Multi-Layer Perceptron (MLP)
The simplest neural network architecture, stacking linear layers with activations:
class MLP(nn.Module):
def __init__(self, input_dim, hidden_dim, output_dim):
super().__init__()
self.layers = nn.Sequential(
nn.Linear(input_dim, hidden_dim),
nn.ReLU(),
nn.Dropout(0.1),
nn.Linear(hidden_dim, hidden_dim),
nn.ReLU(),
nn.Dropout(0.1),
nn.Linear(hidden_dim, output_dim),
)
def forward(self, x):
return self.layers(x)
model = MLP(784, 256, 10) # MNIST classifier2. Transformer Attention Projections
Every attention layer in a Transformer uses four linear layers:
class SelfAttention(nn.Module):
def __init__(self, d_model, n_heads):
super().__init__()
self.q = nn.Linear(d_model, d_model, bias=False)
self.k = nn.Linear(d_model, d_model, bias=False)
self.v = nn.Linear(d_model, d_model, bias=False)
self.out = nn.Linear(d_model, d_model, bias=False)
self.n_heads = n_heads
def forward(self, x):
B, T, C = x.shape
q = self.q(x).view(B, T, self.n_heads, C // self.n_heads).transpose(1, 2)
k = self.k(x).view(B, T, self.n_heads, C // self.n_heads).transpose(1, 2)
v = self.v(x).view(B, T, self.n_heads, C // self.n_heads).transpose(1, 2)
att = (q @ k.transpose(-2, -1)) * (C // self.n_heads) ** -0.5
att = torch.softmax(att, dim=-1)
out = (att @ v).transpose(1, 2).contiguous().view(B, T, C)
return self.out(out)3. Classification Head
The final layer that maps hidden representations to class predictions:
# ResNet classification head
classifier = nn.Linear(2048, 1000) # 2048 features -> 1000 ImageNet classes
# BERT classification head
class BertClassifier(nn.Module):
def __init__(self, hidden_size, num_classes):
super().__init__()
self.dropout = nn.Dropout(0.1)
self.classifier = nn.Linear(hidden_size, num_classes)
def forward(self, pooled_output):
return self.classifier(self.dropout(pooled_output))4. Autoencoder
Linear layers compress data into a bottleneck and reconstruct it:
class Autoencoder(nn.Module):
def __init__(self, input_dim, latent_dim):
super().__init__()
self.encoder = nn.Sequential(
nn.Linear(input_dim, 128),
nn.ReLU(),
nn.Linear(128, latent_dim),
)
self.decoder = nn.Sequential(
nn.Linear(latent_dim, 128),
nn.ReLU(),
nn.Linear(128, input_dim),
)
def forward(self, x):
z = self.encoder(x)
return self.decoder(z)Optimizing nn.Linear in PyTorch 2.x
torch.compile for Faster Execution
torch.compile fuses operations and generates optimized GPU kernels. Linear layers benefit significantly from compiled execution:
model = MLP(784, 256, 10).cuda()
# Compile the model for faster execution
compiled_model = torch.compile(model)
x = torch.randn(64, 784).cuda()
output = compiled_model(x) # First call triggers compilationTypical speedups for linear-heavy models range from 1.3x to 2x on GPU. The compiler fuses the matrix multiplication, bias addition, and activation into a single kernel.
Automatic Mixed Precision (AMP)
AMP uses float16 for linear layer computations on GPU, reducing memory usage and improving throughput:
model = MLP(784, 256, 10).cuda()
optimizer = torch.optim.Adam(model.parameters())
scaler = torch.amp.GradScaler()
x = torch.randn(64, 784).cuda()
target = torch.randint(0, 10, (64,)).cuda()
with torch.amp.autocast('cuda'):
output = model(x)
loss = torch.nn.functional.cross_entropy(output, target)
scaler.scale(loss).backward()
scaler.step(optimizer)
scaler.update()Quantized Linear Layers for Inference
Quantization converts float32 weights to int8, reducing model size by 4x and speeding up inference:
import torch.ao.quantization as quant
# Post-training dynamic quantization
model = MLP(784, 256, 10).eval()
quantized_model = torch.quantization.quantize_dynamic(
model,
{nn.Linear}, # Quantize only linear layers
dtype=torch.qint8
)
# Compare sizes
import os, tempfile
def model_size(model):
with tempfile.NamedTemporaryFile() as f:
torch.save(model.state_dict(), f.name)
return os.path.getsize(f.name)
print(f"Original: {model_size(model) / 1024:.1f} KB")
print(f"Quantized: {model_size(quantized_model) / 1024:.1f} KB")Performance Comparison
| Optimization | Speed Gain | Memory Reduction | Use Case |
|---|---|---|---|
| None (baseline) | 1x | 0% | Development/debugging |
torch.compile | 1.3-2x | 0% | Training and inference |
| AMP (float16) | 1.5-3x | ~50% | GPU training |
| Dynamic quantization (int8) | 1.5-4x | ~75% | CPU inference |
| Static quantization (int8) | 2-4x | ~75% | Optimized CPU inference |
Common Errors and How to Fix Them
1. Shape Mismatch Error
RuntimeError: mat1 and mat2 shapes cannot be multiplied (32x10 and 20x5)The input's last dimension (10) does not match in_features (20). Fix by ensuring dimensions align:
# Wrong: input has 10 features, but layer expects 20
linear = nn.Linear(20, 5)
x = torch.randn(32, 10)
# output = linear(x) # RuntimeError!
# Fix: match in_features to input dimension
linear = nn.Linear(10, 5)
output = linear(x) # Works: [32, 5]2. Integer Tensor Input
nn.Linear requires floating-point input. Passing integers raises an error:
# Wrong
x = torch.tensor([1, 2, 3])
# linear(x) # RuntimeError: expected scalar type Float
# Fix
x = torch.tensor([1, 2, 3], dtype=torch.float32)
# or
x = x.float()3. Forgetting to Flatten Before a Linear Layer
When transitioning from conv layers to linear layers, you need to flatten the spatial dimensions:
class CNN(nn.Module):
def __init__(self):
super().__init__()
self.conv = nn.Conv2d(3, 16, 3, padding=1)
self.pool = nn.AdaptiveAvgPool2d(1)
self.fc = nn.Linear(16, 10)
def forward(self, x):
x = self.conv(x) # [B, 16, H, W]
x = self.pool(x) # [B, 16, 1, 1]
x = x.flatten(start_dim=1) # [B, 16] — flatten required!
return self.fc(x) # [B, 10]4. Device Mismatch
Input and layer must be on the same device:
linear = nn.Linear(10, 5).cuda()
x = torch.randn(32, 10) # CPU tensor
# linear(x) # RuntimeError: Expected all tensors to be on the same device
x = x.cuda() # Move to same device
output = linear(x)Inspecting Linear Layer Parameters
Understanding what is inside a linear layer helps with debugging and model analysis:
linear = nn.Linear(4, 3)
# Access weight and bias directly
print(linear.weight) # Parameter tensor [3, 4]
print(linear.bias) # Parameter tensor [3]
# Count parameters
total_params = sum(p.numel() for p in linear.parameters())
print(f"Parameters: {total_params}") # 4*3 + 3 = 15
# Freeze weights (for transfer learning)
for param in linear.parameters():
param.requires_grad = False
# Check gradient computation
print(linear.weight.requires_grad) # False after freezingFor data scientists working in Jupyter notebooks, RunCell (opens in a new tab) provides an AI agent that can help you inspect model architectures, debug tensor shapes, and run experiments interactively — useful when building and testing PyTorch models.
Conclusion
nn.Linear is the most fundamental module in PyTorch. It performs a simple affine transformation, but understanding its behavior — shape handling, initialization, bias control, and optimization — is essential for building any neural network.
Key takeaways:
nn.Lineartransforms the last dimension of any input tensor, handling batches and sequences automatically- Default Kaiming initialization works well for ReLU networks; use Xavier for sigmoid/tanh
- Disable
biaswhen normalization layers follow immediately - Use
torch.compile, AMP, and quantization to optimize linear-heavy models in production - Use
F.linearfor weight sharing and custom forward logic
Whether you are building a simple MLP, a Transformer, or a large language model, linear layers make up the majority of learnable parameters in your network. Getting them right is the foundation of effective deep learning in PyTorch.
Frequently Asked Questions
What does nn.Linear do in PyTorch?
nn.Linear applies an affine transformation y = xW^T + b to the input tensor. It multiplies the input by a learnable weight matrix and adds a learnable bias vector. This is the standard fully connected layer used in neural networks.
What is the difference between nn.Linear and F.linear?
nn.Linear is a module that manages its own weight and bias parameters. F.linear is a function that takes weight and bias as explicit arguments. Use nn.Linear for standard model building and F.linear when you need weight sharing or custom forward logic.
How does nn.Linear handle batched inputs?
nn.Linear applies the transformation to the last dimension of the input tensor, regardless of how many leading dimensions exist. An input of shape (batch, seq_len, features) produces output (batch, seq_len, out_features) with no reshaping needed.
What initialization does nn.Linear use by default?
PyTorch uses Kaiming Uniform initialization for weights and a uniform distribution bounded by 1/sqrt(in_features) for bias. You can override this using functions from torch.nn.init.
When should I set bias=False in nn.Linear?
Set bias=False when the linear layer is immediately followed by a normalization layer (BatchNorm, LayerNorm) that has its own bias, or in Q/K/V projections where bias is typically unnecessary. Keep bias enabled for final output layers and layers without normalization.