Skip to content

Introdução: nn.Linear em PyTorch, Explicado com Clareza

Updated on

nn.Linear é um dos blocos de construção mais fundamentais em PyTorch.
Seja você estiver implementando um simples Multi-Layer Perceptron (MLP), um Transformer ou um sistema de deep learning em larga escala, camadas lineares estão em todo lugar — realizando transformações afins rápidas nos seus dados.

Com o PyTorch 2.x, camadas lineares agora se beneficiam de kernels atualizados, otimizações do compilador (torch.compile) e melhor aceleração em GPU. Este guia atualizado cobre tudo o que você precisa saber, do básico ao uso moderno em deep learning.

Quer criar rapidamente Data Visualization a partir de um DataFrame do Python Pandas com zero código?

PyGWalker é uma biblioteca Python para Exploratory Data Analysis com Visualização.
Ela transforma seu DataFrame de pandas ou polars em uma interface similar ao Tableau dentro do Jupyter Notebook.

👉 PyGWalker on GitHub (opens in a new tab)

PyGWalker for Data visualization (opens in a new tab)


Entendendo nn.Linear em PyTorch

O que é nn.Linear?

nn.Linear aplica uma transformação afim aos dados de entrada:

[ y = xA^T + b ]

  • x: tensor de entrada
  • A: matriz de pesos (out_features × in_features)
  • b: vetor de bias (out_features)
  • y: tensor de saída

O módulo recebe dois argumentos obrigatórios:

nn.Linear(in_features, out_features, bias=True)

Exemplo:

import torch
from torch import nn
 
linear_layer = nn.Linear(in_features=3, out_features=1)

Isso cria:

  • shape dos pesos: 1 × 3
  • shape do bias: 1

Como o nn.Linear funciona?

O forward pass realiza:

  • multiplicação de matrizes com os pesos
  • adição do bias (se estiver habilitado)
output = linear_layer(torch.tensor([1., 2., 3.]))
print(output)

📌 Uso moderno: entradas em batch

PyTorch aplica nn.Linear na última dimensão.
Então ele funciona de forma transparente com:

Entradas 2D (batch, features)

[batch, in_features]

Entradas 3D (batch, seq_len, features)

Comum em Transformers:

[batch, seq_len, in_features] → [batch, seq_len, out_features]

Exemplo:

x = torch.randn(32, 128, 512)   # batch=32, seq_len=128, hidden=512
linear = nn.Linear(512, 1024)
y = linear(x)                   # output: [32, 128, 1024]

Não é necessário fazer reshape — isso é importante para projeções de atenção (Q, K, V).


Inicializando pesos e biases

PyTorch inicializa os pesos com Kaiming Uniform por padrão, mas você pode customizar essa etapa usando torch.nn.init.

import torch.nn.init as init
 
init.xavier_uniform_(linear_layer.weight)
init.zeros_(linear_layer.bias)

Quando usar bias=False

Você geralmente desabilita o bias quando:

  • camadas de normalização vêm logo após a camada linear
    (por exemplo, LayerNorm em Transformers)
  • camadas de projeção Q/K/V em blocos de atenção
  • camadas lineares usadas dentro de blocos residuais

Exemplo:

nn.Linear(embed_dim, embed_dim, bias=False)

Comparando nn.Linear e nn.Conv2d

CamadaMelhor usoOperaçãoParâmetros
nn.Linearvetores, entrada achatada, MLPs, TransformersTransformação afimin_features, out_features
nn.Conv2ddados de imagem, padrões espaciais locaisConvolução com janela móvelchannels, kernel size, stride, padding

Principais diferenças:

  • Camadas lineares usam conexões densas (tudo-para-tudo).
  • Conv2D usa kernels espacialmente locais, reduzindo parâmetros e introduzindo viés indutivo.

Aplicações de nn.Linear em Deep Learning

1. Multi-Layer Perceptron (MLP)

layer = nn.Linear(256, 128)

2. Modelos Transformer

Um Transformer padrão usa camadas lineares para:

  • projeções Q, K, V
  • redes feed-forward (MLP)
  • heads de saída

Exemplo:

q = nn.Linear(d_model, d_model, bias=False)

3. Heads de classificação

classifier = nn.Linear(hidden_dim, num_classes)

4. Autoencoders

Tanto o encoder quanto o decoder usam camadas lineares para comprimir e reconstruir.


Usando nn.Linear em um modelo PyTorch

import torch.nn.functional as F
from torch import nn
 
class Net(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc1 = nn.Linear(10, 5)
        self.fc2 = nn.Linear(5, 2)
 
    def forward(self, x):
        x = F.relu(self.fc1(x))
        return self.fc2(x)
 
net = Net()

Dicas modernas para PyTorch 2.x (Importante!)

🚀 1. torch.compile otimiza camadas lineares

model = torch.compile(net)

O compile faz fusão de operações para melhorar o desempenho em GPU.


⚡ 2. AMP para treinamento mais rápido em GPUs

with torch.cuda.amp.autocast():
    output = model(x)

📉 3. Linear quantizado para inferência

import torch.ao.quantization as quant

Reduz o tamanho do modelo e acelera a inferência.


🔧 4. Use F.linear para operações customizadas fundidas

import torch.nn.functional as F
y = F.linear(x, weight, bias)

Útil ao combinar manualmente múltiplas operações.


Erros comuns e como corrigi-los

1. Incompatibilidade de shapes

RuntimeError: mat1 and mat2 shapes cannot be multiplied

Correção: garanta que a última dimensão da entrada corresponda a in_features.

2. Passar entradas 3D sem deixar nn.Linear lidar com o shape

Evite fazer reshape desnecessariamente — Linear já oferece suporte a tensores 3D.

3. Esquecer de usar .float() ao misturar tensores inteiros

x = x.float()

Conclusão

nn.Linear é um módulo aparentemente simples, mas fundamental em PyTorch.
De MLPs a Transformers e modelos generativos modernos, camadas lineares respondem pela maior parte da computação nos sistemas de deep learning atuais.

Entender:

  • como elas funcionam
  • como lidam com diferentes shapes de entrada
  • como otimizá-las com PyTorch 2.x

vai te tornar significativamente mais eficiente ao construir modelos neurais.


FAQs

Qual é o propósito do vetor de bias em nn.Linear?

O bias permite que o modelo desloque os valores de saída independentemente da entrada. Isso aumenta a flexibilidade, especialmente quando as entradas não são centradas em zero.

Como inicializar pesos e biases?

PyTorch inicializa automaticamente, mas você pode sobrescrever usando torch.nn.init (por exemplo, inicialização Xavier).

Qual a diferença entre nn.Linear e nn.Conv2d?

nn.Linear aplica uma transformação afim densa; nn.Conv2d aplica convolução espacial. Camadas lineares são usadas em MLPs e Transformers, enquanto camadas Conv2D são usadas em CNNs.