-
[GAN]Generative Adversarial Networks(NIPS 2016)Artificial Intelligence/Deep learning 2021. 9. 10. 13:31
딥러닝의 3대 석학이라고 불리는 얀 르쿤(Yann LeCun) 교수가 GAN(Generative Adversarial Network)을 가리켜 최근 10년간 머신러닝 분야에서 가장 혁신적인 아이디어라고 할 만큼 GAN은 가장 많은 관심을 받고 있는 기술 중에 하나이고 재미있는 연구 분야입니다. 오늘은 Generative Adversarial Networks에 대해서 포스팅해보도록 하겠습니다.
GAN에서 다루고자 하는 모든 데이터는 확률분포를 가지고 있는 랜덤변수(Random Variable)이기 때문에 먼저 확률분포의 개념을 알아야 합니다.
확률분포
확률 번수가 특정한 값을 가질 확률을 나타내는 함수를 의미합니다.
예를 들어 주사위를 던젔을때 나올 수 있는 확률 변수 X라고 합시다.
- 확률 변수 X는 1, 2, 3, 4, 5, 6의 값을 가질 수 있습니다.
- P(X=1)는 1/6 입니다.
- P(X=1)=P(X=2)=P(X=3)=P(X=4)=P(X=5)=P(X=6)
확률 분포는 크게 2가지로 구분할 수 있습니다.
- 이산 확률 분포 : 확률 변수 X의 개수를 정확히 셀 수 있는 함수
- 연속 확률 분포 : 확률 변수 X를 셀 수 없는 함수
이산 확률 분포
- 이산 확률 분포의 나타낼 때 사용되는 대표적인 그림인 주사위 확률 그림입니다.
- 확률 변수 X의 개수를 정확히 셀 수 있고 모든 확률변수 X를 다 더하면 1이 됩니다.
- 대표적인 이산 확률분포에는 이산 균등 분포, 포아송 분포, 이항 분포가 있습니다.
연속 확률 분포
- 연속 확률분포의 대표적인 그림인 정규분포(가우스 분포) 그림입니다.
- 연속적인 값의 예시는 : 키, 달리기, 성적 같은 값들을 말합니다.
- 대표적인 연속 확률 분포에는 정규 분포, 연속 균등 분포, 카이제곱 분포, 감마 분포가 있습니다.
Generative Adversarial Networks(NIPS 2016)
Unsupervised의 대표주자인 GAN은 분류 모델(판별자 Discriminator)과 회귀 모델(생성자 Generator)로 구성되어 있습니다. 두 모델은 GAN이란 이름에서 쉽게 알 수 있듯이, 생성자 Generator와 판별자 Discriminator가 서로의 성능을 개선해 적대적(Adversarial)으로 경쟁해 나가는 모델입니다. 위의 그림과 같이 경찰과 지폐 위조범의 대립과 같은 방식으로 이해할 수 있습니다.
Model architecture
모델의 구조는 다음과 같습니다.
- 생성자 Generator가 Latent space에서 Latent V를 뽑아서 Fake이미지를 생성합니다.
- 판별자 Discriminator가 Real Image와 Fake Image 이미지를 판별하는데 Real Image는 1 Fake는 0으로 판별합니다
- 생성자 Generator는 Fake Image 가 들킬 가능성을 최소화하고 판별자 Discriminator는 속을 가능성을 최소화합니다
Loss function
Gan의 Loss function는 일반적인 Loss function처럼 한 방향으로(최대화하거나 최소화하는) 진행되는 게 아니라 G는 V를 최소화하고 D는 V를 최대화하는 Loss function입니다.
Generator Loss function
Discriminator Loss function
Gan의 학습과정
- (a) z에서 x(이미지) 공간으로 맵핑이 원래 이미지의 분포와는 다른 분포를 생성합니다.
- (b) 생성자 G를 고정하고 판별자 D를 학습을 합니다
진짜 이미지에 대해서 안정적으로 1에 가까울 확률, 반에의 경우 0에 가까운 확률을 반환하게끔 학습을 합니다. - (c) 판별자 D를 고정하고 생성자 G를 학습을 합니다.
z에서 x로의 맵핑이 (a)의 상태와 다르게 좀 더 실제 이미지 분포에 가까운 분포를 형성합니다. - (d) z에서 x로 맵핑이 실제 이미지 분포와 거의 동일한 분포를 생성하면 D는 위의 그림과 같이 1/2로 판별 확률을 반환합니다.
Gan의 목표는 생성자의 분포가 원본 학습 데이터의 분포를 잘 따를 수 있는가입니다.
즉 학습이 다 이루어진 후에 판별자가 진짜 이미지와 가짜 이미지를 구분할 수 없는 가짜 이미지를 생성하는 것입니다.
Generative Adversarial Network Pytorch
https://github.com/eriklindernoren/PyTorch-GAN
- 위의 Github에 Pytorch으로 구현한 GAN이 정리가 잘 되어 잇습니다
참고 자료
'Artificial Intelligence > Deep learning' 카테고리의 다른 글
[Computer Vision]Pose Estimation (0) 2021.07.31 [Computer Vision]Object Detection (0) 2021.07.10 [RNN]Recurrent Neural Network(RNN) (0) 2021.06.21 [Deep Learning]Depthwise Separable Convolution (0) 2021.06.14 [Optimizer] Optimizer-경사하강법(Gradient Descent) (0) 2021.05.19