はじめに: nn.Linear をわかりやすく理解する
Updated on
nn.Linear は PyTorch で最も基本的なブロックの1つです。
シンプルな Multi-Layer Perceptron (MLP)、Transformer、あるいは大規模なディープラーニングシステムを構築する場合でも、線形層はあらゆる場所で利用され、データに対して高速なアフィン変換を行っています。
PyTorch 2.x では、線形層は更新されたカーネルやコンパイラ最適化(torch.compile)、より良い GPU アクセラレーションの恩恵を受けています。本稿では、基礎から最新のディープラーニングでの使い方まで、知っておくべき内容を網羅的に解説します。
Python の Pandas DataFrame から ノーコード で素早くデータ可視化を行いたいですか?
PyGWalker は、可視化付き Exploratory Data Analysis を行うための Python ライブラリです。
pandas や polars の DataFrame を、Jupyter Notebook 内で Tableau ライクな UI に変換してくれます。
PyTorch における nn.Linear を理解する
nn.Linear とは?
nn.Linear は入力データにアフィン変換を適用します:
[ y = xA^T + b ]
x: 入力テンソルA: 重み行列(out_features × in_features)b: バイアスベクトル(out_features)y: 出力テンソル
このモジュールは2つの必須引数を取ります:
nn.Linear(in_features, out_features, bias=True)例:
import torch
from torch import nn
linear_layer = nn.Linear(in_features=3, out_features=1)これにより次のパラメータが作られます:
- weight の形状:
1 × 3 - bias の形状:
1
nn.Linear はどう動くのか?
forward パスでは以下を行います:
- 重みとの行列積
- バイアスの加算(有効な場合)
output = linear_layer(torch.tensor([1., 2., 3.]))
print(output)📌 現代的な使い方: バッチ入力
PyTorch は nn.Linear を「最後の次元」に対して適用します。
そのため、以下のような形状に対してシームレスに動作します。
2次元入力 (batch, features)
[batch, in_features]3次元入力 (batch, seq_len, features)
Transformer でよく使われる形です:
[batch, seq_len, in_features] → [batch, seq_len, out_features]例:
x = torch.randn(32, 128, 512) # batch=32, seq_len=128, hidden=512
linear = nn.Linear(512, 1024)
y = linear(x) # output: [32, 128, 1024]reshape は不要です — これは Attention の Q, K, V の射影において特に重要です。
重みとバイアスの初期化
PyTorch はデフォルトで Kaiming Uniform による初期化を行いますが、torch.nn.init を使ってカスタマイズできます。
import torch.nn.init as init
init.xavier_uniform_(linear_layer.weight)
init.zeros_(linear_layer.bias)bias=False を使うべき場面
次のような場合にはバイアスをオフにすることがよくあります:
- 線形層の直後に正規化層が続く場合
(例: Transformer のLayerNorm) - Attention ブロックにおける Q/K/V の射影層
- residual block 内で使われる線形層
例:
nn.Linear(embed_dim, embed_dim, bias=False)nn.Linear と nn.Conv2d の比較
| Layer | 適した用途 | 演算内容 | 主なパラメータ |
|---|---|---|---|
nn.Linear | ベクトル、flatten 済み入力、MLP、Transformer | アフィン変換 | in_features, out_features |
nn.Conv2d | 画像データ、局所的な空間パターン | スライディングウィンドウ畳み込み | channels, kernel size, stride, padding |
主な違い:
- Linear 層は 全結合 (dense connections) を用います(全ニューロン同士が接続)。
- Conv2d は 空間的に局所なカーネル を使うことで、パラメータ数を削減しつつ、画像のようなデータに対する帰納的バイアスを導入します。
nn.Linear のディープラーニングでの用途
1. Multi-Layer Perceptron (MLP)
layer = nn.Linear(256, 128)2. Transformer モデル
標準的な Transformer では、線形層は次の用途で使われます:
- Q, K, V の射影
- MLP による feed-forward ネットワーク
- 出力ヘッド
例:
q = nn.Linear(d_model, d_model, bias=False)3. 分類ヘッド
classifier = nn.Linear(hidden_dim, num_classes)4. Autoencoder
Encoder と Decoder のどちらも、線形層を用いて情報の圧縮と再構成を行います。
PyTorch モデルでの nn.Linear の使い方
import torch.nn.functional as F
from torch import nn
class Net(nn.Module):
def __init__(self):
super().__init__()
self.fc1 = nn.Linear(10, 5)
self.fc2 = nn.Linear(5, 2)
def forward(self, x):
x = F.relu(self.fc1(x))
return self.fc2(x)
net = Net()現代的な PyTorch 2.x の Tips(重要)
🚀 1. torch.compile で線形層を最適化する
model = torch.compile(net)compile により演算が fuse され、GPU 上で高速化されます。
⚡ 2. AMP で GPU 訓練を高速化する
with torch.cuda.amp.autocast():
output = model(x)📉 3. 推論では Quantized Linear を使う
import torch.ao.quantization as quantモデルサイズを削減し、推論を高速化できます。
🔧 4. カスタムの fused operation には F.linear を使う
import torch.nn.functional as F
y = F.linear(x, weight, bias)複数の演算を手動で組み合わせるときに便利です。
よくあるエラーと対処法
1. 形状の不一致 (Shape Mismatch)
RuntimeError: mat1 and mat2 shapes cannot be multiplied対処: 入力テンソルの最後の次元が in_features と一致しているか確認します。
2. 3次元入力に対して自分で reshape してしまう
不要な reshape は避けてください — nn.Linear はもともと 3D 入力をサポートしています。
3. 整数テンソルをそのまま渡してしまう(.float() を忘れる)
x = x.float()まとめ
nn.Linear は一見シンプルですが、PyTorch における基礎的かつ非常に重要なモジュールです。
MLP から Transformer、最新の生成モデルに至るまで、今日のディープラーニングシステムでは計算の大部分が線形層によって占められています。
次の点を理解しておくと、ニューラルネットワークを設計・実装する際に大きな力になります:
- どのように動作するか
- さまざまな入力形状をどう扱うか
- PyTorch 2.x の機能でどのように最適化できるか
FAQs
nn.Linear におけるバイアスベクトルの役割は何ですか?
バイアスは、入力に依存しない形で出力値をシフトさせることを可能にします。
これにより、特に入力データがゼロ中心でない場合に、モデルの表現力が向上します。
重みとバイアスはどのように初期化しますか?
PyTorch が自動的に初期化してくれますが、必要に応じて torch.nn.init(例: Xavier 初期化)を用いて上書きすることができます。
nn.Linear と nn.Conv2d の違いは何ですか?
nn.Linear は密なアフィン変換を適用し、MLP や Transformer で使われます。
nn.Conv2d は空間方向の畳み込みを行い、主に画像データを扱う CNN で用いられます。
