Skip to content

nn.Linear in PyTorch: Complete Guide to Linear Layers

Updated on

You are building a neural network in PyTorch and every tutorial mentions nn.Linear, but the documentation is dense and the shape errors are cryptic. You pass a tensor with the wrong dimensions and get RuntimeError: mat1 and mat2 shapes cannot be multiplied. You are not sure when to set bias=False, how weight initialization affects training, or why Transformer models use linear layers everywhere instead of convolutions.

These confusions slow down debugging and model development. A mismatched in_features can silently produce garbage outputs before you notice. Choosing the wrong initialization scheme can cause gradients to vanish or explode. And without understanding how nn.Linear handles batched and 3D inputs, you end up writing unnecessary reshape operations that clutter your code.

This guide explains nn.Linear from the ground up: what it computes, how shapes work, how to initialize weights properly, when to disable bias, how it fits into modern architectures, and how to optimize it with PyTorch 2.x features like torch.compile and automatic mixed precision.

📚

What is nn.Linear?

nn.Linear is a module that applies an affine (linear) transformation to input data. It is the core building block for fully connected layers in neural networks.

The mathematical operation is:

$$y = xW^T + b$$

Where:

  • x is the input tensor
  • W is the weight matrix with shape (out_features, in_features)
  • b is the bias vector with shape (out_features,)
  • y is the output tensor

Creating a Linear Layer

The constructor takes two required arguments and one optional:

import torch
from torch import nn
 
# Create a linear layer: 3 inputs -> 1 output
linear = nn.Linear(in_features=3, out_features=1)
 
# Inspect the parameters
print(linear.weight.shape)  # torch.Size([1, 3])
print(linear.bias.shape)    # torch.Size([1])

The in_features parameter defines how many input values the layer expects. The out_features parameter defines the output dimension. PyTorch automatically creates the weight matrix and bias vector with the correct shapes.

torch.nn.Linear Constructor Parameters

ParameterTypeDefaultDescription
in_featuresintrequiredSize of each input sample
out_featuresintrequiredSize of each output sample
biasboolTrueIf False, the layer does not learn an additive bias
devicedeviceNoneDevice for parameter tensors
dtypedtypeNoneData type for parameter tensors

How nn.Linear Works: Forward Pass

When you call the layer on an input tensor, it performs matrix multiplication followed by bias addition:

import torch
from torch import nn
 
linear = nn.Linear(3, 2)
 
# Single sample: shape [3]
x = torch.tensor([1.0, 2.0, 3.0])
output = linear(x)
print(output.shape)  # torch.Size([2])
 
# Verify manually
manual = x @ linear.weight.T + linear.bias
print(torch.allclose(output, manual))  # True

The key insight is that nn.Linear operates on the last dimension of the input tensor. This means it handles batched inputs automatically without any reshaping.

2D Input: Batched Samples

# Batch of 32 samples, each with 3 features
x = torch.randn(32, 3)
linear = nn.Linear(3, 2)
output = linear(x)
print(output.shape)  # torch.Size([32, 2])

3D Input: Sequence Data (Transformers)

This is critical for Transformer models where inputs have shape (batch, sequence_length, hidden_dim):

# Batch of 8, sequence length 128, hidden dimension 512
x = torch.randn(8, 128, 512)
linear = nn.Linear(512, 256)
output = linear(x)
print(output.shape)  # torch.Size([8, 128, 256])

No reshape needed. The linear layer applies the same transformation to every position in the sequence independently. This is exactly how Q, K, V projection layers work in attention mechanisms.

Arbitrary Dimensions

nn.Linear works with any number of leading dimensions:

# 4D input
x = torch.randn(2, 4, 8, 16)
linear = nn.Linear(16, 32)
output = linear(x)
print(output.shape)  # torch.Size([2, 4, 8, 32])

Weight Initialization

PyTorch initializes nn.Linear weights using Kaiming Uniform initialization by default. This works well for layers followed by ReLU activations. The bias is initialized from a uniform distribution bounded by 1 / sqrt(in_features).

Default Initialization

linear = nn.Linear(256, 128)
 
# Default: Kaiming uniform
print(linear.weight.min().item(), linear.weight.max().item())
# Approximately -0.0625 to 0.0625

Custom Initialization

Use torch.nn.init to apply a different initialization scheme:

import torch.nn.init as init
 
linear = nn.Linear(256, 128)
 
# Xavier uniform — good for sigmoid/tanh activations
init.xavier_uniform_(linear.weight)
init.zeros_(linear.bias)
 
# Xavier normal
init.xavier_normal_(linear.weight)
 
# Kaiming normal — good for ReLU activations
init.kaiming_normal_(linear.weight, mode='fan_out', nonlinearity='relu')
 
