본문 바로가기

ML & DL

[Pytorch] Pix2Pix 구현하기

Pytorch Architecture Practice #3

generatice model 중 하나인 pix2pix(Image-to-Image Translation with Conditional Adversarial Networks)를 구현해봅니다. DCGAN과 달라진 점들을 체크하고 구현하면서 알게된 점이나, 궁금한 점 위주로 정리했습니다. 한요섭님의 코드를 참고하여 공부했음을 밝힙니다. 전체 코드는 github에 있습니다.

original paper : https://arxiv.org/pdf/1611.07004.pdf

my github : https://github.com/HyunLee103/Pytorch_practice


Pix2Pix

pix2pix는 conditional GAN에 기반을 둔 image-to-image translation model이다. 즉 입력 이미지를 특정 domain의 이미지로 변환시켜준다. DCGAN과 다른점은 Generator(G)의 인풋이 random vector가 아니라 condition image 라는 점이다. 

출처 : Image-to-Image Translation with Conditional Adversarial Networks

pix2pix는 위 그림처럼 서로 연관된 pair dataset이 필요하다. 이런 pair set은 구하기 어렵고, 이 문제를 해결하기 위해 pair dataset이 없이도 domain transfer를 할 수 있는 CycleGAN이 등장한다.

DCGAN과 다르게 random vector z 대신 condition input x가 G에 들어가 이미지 G(x)를 생성한다. 이후 Dicriminator(D)에 input x와 output G(x)가 같이 들어간다. 목표하는 domain에 target label y가 있고 y는 x와 함께 D에 들어간다. D는 생성된 이미지 (G(x), x) 를 가짜로 label 이미지 (y, x)를 진짜로 잘 분류하게 학습한다. 수식으로 나타내면 다음과 같다.

$L_{cGAN}(G,D) = E_{x,y}[\log D(x,y)]+E_{x,z}[\log (1-D(x,G(x)))]$

cGAN loss의 이해를 위해 도식화를 해봤다.

CrossEntropy for Conditional GAN Loss

여기에 추가로 생성된 이미지 G(x)와 label y 사이의 유클리디안 거리를 반영하는 L1 Loss를 사용한다. 수식은 다음과 같다.

$L_{L1}(G) = E_{x,y}\left \| y-G(x) \right \|_1$

L1 loss에 람다만큼 weight를 주고 두 loss를 더한 최종 object function은 다음과 같다. 

$G = \arg min_Gmax_D L_{GAN}(G,D) + \lambda L_{L1}(G)$

L1 loss로 image의 low-frequency content를 학습하고 GAN loss로 high-frequency content를 학습하여 L1 loss만 사용했을 때 보다 좀 더 sharp하고 진짜같은 이미지를 생성할 수 있게된다. 위 함수들에서 random vector z가 빠진 이유는 다음과 같이 z 대신 dropout을 사용했기 때문이다.

' Instead, for our final models, we provide noise only in the form of dropout, applied on several layers of our generator at both training and test time ' - paper_p.3

Generator

pix2pix의 G는 encoder-decoder 구조를 가진다. 특히 encoder step의 feature map을 skip connection을 통해 연결해 decoding하는 U-Net 구조가 성능이 좋았다. 

Discriminator

D는 PatchGAN을 사용한다. DCGAN에서는 D에서 이미지 전체 영역을 보고 진짜/가짜 판별을 했다면, PatchGAN은 특정 크기의 patch 단위로 이미지의 진위 여부를 판별한다.  

patch는 위 그림에서 파란 점선 부분이 되고 이를 receptive field라고 정의한다. 앞서 언급했듯이 D에는 256 x 256 크기의 입력이미지 x와 동일한 크기의 생성된 fake image G(x)가 concat해서 들어간다. 1/2 downsampling Conv layer를 5번 거쳐 16 x 16 feature map을 뽑아내고 이 값을 sigmoid에 통과시켜 최종적으로 1-D scalar 값을 뽑아낸다. 이 때, 최종 feature map 한 픽셀에 해당하는 인풋 이미지 size가 receptive field가 되고, pix2pix에서는 70 x 70을 사용한다.

Normalization

이 논문에서는 batch normalization 대신 사용가능한 instance normalization을 시도해본다. 

찾아보니 norm 방법이 다양하게 있었고, 'Batch-Instance Normalization for Adaptively Style-Invariant Neural Networks' 라는 논문을 통해 얻은 간단한 insight만 정리해보았다. 

IN has been widely adopted as an alternative to Batch Normalization (BN) in style transfer [15, 4] and generative adversarial networks (GANs) [14, 30]. It is a reasonable assumption that IN would be beneficial not only in generative tasks but also in discriminative tasks for addressing unnecessary style variations. However, directly applying IN to a classification problem degrades the performance [25], probably because styles often serve as useful discriminative features for classification. Note that IN dilutes the information carried by the global statistics of feature responses while leaving their spatial configuration only, which can be undesirable depending on the task at hand and the information encoded by a feature map.

