Introdução: nn.Linear em PyTorch, Explicado com Clareza
Updated on
nn.Linear é um dos blocos de construção mais fundamentais em PyTorch.
Seja você estiver implementando um simples Multi-Layer Perceptron (MLP), um Transformer ou um sistema de deep learning em larga escala, camadas lineares estão em todo lugar — realizando transformações afins rápidas nos seus dados.
Com o PyTorch 2.x, camadas lineares agora se beneficiam de kernels atualizados, otimizações do compilador (torch.compile) e melhor aceleração em GPU. Este guia atualizado cobre tudo o que você precisa saber, do básico ao uso moderno em deep learning.
Quer criar rapidamente Data Visualization a partir de um DataFrame do Python Pandas com zero código?
PyGWalker é uma biblioteca Python para Exploratory Data Analysis com Visualização.
Ela transforma seu DataFrame de pandas ou polars em uma interface similar ao Tableau dentro do Jupyter Notebook.
Entendendo nn.Linear em PyTorch
O que é nn.Linear?
nn.Linear aplica uma transformação afim aos dados de entrada:
[ y = xA^T + b ]
x: tensor de entradaA: matriz de pesos (out_features × in_features)b: vetor de bias (out_features)y: tensor de saída
O módulo recebe dois argumentos obrigatórios:
nn.Linear(in_features, out_features, bias=True)Exemplo:
import torch
from torch import nn
linear_layer = nn.Linear(in_features=3, out_features=1)Isso cria:
- shape dos pesos:
1 × 3 - shape do bias:
1
Como o nn.Linear funciona?
O forward pass realiza:
- multiplicação de matrizes com os pesos
- adição do bias (se estiver habilitado)
output = linear_layer(torch.tensor([1., 2., 3.]))
print(output)📌 Uso moderno: entradas em batch
PyTorch aplica nn.Linear na última dimensão.
Então ele funciona de forma transparente com:
Entradas 2D (batch, features)
[batch, in_features]Entradas 3D (batch, seq_len, features)
Comum em Transformers:
[batch, seq_len, in_features] → [batch, seq_len, out_features]Exemplo:
x = torch.randn(32, 128, 512) # batch=32, seq_len=128, hidden=512
linear = nn.Linear(512, 1024)
y = linear(x) # output: [32, 128, 1024]Não é necessário fazer reshape — isso é importante para projeções de atenção (Q, K, V).
Inicializando pesos e biases
PyTorch inicializa os pesos com Kaiming Uniform por padrão, mas você pode customizar essa etapa usando torch.nn.init.
import torch.nn.init as init
init.xavier_uniform_(linear_layer.weight)
init.zeros_(linear_layer.bias)Quando usar bias=False
Você geralmente desabilita o bias quando:
- camadas de normalização vêm logo após a camada linear
(por exemplo,LayerNormem Transformers) - camadas de projeção Q/K/V em blocos de atenção
- camadas lineares usadas dentro de blocos residuais
Exemplo:
nn.Linear(embed_dim, embed_dim, bias=False)Comparando nn.Linear e nn.Conv2d
| Camada | Melhor uso | Operação | Parâmetros |
|---|---|---|---|
nn.Linear | vetores, entrada achatada, MLPs, Transformers | Transformação afim | in_features, out_features |
nn.Conv2d | dados de imagem, padrões espaciais locais | Convolução com janela móvel | channels, kernel size, stride, padding |
Principais diferenças:
- Camadas lineares usam conexões densas (tudo-para-tudo).
- Conv2D usa kernels espacialmente locais, reduzindo parâmetros e introduzindo viés indutivo.
Aplicações de nn.Linear em Deep Learning
1. Multi-Layer Perceptron (MLP)
layer = nn.Linear(256, 128)2. Modelos Transformer
Um Transformer padrão usa camadas lineares para:
- projeções Q, K, V
- redes feed-forward (MLP)
- heads de saída
Exemplo:
q = nn.Linear(d_model, d_model, bias=False)3. Heads de classificação
classifier = nn.Linear(hidden_dim, num_classes)4. Autoencoders
Tanto o encoder quanto o decoder usam camadas lineares para comprimir e reconstruir.
Usando nn.Linear em um modelo PyTorch
import torch.nn.functional as F
from torch import nn
class Net(nn.Module):
def __init__(self):
super().__init__()
self.fc1 = nn.Linear(10, 5)
self.fc2 = nn.Linear(5, 2)
def forward(self, x):
x = F.relu(self.fc1(x))
return self.fc2(x)
net = Net()Dicas modernas para PyTorch 2.x (Importante!)
🚀 1. torch.compile otimiza camadas lineares
model = torch.compile(net)O compile faz fusão de operações para melhorar o desempenho em GPU.
⚡ 2. AMP para treinamento mais rápido em GPUs
with torch.cuda.amp.autocast():
output = model(x)📉 3. Linear quantizado para inferência
import torch.ao.quantization as quantReduz o tamanho do modelo e acelera a inferência.
🔧 4. Use F.linear para operações customizadas fundidas
import torch.nn.functional as F
y = F.linear(x, weight, bias)Útil ao combinar manualmente múltiplas operações.
Erros comuns e como corrigi-los
1. Incompatibilidade de shapes
RuntimeError: mat1 and mat2 shapes cannot be multipliedCorreção: garanta que a última dimensão da entrada corresponda a in_features.
2. Passar entradas 3D sem deixar nn.Linear lidar com o shape
Evite fazer reshape desnecessariamente — Linear já oferece suporte a tensores 3D.
3. Esquecer de usar .float() ao misturar tensores inteiros
x = x.float()Conclusão
nn.Linear é um módulo aparentemente simples, mas fundamental em PyTorch.
De MLPs a Transformers e modelos generativos modernos, camadas lineares respondem pela maior parte da computação nos sistemas de deep learning atuais.
Entender:
- como elas funcionam
- como lidam com diferentes shapes de entrada
- como otimizá-las com PyTorch 2.x
vai te tornar significativamente mais eficiente ao construir modelos neurais.
FAQs
Qual é o propósito do vetor de bias em nn.Linear?
O bias permite que o modelo desloque os valores de saída independentemente da entrada. Isso aumenta a flexibilidade, especialmente quando as entradas não são centradas em zero.
Como inicializar pesos e biases?
PyTorch inicializa automaticamente, mas você pode sobrescrever usando torch.nn.init (por exemplo, inicialização Xavier).
Qual a diferença entre nn.Linear e nn.Conv2d?
nn.Linear aplica uma transformação afim densa; nn.Conv2d aplica convolução espacial. Camadas lineares são usadas em MLPs e Transformers, enquanto camadas Conv2D são usadas em CNNs.