# Constant initialization
init.constant_(linear.weight, 0.01)
init.zeros_(linear.bias)
 
# Normal distribution
init.normal_(linear.weight, mean=0.0, std=0.02)

Initialization Comparison Table

MethodBest ForFormulaWhen to Use
Kaiming Uniform (default)ReLU networksU(-bound, bound)MLPs, CNNs with ReLU
Xavier UniformSigmoid/TanhU(-a, a) where a = sqrt(6/(fan_in + fan_out))Networks with sigmoid/tanh
Xavier NormalSigmoid/TanhN(0, std) where std = sqrt(2/(fan_in + fan_out))Alternative to Xavier Uniform
Kaiming NormalReLU networksN(0, std) where std = sqrt(2/fan_in)ResNets, deeper networks
Normal(0, 0.02)TransformersN(0, 0.02)GPT-style models

Initialization in a Model

Apply initialization during __init__ or with apply():

class MLP(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim):
        super().__init__()
        self.fc1 = nn.Linear(input_dim, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, output_dim)
        self._init_weights()
 
    def _init_weights(self):
        for module in self.modules():
            if isinstance(module, nn.Linear):
                init.kaiming_normal_(module.weight, nonlinearity='relu')
                if module.bias is not None:
                    init.zeros_(module.bias)
 
    def forward(self, x):
        x = torch.relu(self.fc1(x))
        return self.fc2(x)

When to Use bias=False

The bias parameter adds a learnable constant to the output. You can disable it when it is redundant or harmful:

Disable bias when:

  1. A normalization layer follows immediately. BatchNorm and LayerNorm have their own bias terms, making the linear layer's bias redundant:
# Bias is absorbed by LayerNorm
layer = nn.Sequential(
    nn.Linear(512, 512, bias=False),
    nn.LayerNorm(512)
)
  1. Q/K/V projections in attention. Many Transformer implementations disable bias for projections:
self.q_proj = nn.Linear(d_model, d_model, bias=False)
self.k_proj = nn.Linear(d_model, d_model, bias=False)
self.v_proj = nn.Linear(d_model, d_model, bias=False)
  1. Embedding projections. When projecting embedding vectors, bias can add unnecessary parameters:
self.output_proj = nn.Linear(hidden_dim, vocab_size, bias=False)

Keep bias when:

  • The layer is the final output layer of a regression model
  • No normalization follows the layer
  • You need the model to learn a constant offset

nn.Linear vs nn.Conv2d

Both are linear operations, but they differ in how they connect inputs to outputs:

Featurenn.Linearnn.Conv2d
Input shape(*, in_features)(N, C_in, H, W)
Connection typeDense (all-to-all)Local (kernel-sized window)
Parameter countin * out + outC_out * C_in * kH * kW + C_out
Spatial awarenessNoneYes (preserves spatial structure)
Best forMLPs, Transformers, classification headsImage processing, spatial data
Weight sharingNoYes (same kernel across all positions)

A nn.Conv2d with kernel_size=1 is mathematically equivalent to applying nn.Linear across spatial positions:

# These produce the same result for 1x1 convolution
conv = nn.Conv2d(in_channels=64, out_channels=128, kernel_size=1, bias=False)
linear = nn.Linear(64, 128, bias=False)
 
# Copy weights
linear.weight.data = conv.weight.data.squeeze()
 
x = torch.randn(1, 64, 8, 8)  # image input
out_conv = conv(x)  # [1, 128, 8, 8]
out_linear = linear(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)
print(torch.allclose(out_conv, out_linear, atol=1e-6))  # True

F.linear: The Functional API

torch.nn.functional.linear (commonly imported as F.linear) performs the same operation as nn.Linear but without a module wrapper. Use it when you need more control or are building custom layers:

import torch.nn.functional as F
 
x = torch.randn(32, 64)
weight = torch.randn(128, 64)
bias = torch.randn(128)
 
# Functional API — same as nn.Linear forward pass
output = F.linear(x, weight, bias)
print(output.shape)  # torch.Size([32, 128])

When to Use F.linear vs nn.Linear

ScenarioUse
Standard model buildingnn.Linear (manages parameters automatically)
Weight sharing between layersF.linear (pass the same weight tensor)
Custom forward logicF.linear (more flexible)
Dynamic weight generation (hypernetworks)F.linear (weights computed on-the-fly)

Example of weight sharing:

class SharedWeightNetwork(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.weight = nn.Parameter(torch.randn(dim, dim))
        self.bias = nn.Parameter(torch.zeros(dim))
 
    def forward(self, x):
        # Apply the same linear transformation twice
        x = F.linear(x, self.weight, self.bias)
        x = torch.relu(x)
        x = F.linear(x, self.weight, self.bias)
        return x

Applications in Deep Learning

1. Multi-Layer Perceptron (MLP)

The simplest neural network architecture, stacking linear layers with activations:

class MLP(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim):
        super().__init__()
        self.layers = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(hidden_dim, output_dim),
        )
 
    def forward(self, x):
        return self.layers(x)
 