Instance normalization은 content 이미지의 feature 통계량을 직접 normalize 하므로 style variation을 제거하는 효과를 가져와, style transfer에서 BN 대신 사용된다. 하지만 이러한 IN을 object detection등 다른 vision task에 적용하기는 어렵다. 따라서 둘을 섞은 Batch-Instance Normalization을 제안하는게 이 논문의 골자이다.


Implement

Framework은 이전 DCGAN 구현과 거의 동일하므로 주요한 코드만 살펴보겠다.

model.py

generator와 discriminator를 구현한 파일이다.

class Pix2Pix_generator(nn.Module):
    def __init__(self, in_channels,out_channels,nker=64,norm='bnorm'):
        super(Pix2Pix, self).__init__()

        # encoder
        # Leaky relu 사용, 첫번째 encoder는 batchnorm X
        self.enc1 = CBR2d(in_channels,1*nker,kernel_size=4, padding=1,stride=2,
        norm = None,relu=0.2)
        self.enc2 = CBR2d(1*nker,2*nker,kernel_size=4, padding=1,stride=2,
        norm = norm ,relu=0.2)
def forward(self, x):
        enc1 = self.enc1(x)
        enc2 = self.enc2(enc1)
        enc3 = self.enc3(enc2)
        enc4 = self.enc4(enc3)
        enc5 = self.enc5(enc4)
        enc6 = self.enc6(enc5)
        enc7 = self.enc7(enc6)
        enc8 = self.enc8(enc7)

        dec1 = self.dec1(enc8)
        drop1 = self.drop1(dec1)

        cat2 = torch.cat((drop1,enc7),dim=1)
        dec2 = self.dec2(cat2)
        drop2 = self.drop2(dec2)

 

generator class에 forward method에서 skip-connection을 해줘야 하므로 torch.cat 을 활용해 채널 방향으로 feature map을 합쳐주었다. 

 

layer.py

여기에는 반복해서 사용되는 네트워크 layer들을 연결해 하나의 class로 선언했다. 

lass CBR2d(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=True, norm="bnorm", relu=0.0):
        super().__init__()

        layers = []
        layers += [nn.Conv2d(in_channels=in_channels, out_channels=out_channels,
                             kernel_size=kernel_size, stride=stride, padding=padding,
                             bias=bias)]

        if not norm is None:
            if norm == "bnorm":
                layers += [nn.BatchNorm2d(num_features=out_channels)]
            elif norm == "inorm":
                layers += [nn.InstanceNorm2d(num_features=out_channels)]

        if not relu is None and relu >= 0.0:
            layers += [nn.ReLU() if relu == 0 else nn.LeakyReLU(relu)]

        self.cbr = nn.Sequential(*layers)

    def forward(self, x):
        return self.cbr(x)

인코더 파트에 사용되는 Conv-batch-Relu을 합친 CBR2d class, norm과 relu 파라미터를 통해 BN/IN, Relu/LeakyRelu를 선택할 수 있게 구현했다.

 

dataset.py

 

training에 사용할 데이터가 위와 같이 생겼기 때문에, label-image와 input-image를 지정해 분리해서 data를 불러야 한다. 다음과 같이 Dataset class의 __getitem__ method를 이용했다.

 

