앞서 살펴봤듯이 GAN은 패턴을 흉내내는 문제를 Minimax Game으로 접근합니다. 즉, 패턴을 흉내내는 생성자(generator) 네트워크와 그 패턴의 진위 여부를 판별하는 판별자(discriminator) 네트워크로 구성됩니다. 이 두 네트워크가 서로의 목적을 달성하도록 학습을 반복하는 것이죠. 물론 학습을 하려면 학습 데이터가 필요합니다. GAN의 학습 데이터는 생성자가 흉내내야 할 진짜 데이터를 말합니다. 예를 들어 사람의 얼굴을 그려내는 모델을 학습시킨다면 사람의 얼굴 사진이 필요하겠죠. 반면 사람이 손으로 쓴 숫자를 흉내낸다면 수기 숫자 이미지가 필요할 것입니다. 간단히 말하면 GAN 학습을 위해서 아래와 같은 세가지 구성요소가 필요합니다.
- 학습 데이터 : 흉내내고자 하는 진짜 데이터
- 생성자 네트워크 : 진짜 데이터와 유사한 패턴을 생성하는 네트워크
- 판별자 네트워크 : '학습(진짜) 데이터'와 생성자가 만들어낸 '가짜 데이터'의 진위 여부를 판별하는 네트워크
이제 이 세가지 구성요소를 자세히 살펴보겠습니다.
학습 데이터
GAN의 학습 데이터는 간단합니다. 분류 문제처럼 각 데이터의 분류를 라벨링(labeling)하거나 회귀 문제처럼 정확한 답을 사람이 일일이 지정해줄 필요가 없습니다. 단지 흉내내고자 하는 데이터를 수집하면 됩니다. 이러한 이유로 GAN이 비지도 학습(unsupervised learning)의 성격을 띈다고 말하기도 합니다.
물론 데이터의 수는 일반적으로 많을 수록 좋습니다. 데이터의 수가 많을 수록 다양한 패턴이 포함될 확률이 크기 때문이죠. 데이터의 양에 더불어 데이터의 다양성도 중요합니다. 수기 숫자를 흉내낼 때, 학습 데이터에 0부터 9까지의 모든 숫자가 골고루 필요하다는 말입니다. 학습 데이터에 포함된 숫자가 '1'뿐이라면, 생성자도 '1'만 생성해낼 것입니다.
생성자 네트워크
생성자는 학습 데이터의 패턴을 흉내내어 판별자의 정확도를 최소화하는 것이 목적입니다. 기본적인 GAN 모델에서 생성자 네트워크에 주어지는 입력은 랜덤 노이즈(random noise)이며, 출력은 학습 데이터와 유사한 패턴을 지닌 데이터 입니다. 즉, 생성자 네트워크는 랜덤 노이즈를 학습 데이터와 유사한 패턴으로 변환하는 함수를 학습하게 됩니다.
네트워크의 형태는 목적에 따라 다양할 수 있습니다. 생성하고 싶은 패턴의 형태에 따라 MLP(multi-layer perceptron)와 CNN, auto-encoder 등 어떤 형태든 가능합니다. 기본적인 GAN에서 파생된 모델들에 대해서는 차차 살펴보도록 하겠습니다.
판별자 네트워크
판별자는 입력으로 주어진 데이터가 학습 데이터에 포함된 진짜인지, 생성자가 만들어낸 가짜인지를 판별하는 역할을 하며, 판별의 성공 확률을 최대화하는 것이 목적입니다. 즉, 진짜 혹은 가짜 데이터를 입력으로 받아, 입력받은 데이터가 학습 데이터에 포함된 진짜 데이터일 확률을 출력합니다. 판별자 네트워크의 형태도 생성자와 마찬가지로 제한이 없습니다. 판별코자 하는 데이터의 형태에 적합한 네트워크 형태를 선택하면 됩니다.
셋이 함께 모이면...
지금까지 설명한 내용을 처음부터 끝까지 정리해 보면 다음과 같습니다.
- 랜덤 노이즈를 생성하여 생성자의 입력으로 전달합니다.
- 생성자는 입력으로 주어진 랜덤 노이즈를 변환하여 가짜 데이터를 만듭니다.
- 생성자가 만들어낸 '가짜' 출력과 '진짜' 학습 데이터를 적절히 조합하여 판별자의 입력으로 제공합니다.
- 판별자는 주어진 입력이 학습 데이터에 포함된 '진짜'일 확률을 구합니다.
GAN의 실행 순서는 위와 같으며, 학습 과정은 위의 과정을 거꾸로 거슬러 올라가면서 역전파(back propogation) 알고리즘을 바탕으로 네트워크의 파라미터를 최적화하는 과정입니다. 다음 포스트에서는 적대적 학습의 과정을 자세히 살펴보겠습니다.
이해하기 쉽게 잘 써주셨네요!!
답글삭제글을 읽다가 한 가지 궁금한 점이 생겼는데요.
GAN이 비지도 학습이라고 하는데
discriminator가 input으로 들어온 데이터가 진짜인지 아닌지 판별하는 과정에서도 데이터 라벨(input으로 들어온 것이 진짜인지 가짜인지)이 필요가 없다는 뜻인가요?
머신러닝에서 분류기라고 하면 정답을 알고 있는 상태에서 패턴을 학습시킨 다음 새로운 데이터가 들어왔을 때 그게 1인지 0인지 판별하는 것이라고 알고 있는데, discriminator도 분류기라면 라벨링된 데이터가 필요할 것 같아서요...
Discriminator 측 모델을 학습시키는 것은 라벨이 있는 데이터를 통해서 판단하므로 지도학습으로 볼 수 있습니다.
답글삭제