Introducción: nn.Linear en PyTorch, explicado claramente
Updated on
nn.Linear es uno de los bloques de construcción más fundamentales en PyTorch.
Tanto si estás construyendo un Perceptrón Multicapa (MLP) sencillo, un Transformer o un sistema de deep learning a gran escala, las capas lineales están en todas partes, realizando transformaciones afines rápidas sobre tus datos.
Con PyTorch 2.x, las capas lineales se benefician de kernels actualizados, optimizaciones del compilador (torch.compile) y una mejor aceleración en GPU. Esta guía actualizada cubre todo lo que necesitas saber, desde los conceptos básicos hasta el uso en deep learning moderno.
¿Quieres crear rápidamente visualizaciones de datos desde un DataFrame de Python Pandas sin código?
PyGWalker es una librería de Python para análisis exploratorio de datos con visualización.
Convierte tu DataFrame de pandas o polars en una interfaz similar a Tableau dentro de Jupyter Notebook.
Entendiendo nn.Linear en PyTorch
¿Qué es nn.Linear?
nn.Linear aplica una transformación afín a los datos de entrada:
[ y = xA^T + b ]
x: tensor de entradaA: matriz de pesos (out_features × in_features)b: vector de sesgo (out_features)y: tensor de salida
El módulo recibe dos argumentos obligatorios:
nn.Linear(in_features, out_features, bias=True)Ejemplo:
import torch
from torch import nn
linear_layer = nn.Linear(in_features=3, out_features=1)Esto crea:
- forma de
weight:1 × 3 - forma de
bias:1
¿Cómo funciona nn.Linear?
En el paso hacia adelante realiza:
- multiplicación matricial con los pesos
- suma del sesgo (si está habilitado)
output = linear_layer(torch.tensor([1., 2., 3.]))
print(output)📌 Uso moderno: entradas con batch
PyTorch aplica nn.Linear a la última dimensión.
Así que funciona de manera transparente con:
Entradas 2D (batch, features)
[batch, in_features]Entradas 3D (batch, seq_len, features)
Común en Transformers:
[batch, seq_len, in_features] → [batch, seq_len, out_features]Ejemplo:
x = torch.randn(32, 128, 512) # batch=32, seq_len=128, hidden=512
linear = nn.Linear(512, 1024)
y = linear(x) # output: [32, 128, 1024]No hace falta hacer reshape — esto es importante para las proyecciones de atención (Q, K, V).
Inicialización de pesos y sesgos
PyTorch inicializa los pesos usando Kaiming Uniform por defecto, pero puedes personalizar esto usando torch.nn.init.
import torch.nn.init as init
init.xavier_uniform_(linear_layer.weight)
init.zeros_(linear_layer.bias)Cuándo usar bias=False
A menudo se desactiva el sesgo cuando:
- se usan capas de normalización después de la capa lineal
(por ejemplo,LayerNormen Transformers) - capas de proyección Q/K/V en bloques de atención
- capas lineales usadas dentro de bloques residuales
Ejemplo:
nn.Linear(embed_dim, embed_dim, bias=False)Comparando nn.Linear y nn.Conv2d
| Capa | Mejor para | Operación | Parámetros |
|---|---|---|---|
nn.Linear | vectores, entrada aplanada, MLPs, Transformers | Transformación afín | in_features, out_features |
nn.Conv2d | datos de imagen, patrones espaciales locales | Convolución de ventana deslizante | channels, kernel size, stride, padding |
Diferencias clave:
- Las capas lineales usan conexiones densas (todos-contra-todos).
- Conv2D usa kernels espaciales locales, reduciendo parámetros e introduciendo sesgos inductivos.
Aplicaciones de nn.Linear en Deep Learning
1. Perceptrón Multicapa (MLP)
layer = nn.Linear(256, 128)2. Modelos Transformer
Un Transformer estándar usa capas lineales para:
- proyecciones Q, K, V
- redes feed-forward tipo MLP
- cabezas de salida
Ejemplo:
q = nn.Linear(d_model, d_model, bias=False)3. Cabeceras de clasificación
classifier = nn.Linear(hidden_dim, num_classes)4. Autoencoders
Tanto el encoder como el decoder usan capas lineales para comprimir y reconstruir.
Uso de nn.Linear en un modelo de PyTorch
import torch.nn.functional as F
from torch import nn
class Net(nn.Module):
def __init__(self):
super().__init__()
self.fc1 = nn.Linear(10, 5)
self.fc2 = nn.Linear(5, 2)
def forward(self, x):
x = F.relu(self.fc1(x))
return self.fc2(x)
net = Net()Consejos modernos para PyTorch 2.x (¡Importante!)
🚀 1. torch.compile optimiza las capas lineales
model = torch.compile(net)Compilar fusiona operaciones para un rendimiento más rápido en GPU.
⚡ 2. AMP para entrenamiento más rápido en GPUs
with torch.cuda.amp.autocast():
output = model(x)📉 3. Linear cuantizado para inferencia
import torch.ao.quantization as quantReduce el tamaño del modelo y acelera la inferencia.
🔧 4. Usa F.linear para operaciones fusionadas personalizadas
import torch.nn.functional as F
y = F.linear(x, weight, bias)Útil cuando combinas manualmente varias operaciones.
Errores comunes y cómo solucionarlos
1. Desajuste de formas
RuntimeError: mat1 and mat2 shapes cannot be multipliedSolución: asegúrate de que la última dimensión de la entrada coincide con in_features.
2. Pasar entradas 3D sin dejar que nn.Linear gestione la forma
Evita hacer reshape innecesariamente — Linear ya soporta entradas 3D.
3. Olvidar .float() al mezclar tensores enteros
x = x.float()Conclusión
nn.Linear es un módulo aparentemente simple pero fundamental en PyTorch.
Desde MLPs hasta Transformers y modelos generativos modernos, las capas lineales constituyen la mayor parte del cómputo en los sistemas de deep learning actuales.
Entender:
- cómo funcionan
- cómo manejan distintas formas de entrada
- cómo optimizarlas con PyTorch 2.x
te hará mucho más eficaz a la hora de construir modelos neuronales.
Preguntas frecuentes (FAQs)
¿Cuál es el propósito del vector de sesgo en nn.Linear?
El sesgo permite al modelo desplazar los valores de salida independientemente de la entrada. Esto mejora la flexibilidad, especialmente cuando las entradas no están centradas en cero.
¿Cómo se inicializan pesos y sesgos?
PyTorch los inicializa automáticamente, pero puedes sobrescribir esta inicialización usando torch.nn.init (por ejemplo, inicialización Xavier).
¿Diferencia entre nn.Linear y nn.Conv2d?
nn.Linear aplica una transformación afín densa; nn.Conv2d aplica una convolución espacial. Las capas lineales se usan en MLPs y Transformers, mientras que las capas Conv2D se usan en CNNs.
