번개 파이토치에 대해 알아보자~
파이토치 라이트닝은 기존 파이토치에서 학습/추론 및 데이터 로드 부분을 모듈화하여 반복되는 코드를 구현하지 않아도 되는 패키지다.
기존 파이토치는 모델 클래스, 데이터셋 클래스, 학습 과정을 따로 구현했으나 라이트닝은 한 클래스 안에 전부 구현한다. Gru를 이용한 오토인코더 모델을 예시로 구현해봤다.
import pytorch_lightning as pl
import torch
import torch.nn as nn
import torch.nn.functional as F
class GruAutoencoder(pl.LightningModule):
def __init__(self, config, chord_token, batch_size):
super().__init__()
self.config = config
self.chord_token = chord_token
self.batch_size = batch_size
self.learning_rate = config["learning_rate"]
self.enc_emb = nn.Embedding(
self.config["n_enc_vocab"], self.config["d_hidn"], padding_idx=0
)
self.encoder = nn.GRU(
input_size=self.config["d_hidn"],
hidden_size=self.config["d_hidn"],
num_layers=1,
dropout=config["dropout"],
)
self.decoder = nn.Conv1d(
in_channels=self.config["latent_dim"],
out_channels=self.config["data_seq"],
kernel_size=1,)
def forward(self, inputs) -> tensor:
embedding_inputs = self.enc_emb(inputs)
embedding_input = embedding_inputs.transpose(0, 1)
_, output = self.encoder(embedding_input)
output = self.decoder(output.view(-1, self.config["latent_dim"], self.config["d_hidn"]))
return output, embedding_inputs
def train_dataloader(self):
train_dataset = ChordProgressionSet(self.chord_token)
train_loader = torch.utils.data.DataLoader(train_dataset, self.batch_size, shuffle=True)
return train_loader
def training_step(self, batch, idx):
outputs, embedding_input = self(batch)
loss = F.mse_loss(outputs, embedding_input)
tensor_board_logs = {"train_loss": loss, "lr": self.learning_rate}
return {"loss": loss, "log": tensor_board_logs}
def configure_optimizers(self):
optimizer = torch.optim.Adam(self.parameters(), lr=self.learning_rate)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=20, gamma=0.5)
return [optimizer], [scheduler]
클래스 생성자에 하이퍼 파라미터와 사용할 모델을 선언하고, forward 함수에서 연결해준다. train_dataloader 에서 토치 데이터셋을 불러와 Dataloader 형식으로 return 하고, training_step 에서 모델 output과 target의 loss를 계산해 return 한다. 끝으로 configure_optimizers 에서 최적화 방법과 lr_scheduler 등을 설정할 수 있다.
이후 학습 및 추론을 위한 main 함수만 구현하면 끝!
def train(args, config, chord_token):
trainer = pytorch_lightning.Trainer(
gpus=4,
max_epochs=1500,
default_root_dir=args.ckpt_dir,
)
models = model.GruAutoencoder(config, chord_token, args.batch_size)
trainer.fit(models)
pl trainer를 선언하고 앞서 구현한 모델을 호출하고 fit 하면 학습이 진행된다.
결론
1. 반복적인 train 및 inference 코드를 길게 쓰지 않아도 되서 좋다.
2. 모델 구조와 데이터 로더, 학습 방법등을 한 클래스에서 선언해 가독성이 좋고, 추후 모델 수정에도 용이하다.
3. trainer에서 gpu 수, epoch, ckpt 경로 등 기본적인 학습 파라미터를 직관적으로 입력할 수 있다.
'ML & DL' 카테고리의 다른 글
Hugging Face 기초 😇 (2) | 2021.04.23 |
---|---|
[NLP] ELMo, Contextual Word Representation (0) | 2021.02.10 |
[NLP] Language Model, Seq2Seq, Attention (0) | 2021.02.10 |
[Deep learning] 4장 수치 계산(Numerical Computation) (0) | 2021.01.18 |
[Deep learning] 8장 최적화 ① (1) | 2021.01.14 |