model = MLP(784, 256, 10)  # MNIST classifier

2. Transformer Attention Projections

Every attention layer in a Transformer uses four linear layers:

class SelfAttention(nn.Module):
    def __init__(self, d_model, n_heads):
        super().__init__()
        self.q = nn.Linear(d_model, d_model, bias=False)
        self.k = nn.Linear(d_model, d_model, bias=False)
        self.v = nn.Linear(d_model, d_model, bias=False)
        self.out = nn.Linear(d_model, d_model, bias=False)
        self.n_heads = n_heads
 
    def forward(self, x):
        B, T, C = x.shape
        q = self.q(x).view(B, T, self.n_heads, C // self.n_heads).transpose(1, 2)
        k = self.k(x).view(B, T, self.n_heads, C // self.n_heads).transpose(1, 2)
        v = self.v(x).view(B, T, self.n_heads, C // self.n_heads).transpose(1, 2)
 
        att = (q @ k.transpose(-2, -1)) * (C // self.n_heads) ** -0.5
        att = torch.softmax(att, dim=-1)
        out = (att @ v).transpose(1, 2).contiguous().view(B, T, C)
        return self.out(out)

3. Classification Head

The final layer that maps hidden representations to class predictions:

# ResNet classification head
classifier = nn.Linear(2048, 1000)  # 2048 features -> 1000 ImageNet classes
 
# BERT classification head
class BertClassifier(nn.Module):
    def __init__(self, hidden_size, num_classes):
        super().__init__()
        self.dropout = nn.Dropout(0.1)
        self.classifier = nn.Linear(hidden_size, num_classes)
 
    def forward(self, pooled_output):
        return self.classifier(self.dropout(pooled_output))

4. Autoencoder

Linear layers compress data into a bottleneck and reconstruct it:

class Autoencoder(nn.Module):
    def __init__(self, input_dim, latent_dim):
        super().__init__()
        self.encoder = nn.Sequential(
            nn.Linear(input_dim, 128),
            nn.ReLU(),
            nn.Linear(128, latent_dim),
        )
        self.decoder = nn.Sequential(
            nn.Linear(latent_dim, 128),
            nn.ReLU(),
            nn.Linear(128, input_dim),
        )
 
    def forward(self, x):
        z = self.encoder(x)
        return self.decoder(z)

Optimizing nn.Linear in PyTorch 2.x

torch.compile for Faster Execution

torch.compile fuses operations and generates optimized GPU kernels. Linear layers benefit significantly from compiled execution:

model = MLP(784, 256, 10).cuda()
 
# Compile the model for faster execution
compiled_model = torch.compile(model)
 
x = torch.randn(64, 784).cuda()
output = compiled_model(x)  # First call triggers compilation

Typical speedups for linear-heavy models range from 1.3x to 2x on GPU. The compiler fuses the matrix multiplication, bias addition, and activation into a single kernel.

Automatic Mixed Precision (AMP)

AMP uses float16 for linear layer computations on GPU, reducing memory usage and improving throughput:

model = MLP(784, 256, 10).cuda()
optimizer = torch.optim.Adam(model.parameters())
scaler = torch.amp.GradScaler()
 
x = torch.randn(64, 784).cuda()
target = torch.randint(0, 10, (64,)).cuda()
 
with torch.amp.autocast('cuda'):
    output = model(x)
    loss = torch.nn.functional.cross_entropy(output, target)
 
scaler.scale(loss).backward()
scaler.step(optimizer)
scaler.update()

Quantized Linear Layers for Inference

Quantization converts float32 weights to int8, reducing model size by 4x and speeding up inference:

import torch.ao.quantization as quant
 
# Post-training dynamic quantization
model = MLP(784, 256, 10).eval()
quantized_model = torch.quantization.quantize_dynamic(
    model,
    {nn.Linear},  # Quantize only linear layers
    dtype=torch.qint8
)
 
# Compare sizes
import os, tempfile
def model_size(model):
    with tempfile.NamedTemporaryFile() as f:
        torch.save(model.state_dict(), f.name)
        return os.path.getsize(f.name)
 
print(f"Original: {model_size(model) / 1024:.1f} KB")
print(f"Quantized: {model_size(quantized_model) / 1024:.1f} KB")

Performance Comparison

OptimizationSpeed GainMemory ReductionUse Case
None (baseline)1x0%Development/debugging
torch.compile1.3-2x0%Training and inference
AMP (float16)1.5-3x~50%GPU training
Dynamic quantization (int8)1.5-4x~75%CPU inference
Static quantization (int8)2-4x~75%Optimized CPU inference

Common Errors and How to Fix Them

1. Shape Mismatch Error

RuntimeError: mat1 and mat2 shapes cannot be multiplied (32x10 and 20x5)

The input's last dimension (10) does not match in_features (20). Fix by ensuring dimensions align:

# Wrong: input has 10 features, but layer expects 20
linear = nn.Linear(20, 5)
x = torch.randn(32, 10)
# output = linear(x)  # RuntimeError!
 
# Fix: match in_features to input dimension
linear = nn.Linear(10, 5)
output = linear(x)  # Works: [32, 5]

2. Integer Tensor Input

nn.Linear requires floating-point input. Passing integers raises an error:

# Wrong
x = torch.tensor([1, 2, 3])
# linear(x)  # RuntimeError: expected scalar type Float
 
# Fix
x = torch.tensor([1, 2, 3], dtype=torch.float32)
# or
x = x.float()

3. Forgetting to Flatten Before a Linear Layer

When transitioning from conv layers to linear layers, you need to flatten the spatial dimensions:

class CNN(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv = nn.Conv2d(3, 16, 3, padding=1)
        self.pool = nn.AdaptiveAvgPool2d(1)
        self.fc = nn.Linear(16, 10)
 
    def forward(self, x):
        x = self.conv(x)           # [B, 16, H, W]
        x = self.pool(x)           # [B, 16, 1, 1]
        x = x.flatten(start_dim=1) # [B, 16] — flatten required!
        return self.fc(x)          # [B, 10]

4. Device Mismatch

Input and layer must be on the same device:

linear = nn.Linear(10, 5).cuda()
x = torch.randn(32, 10)  # CPU tensor
# linear(x)  # RuntimeError: Expected all tensors to be on the same device
 
x = x.cuda()  # Move to same device
output = linear(x)

Inspecting Linear Layer Parameters

Understanding what is inside a linear layer helps with debugging and model analysis:

linear = nn.Linear(4, 3)
 
# Access weight and bias directly
print(linear.weight)       # Parameter tensor [3, 4]
print(linear.bias)         # Parameter tensor [3]
 
# Count parameters
total_params = sum(p.numel() for p in linear.parameters())
print(f"Parameters: {total_params}")  # 4*3 + 3 = 15
 
# Freeze weights (for transfer learning)
for param in linear.parameters():
    param.requires_grad = False
 
# Check gradient computation
print(linear.weight.requires_grad)  # False after freezing

For data scientists working in Jupyter notebooks, RunCell (opens in a new tab) provides an AI agent that can help you inspect model architectures, debug tensor shapes, and run experiments interactively — useful when building and testing PyTorch models.


Conclusion

nn.Linear is the most fundamental module in PyTorch. It performs a simple affine transformation, but understanding its behavior — shape handling, initialization, bias control, and optimization — is essential for building any neural network.

Key takeaways:

  • nn.Linear transforms the last dimension of any input tensor, handling batches and sequences automatically
  • Default Kaiming initialization works well for ReLU networks; use Xavier for sigmoid/tanh
  • Disable bias when normalization layers follow immediately
  • Use torch.compile, AMP, and quantization to optimize linear-heavy models in production
  • Use F.linear for weight sharing and custom forward logic

Whether you are building a simple MLP, a Transformer, or a large language model, linear layers make up the majority of learnable parameters in your network. Getting them right is the foundation of effective deep learning in PyTorch.


Frequently Asked Questions

What does nn.Linear do in PyTorch?

nn.Linear applies an affine transformation y = xW^T + b to the input tensor. It multiplies the input by a learnable weight matrix and adds a learnable bias vector. This is the standard fully connected layer used in neural networks.

What is the difference between nn.Linear and F.linear?

nn.Linear is a module that manages its own weight and bias parameters. F.linear is a function that takes weight and bias as explicit arguments. Use nn.Linear for standard model building and F.linear when you need weight sharing or custom forward logic.

How does nn.Linear handle batched inputs?

nn.Linear applies the transformation to the last dimension of the input tensor, regardless of how many leading dimensions exist. An input of shape (batch, seq_len, features) produces output (batch, seq_len, out_features) with no reshaping needed.

What initialization does nn.Linear use by default?

PyTorch uses Kaiming Uniform initialization for weights and a uniform distribution bounded by 1/sqrt(in_features) for bias. You can override this using functions from torch.nn.init.

When should I set bias=False in nn.Linear?

Set bias=False when the linear layer is immediately followed by a normalization layer (BatchNorm, LayerNorm) that has its own bias, or in Q/K/V projections where bias is typically unnecessary. Keep bias enabled for final output layers and layers without normalization.

📚