Skip to content

Introduction: nn.Linear in PyTorch, Clearly Explained

Updated on

nn.Linear는 PyTorch에서 가장 기본이 되는 블록 중 하나입니다.
간단한 Multi-Layer Perceptron(MLP), Transformer, 대규모 딥러닝 시스템까지 무엇을 만들든, 선형 계층은 데이터에 빠른 affine 변환을 수행하는 핵심 요소로 거의 모든 곳에 사용됩니다.

PyTorch 2.x에서는 linear layer가 업데이트된 커널, 컴파일러 최적화(torch.compile), 더 나은 GPU 가속의 이점을 누립니다. 이 업데이트된 가이드는 기초부터 현대 딥러닝에서의 활용까지, 알아야 할 모든 내용을 다룹니다.

Python Pandas DataFrame으로부터 노코드로 빠르게 데이터 시각화를 만들고 싶나요?

PyGWalker는 시각화를 활용한 탐색적 데이터 분석(EDA)을 위한 Python 라이브러리입니다.
pandas 또는 polars DataFrame을 Jupyter Notebook 안에서 Tableau와 비슷한 UI로 바꿔 줍니다.

👉 PyGWalker on GitHub (opens in a new tab)

PyGWalker for Data visualization (opens in a new tab)


Understanding nn.Linear in PyTorch

What is nn.Linear?

nn.Linear는 입력 데이터에 affine 변환을 적용합니다:

[ y = xA^T + b ]

  • x: 입력 텐서
  • A: 가중치 행렬 (out_features × in_features)
  • b: bias 벡터 (out_features)
  • y: 출력 텐서

이 모듈은 두 개의 필수 인자를 받습니다:

nn.Linear(in_features, out_features, bias=True)

예시:

import torch
from torch import nn
 
linear_layer = nn.Linear(in_features=3, out_features=1)

이렇게 생성됩니다:

  • weight shape: 1 × 3
  • bias shape: 1

How Does nn.Linear Work?

forward pass에서는 다음을 수행합니다:

  • weight와의 행렬 곱셈
  • bias 덧셈(활성화된 경우)
output = linear_layer(torch.tensor([1., 2., 3.]))
print(output)

📌 Modern Usage: Batched Inputs

PyTorch의 nn.Linear마지막 차원에 대해 연산을 적용합니다.
그래서 다음과 같은 배치 형태에서 별도 처리 없이 자연스럽게 동작합니다.

2D inputs (batch, features)

[batch, in_features]

3D inputs (batch, seq_len, features)

Transformer에서 흔히 사용되는 형태:

[batch, seq_len, in_features] → [batch, seq_len, out_features]

예시:

x = torch.randn(32, 128, 512)   # batch=32, seq_len=128, hidden=512
linear = nn.Linear(512, 1024)
y = linear(x)                   # output: [32, 128, 1024]

별도의 reshape가 필요 없습니다 — 이는 attention에서 Q, K, V projection을 구현할 때 특히 중요합니다.


Initializing Weights and Biases

PyTorch는 기본적으로 Kaiming Uniform 초기화를 사용하지만, torch.nn.init을 이용해 자유롭게 커스터마이즈할 수 있습니다.

import torch.nn.init as init
 
init.xavier_uniform_(linear_layer.weight)
init.zeros_(linear_layer.bias)

When to use bias=False

다음과 같은 경우에는 bias를 끄는 일이 자주 있습니다:

  • linear layer 뒤에 normalization layer가 바로 따라올 때
    (예: Transformer의 LayerNorm)
  • attention block의 Q/K/V projection layer
  • residual block 내부의 linear layer

예시:

nn.Linear(embed_dim, embed_dim, bias=False)

Comparing nn.Linear and nn.Conv2d

LayerBest forOperationParameters
nn.Linear벡터, 펼친(flattened) 입력, MLP, TransformerAffine transformin_features, out_features
nn.Conv2d이미지 데이터, 국소 공간 패턴Sliding window convolutionchannels, kernel size, stride, padding