class Dataset(torch.utils.data.Dataset):
    def __init__(self, data_dir, transform=None, task=None, opts=None):
        self.data_dir = data_dir
        self.transform = transform
        self.task = task
        self.opts = opts
        self.to_tensor = ToTensor()

        lst_data = os.listdir(self.data_dir)
        lst_data = [f for f in lst_data if f.endswith('jpg') | f.endswith('jpeg') | f.endswith('png')]

        lst_data.sort()
        self.lst_data = lst_data

    def __len__(self):
        return len(self.lst_data)

    def __getitem__(self, index):  # iterator 만들기
        img = plt.imread(os.path.join(self.data_dir, self.lst_data[index]))
        sz = img.shape

        if img.ndim == 2:
            img = img[:, :, np.newaxis]

        if img.dtype == np.uint8:
            img = img / 255.0

        # 이미지의 label(y), input(x)을 결정해, 학습 방향을 정하는 옵션
        if self.opts[0] == 'direction':
            if self.opts[1] == 0: # label : left, input : right
                data = {'label':img[:, :sz[1]//2, :] ,'input':img[: , sz[1]//2:, :]}
            elif self.opts[1] == 1: # 반대
                data = {'label':img[:, sz[1]//2: ,:], 'input':img[:, :sz[1]//2,:]}
        else:
            data = {'label': img}

        if self.transform:
            data = self.transform(data)

        data = self.to_tensor(data)

        return data

 

self.opts라는 파라미터를 통해 raw image에서 어떤 부분을 label로, input으로 할 지 결정해서 인덱싱 해준다. default는 opts[1] == 0 으로 위 이미지의 왼쪽 부분이 label(y), 오른쪽 부분이 G의 인풋 x가 된다.

train.py

# 데이터 부르기
    if mode == 'train':
        transform_train = transforms.Compose([Resize(shape=(286,286,nch)), RandomCrop((ny,nx))
        , Normalization(mean=0.5,std=0.5)]) # jitter technic이 사용됨(data augumentation)
                                            # Random jitter was applied by resizing the 256×256 input
                                            # images to 286 × 286 and then randomly cropping back to size 256 × 256.
        dataset_train = Dataset(data_dir=os.path.join(data_dir,'train'), transform=transform_train, task=task, opts=opts)
        loader_train = DataLoader(dataset_train, batch_size=batch_size, shuffle=True, num_workers=8)

 

논문을 보면, jittering을 사용해 데이터 augumentation 효과를 준다. jitter는 256 x 256 인풋 이미지를 286 x 286으로 늘린뒤 다시 256 x 256으로 RandomCrop하는 테크닉이다. DataLoader를 통해 batch_size 만큼 데이터를 yield하는 generator(loader_train)를 만들어준다.

 

# TRAIN MODE
    if mode == 'train':
        if train_continue == "on": # 이어서 학습할때 on, 처음 학습 시킬때는 off
            netG,netD, optimG,optimD, st_epoch = load(ckpt_dir=ckpt_dir, netG=netG,netD=netD, optimG=optimG,optimD=optimD)

        for epoch in range(st_epoch + 1, num_epoch + 1):
            netG.train() # 사용할 네트워크를 train 모드로 설정
            netD.train() # 사용할 네트워크를 train 모드로 설정

            for batch, data in enumerate(loader_train, 1): # loader_train에는 input, label pair data.
                # forward pass
                label = data['label'].to(device) # label : 실제이미지 data, device로 올리기 > D의 real 파트 인풋 y
                input = data['input'].to(device) # input : G의 condition 입력 이미지 data, x 
                output = netG(input) # G(x)

                # backward pass
                # backprop도 두 네트워크 각각 해줘야 한다.

                # Discriminator backprop
                set_requires_grad(netD,True) # netD(Discriminator)의 모든 파라미터 연산을 추적해 gradient를 계산한다.
                optimD.zero_grad() # gradient 초기화

                real = torch.cat([input,label], dim=1)
                fake = torch.cat([output,input],dim=1)

                pred_real = netD(real)  # D(x,y)
                pred_fake = netD(fake.detach()) # D(G(x),x) , detach는 D의 gradient가 G까지 흘러가지 않게 하기위해서

                loss_D_real = loss_gan(pred_real,torch.ones_like(pred_real)) 
                loss_D_fake = loss_gan(pred_fake,torch.zeros_like(pred_fake))
                loss_D = 0.5 * (loss_D_fake + loss_D_real)

                loss_D.backward() # gradient backprop
                optimD.step() # gradient update

                # Generator backprop
                set_requires_grad(netD,False) # generator를 학습할땐, discriminator는 고정한다. 따라서 required_grad = False
                optimG.zero_grad() # gradient 초기화

                fake = torch.cat([input,output],dim = 1) 
                pred_fake = netD(fake) # D(G(x),x)

                loss_G_gan = loss_gan(pred_fake,torch.ones_like(pred_fake))  # generator는 생성한 가짜 이미지가 진짜(1)로 분류되게 학습해야한다.
                loss_G_L1 = loss_L1(output, label) # G(x), y의 L1 loss , 생성된 이미지와 real-image 사이의 유클리디안 거리
                loss_G = loss_G_gan + wgt * loss_G_L1

                loss_G.backward() # gradient backprop
                optimG.step() # gradient update

 

Result

실제 학습은 colab에서 GPU를 사용해서 진행했다. 논문에는 300 epoch를 기준으로 설명하고 있는데 일단 20 epoch 까지만 돌리고 매 10번 batch마다 output을 저장해보았다. 

학습이 진행되면서 조금 나아지는 것 같지만 GAN loss쪽이 제대로 학습되지 않은 듯 하다. 어쩐지 L1 loss만 뚝뚝 떨어지더라..ㅠㅠ 아니면 epoch을 논문처럼 300번 돌려봐야되는건지 째든 결과물을 확인하니까 기분이 좋다! 


reference

Batch-Instance Normalization for Adaptively Style-Invariant Neural Network(2018) ,https://arxiv.org/pdf/1805.07925.pdf 

https://wewinserv.tistory.com/71 

 

Pix2Pix

Pix2Pix Image-to-Image Translation with Conditional Adversarial Networks Isola, Phillip, et al. "Image-to-image translation with conditional adversarial networks." arXiv preprint (2017). cycleGAN,..

wewinserv.tistory.com

https://brstar96.github.io/mldlstudy/what-is-patchgan-D/

 

(NN Methodology) PatchGAN Discriminator 뽀개기

본 글은 개인적으로 스터디하며 정리한 자료입니다. 간혹 레퍼런스를 찾지 못해 빈 곳이 있으므로 양해 부탁드립니다.

brstar96.github.io

https://github.com/hanyoseob/pytorch-pix2pix

 

hanyoseob/pytorch-pix2pix

Image-to-Image Translation with Conditional Adversarial Networks - hanyoseob/pytorch-pix2pix

github.com