학습 데이터와 생성자 네트워크, 판별자 네트워크는 서로 어떻게 상호작용하면서 학습이 이뤄질까요? 생성자는 판별자의 정확도를 최소화하는 것이 목적이고, 판별자는 판별의 정확도를 최대화하는 것이 목적이라는 점은 알겠는데, 도데체 어떻게 학습이 이뤄지는지 감이 오질 않습니다. 이제부터 적대적 학습의 과정을 차근차근 살펴보겠습니다.
번갈아 학습하기
기본적인 적대적 학습에서는 판별자와 생성자를 번갈아가며 학습합니다. 학습을 처음 시작할 때는 판별자와 생성자가 모두 엉망인 상태죠. 판별자는 가짜와 진짜를 전혀 구별하지 못하고, 생성자는 전혀 엉뚱한 데이터를 생성해냅니다. 이렇게 양쪽이 모두 엉망인 상태에서는 학습이 어렵습니다. 생성자가 풀어야할 문제의 힌트를 주기위해서는 판별자를 먼저 학습 시킨 후, 생성자가 판별자의 지도를 받아 그 제약 안에서 학습하도록 해야 합니다. 이를 정리해보면, 전체적인 적대적 학습의 얼개는 다음과 같습니다.
- 판별자 네트워크 학습
- 랜덤 노이즈 m 개를 생성하여, 생성자 네트워크에 전달하고 변환된 데이터 m 개를 얻습니다.
- 학습 데이터셋에서 진짜 데이터 m 개를 선택합니다.
- 2m 개의 데이터(진짜 m개 + 가짜 m개)를 이용해 판별자 네트워크의 정확도를 최대화하는 방향으로 학습합니다.
- 생성자 네트워크 학습
- 랜덤 노이즈 m 개를 다시 생성합니다.
- 랜덤 노이즈 m 개를 이용해 생성자가 판별자의 정확도를 최소화하도록 학습합니다.
학습은 1,2 단계를 반복하면서 진행됩니다. 1 단계는 위에서 말한대로 판별자를 먼저 학습 시키는 과정이고, 2 단계는 어느 정도 학습된 판별자를 바탕으로 생성자를 학습시키는 과정입니다. 실제로는 판별자의 학습 속도를 높이기 위해 1번 단계를 여러번 반복한 후, 2번 단계로 넘어갈 수도 있습니다. 그만큼 판별자의 능력이 중요하다는 합니다. 앞에서 설명한 Minimax Game이 성립하려면 판별자의 능력이 최대화돼야하기 때문입니다. 이제 판별자 학습 과정과 생성자 학습 과정을 자세히 살펴보겠습니다.
판별자 네트워크 학습
판별자는 진짜 혹은 가짜 데이터를 입력으로 받고, 그 입력이 진짜일 확률을 출력합니다. 결국 판별자의 정확도가 높다는 말은 다음과 같이 두 가지 경우로 나누어 생각할 수 있습니다. (확률은 0이상 1이하의 실수입니다.)
- 입력 데이터가 진짜인 경우 : 1에 가까운 큰 확률값을 출력한다.
- 입력 데이터가 가짜인 경우 : 0에 가까운 작은 확률값을 출력한다.
이제 이런 목적을 달성할 수 있는 목적함수(손실함수)를 설계해야 합니다. 이해를 쉽게하기 위해 Ian J. Goodfellow 논문의 수식을 판별자 입장에서 단순히 표현하면 다음과 같습니다. 입력 데이터 x에 대하여, D(x)는 판별자가 출력한 확률을 말합니다.
즉, 입력 데이터 x가 학습 데이터에 포함되는 진짜라면 판별자가 출력하는 확률 D(x)를 최대화한다는 말입니다. 반대로 입력 데이터 x가 학습 데이터에 없는 가짜라면 판별자는 입력 x가 가짜일 확률을 크게 출력해야 하므로 1-D(x)를 최대화해야 합니다.
이제 이 손실함수를 최대화하는 방향으로 판별자를 학습시킵니다. 학습과정은 역전파 알고리즘을 이용한 일반적인 SGD와 동일합니다. 단지, 판별자의 목표는 손실을 최대화하는 것이므로 목적함수의 기울기가 하강하는 방향이 아니라 상승하는 방향으로 파라미터를 조정합니다.
원래 논문에서 표현된 GAN 전체의 목적함수는 위와는 크게 다른 모습입니다. 하지만, 판별자의 입장에서 판별자 네트워크를 학습시키는 부분만 분리해서 보면 위와 같이 이해해도 무방하다고 생각됩니다. 또한 원래 논문에서는 D(x)와 1-D(x)에 log를 적용하여 사용하는데 이는 좀 더 안정적인 학습을 위한 것으로, 손실을 최대화한다는 점에서는 동일합니다. 결국 실제로 GAN에서 판별자 네트워크를 학습시킬 때 사용하는 목적함수는 다음과 같습니다.
생성자 네트워크 학습
자, 이제 판별자가 어느 정도 능력을 갖추게 됐으니 그 경쟁자인 생성자에게 좋은 적수가 될 수 있겠습니다. 생성자 네트워크의 학습은 말그대로 판별자를 속이는 방법을 학습시키는 과정입니다. 생성자의 입력으로 주어지는 랜덤 노이즈 데이터를 z 라하고, 주어진 랜덤 노이즈를 생성자가 변환시킨 결과를 G(z)라고 합시다. 그렇다면 판별자가 변환된 데이터를 입력으로 받아 출력하는 확률은 D(G(z))로 표기할 수 있습니다.
우리는 생성자가 판별자를 속이도록 학습시켜야하므로, 판별자가 G(z)를 가짜라고 판별할 확률 1-D(G(z))를 최소화해야 합니다. 정리하면, 생성자의 손실함수는 아래와 같습니다.
일반적인 기울기 하강법과 SGD를 이용하여 생성자 네트워크의 파라미터를 업데이트 하면 됩니다. 물론 이 과정에서 판별자 네트워크의 파라미터는 업데이트하면 안됩니다. 판별자의 목적은 손실을 최소화하는 것이 아니라 최대화하는 것이기 때문입니다. 즉, 판별자는 손실함수에서 발생하는 기울기의 흐름을 생성자 네트워크에 전달하는 통로 역할을 할뿐, 판별자 네트워크 자체는 변하지 않습니다.
하지만 위에서 설명한 손실함수를 이용하면 생성자를 학습시키기가 어렵습니다. 그 이유는 학습 초기에 생성자와 판별자의 능력차가 크기 때문입니다. '가짜와 진짜를 판별하는 일'에 비해서 '진짜와 똑같은 가짜를 만드는 일'이 훨씬 더 어렵기 때문에, 생성자에 비해 판별자가 더 빠르게 학습됩니다. 결국 학습 초기에 생성자가 만들어낸 데이터는 진짜와 전혀 비슷하지 않은 '엉터리 가짜'이므로 판별자는 아주 쉽게 이를 판별해냅니다.
즉, 주어지는 모든 z에 대해 1-D(G(z))는 항상 1에 가까운 값을 가지게 됩니다. 생성자 입장에서는 어떻게 해도 판별자를 속일 힌트를 얻을 수 없는 것입니다. 수학적으로 보면 1-D(G(z))가 최대값인 1에 가까울 수록 log{1-D(G(z))}의 기울기는 작아지기 때문에, 생성자 네트워크의 파라미터를 조정하기에 충분한 기울기를 얻기 어렵습니다.
이런 문제를 해결하기 위해 약간의 트릭을 사용합니다. 판별자가 G(z)를 가짜라고 판별할 확률 1-D(G(z))을 최소화하는 대신에, 판별자가 G(z)를 진짜라고 할 확률 D(G(z))를 최대화하는 것입니다. 논리적으로는 완벽히 동일한 얘기지만, log를 적용했을 때의 기울기라는 측면에서 보면 아래 그림과 같이 전혀 다른 얘기가 됩니다. 즉, 학습초기에 D(G(z))가 0에 가까울 때 log{D(G(z))}의 기울기는 거의 수직에 가깝기 때문에 빠른 학습이 가능한 것입니다.
결론적으로 아래 목적 함수를 최대화하는 방향으로 생성자 네크워크를 학습시키게 됩니다.
이제 적대적 학습의 전체적인 흐름을 설명했습니다. 그런데 GAN을 다루는 논문들을 보면 이 모든 것을 '데이터의 확률분포'로 설명합니다. 우리는 흔히 확률이라고하면 '어떤 일이 벌어질 가능성'이라고 생각합니다. 그런데 패턴을 흉내내는 일과 확률분포 사이에 어떤 관계가 있을까요? 다음 포스트에서는 GAN에서 확률분포가 어떤 의미를 갖는지 살펴보겠습니다.
수학적인 부분 때문에 이해하기 힘든 부분이 있었는데 쉽게 설명해주신 덕분에 많은 도움이 되었습니다. 감사합니다 :)
답글삭제감사합니다 좋은 설명이에요 ㅠㅜ
답글삭제