핵심 차이점:

  • Linear layer는 dense connection(모든 입력-출력 간 완전 연결)을 사용합니다.
  • Conv2D는 공간적으로 국소적인 커널을 사용해 파라미터 수를 줄이고, 강한 inductive bias(공간 구조에 대한 사전 가정)를 제공합니다.

Applications of nn.Linear in Deep Learning

1. Multi-Layer Perceptron (MLP)

layer = nn.Linear(256, 128)

2. Transformer Models

표준 Transformer에서 linear layer는 다음에 사용됩니다:

  • Q, K, V projection
  • MLP feed-forward network
  • 출력 head

예시:

q = nn.Linear(d_model, d_model, bias=False)

3. Classification Heads

classifier = nn.Linear(hidden_dim, num_classes)

4. Autoencoders

Encoder와 decoder 모두에서 linear layer를 사용해 정보를 압축하고 복원합니다.


Using nn.Linear in a PyTorch Model

import torch.nn.functional as F
from torch import nn
 
class Net(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc1 = nn.Linear(10, 5)
        self.fc2 = nn.Linear(5, 2)
 
    def forward(self, x):
        x = F.relu(self.fc1(x))
        return self.fc2(x)
 
net = Net()

Modern PyTorch 2.x Tips (Important!)

🚀 1. torch.compile로 linear layer 최적화

model = torch.compile(net)

컴파일을 통해 연산을 fusion하여 GPU 성능을 높일 수 있습니다.


⚡ 2. GPU에서 더 빠른 학습을 위한 AMP

with torch.cuda.amp.autocast():
    output = model(x)

📉 3. 추론용 Quantized Linear

import torch.ao.quantization as quant

모델 크기를 줄이고 추론 속도를 높이는 데 유용합니다.


🔧 4. 커스텀 fused 연산에는 F.linear 사용

import torch.nn.functional as F
y = F.linear(x, weight, bias)

여러 연산을 수동으로 결합해 커스텀 연산을 만들 때 유용합니다.


Common Errors and How to Fix Them

1. Shape Mismatch

RuntimeError: mat1 and mat2 shapes cannot be multiplied

해결: 입력의 마지막 차원이 in_features와 일치하는지 확인합니다.

2. 3D 입력을 nn.Linear가 처리하도록 두지 않고 reshape하는 경우

불필요한 reshape는 피하세요 — nn.Linear는 이미 3D 입력을 지원합니다.

3. 정수 텐서를 섞어 쓰면서 .float()를 잊는 경우

x = x.float()

Conclusion

nn.Linear는 겉보기에는 단순하지만, PyTorch에서 매우 기본적이고 중요한 모듈입니다.
MLP에서 Transformer, 최신 생성 모델에 이르기까지, 오늘날의 딥러닝 시스템에서 대부분의 연산량은 linear layer에서 발생한다고 해도 과언이 아닙니다.

다음 내용을 이해하면:

  • 동작 원리
  • 다양한 입력 shape 처리 방식
  • PyTorch 2.x에서의 최적화 방법

신경망 모델을 설계하고 디버깅하는 데 훨씬 더 효율적으로 대응할 수 있습니다.


FAQs

What is the purpose of a bias vector in nn.Linear?

bias는 입력과 무관하게 출력 값을 평행 이동시킬 수 있게 해 줍니다.
특히 입력 데이터가 0을 중심으로 분포하지 않을 때, 모델의 표현력을 높이는 데 도움이 됩니다.

How do you initialize weights and biases?

PyTorch가 자동으로 초기화를 수행하지만, 필요하다면 torch.nn.init(예: Xavier initialization)을 사용해 원하는 방식으로 재초기화할 수 있습니다.

Difference between nn.Linear and nn.Conv2d?

nn.Linear는 dense한 affine 변환을 적용하며, MLP나 Transformer에 주로 사용됩니다.
nn.Conv2d는 공간적인 convolution을 적용하며, 이미지 데이터를 다루는 CNN에서 주로 사용됩니다.