본문 바로가기

ML & DL

[DL] Pytorch Lightning ⚡

번개 파이토치에 대해 알아보자~

파이토치 라이트닝은 기존 파이토치에서 학습/추론 및 데이터 로드 부분을 모듈화하여 반복되는 코드를 구현하지 않아도 되는 패키지다. 

기존 파이토치는 모델 클래스, 데이터셋 클래스, 학습 과정을 따로 구현했으나 라이트닝은 한 클래스 안에 전부 구현한다. Gru를 이용한 오토인코더 모델을 예시로 구현해봤다.

import pytorch_lightning as pl
import torch
import torch.nn as nn
import torch.nn.functional as F

class GruAutoencoder(pl.LightningModule):
    def __init__(self, config, chord_token, batch_size):
        super().__init__()
        self.config = config
        self.chord_token = chord_token
        self.batch_size = batch_size
        self.learning_rate = config["learning_rate"]

        self.enc_emb = nn.Embedding(
            self.config["n_enc_vocab"], self.config["d_hidn"], padding_idx=0
        )
        self.encoder = nn.GRU(
            input_size=self.config["d_hidn"],
            hidden_size=self.config["d_hidn"],
            num_layers=1,
            dropout=config["dropout"],
        )
        self.decoder = nn.Conv1d(
            in_channels=self.config["latent_dim"],
            out_channels=self.config["data_seq"],
            kernel_size=1,)
        
    def forward(self, inputs) -> tensor:
        embedding_inputs = self.enc_emb(inputs)
        embedding_input = embedding_inputs.transpose(0, 1)
        _, output = self.encoder(embedding_input)
        output = self.decoder(output.view(-1, self.config["latent_dim"], self.config["d_hidn"]))
        return output, embedding_inputs
        
    def train_dataloader(self):
        train_dataset = ChordProgressionSet(self.chord_token)
        train_loader = torch.utils.data.DataLoader(train_dataset, self.batch_size, shuffle=True)
        return train_loader
        
    def training_step(self, batch, idx):
        outputs, embedding_input = self(batch)
        loss = F.mse_loss(outputs, embedding_input)
        tensor_board_logs = {"train_loss": loss, "lr": self.learning_rate}
        return {"loss": loss, "log": tensor_board_logs}

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=self.learning_rate)
        scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=20, gamma=0.5)
        return [optimizer], [scheduler]

클래스 생성자에 하이퍼 파라미터와 사용할 모델을 선언하고, forward 함수에서 연결해준다. train_dataloader 에서 토치 데이터셋을 불러와  Dataloader 형식으로 return 하고, training_step 에서 모델 output과 target의 loss를 계산해 return 한다. 끝으로 configure_optimizers 에서 최적화 방법과 lr_scheduler 등을 설정할 수 있다. 

이후 학습 및 추론을 위한 main 함수만 구현하면 끝!

def train(args, config, chord_token):
    trainer = pytorch_lightning.Trainer(
        gpus=4,
        max_epochs=1500,
        default_root_dir=args.ckpt_dir,
    )
    models = model.GruAutoencoder(config, chord_token, args.batch_size)
    trainer.fit(models)

pl trainer를 선언하고 앞서 구현한 모델을 호출하고 fit 하면 학습이 진행된다. 

결론

1. 반복적인 train 및 inference 코드를 길게 쓰지 않아도 되서 좋다.

2. 모델 구조와 데이터 로더, 학습 방법등을 한 클래스에서 선언해 가독성이 좋고, 추후 모델 수정에도 용이하다.

3. trainer에서 gpu 수, epoch, ckpt 경로 등 기본적인 학습 파라미터를 직관적으로 입력할 수 있다.