Записная книжка

Компьютерное зрение, машинное обучение, нейронные сети и т.п.

Генеративно-состязательная сеть

Генеративно-состязательная сеть (Generative adversarial network) это методика обучения пары моделей:

  1. генеративной $G$, которая по случайному шуму генерирует данные, подчинённые некоторому вероятностному распределению, и

  2. дискриминативной $D$, которая приписывает, поступающим на её вход данным, вероятность того, что они получены из тренировочного датасета или сгенерированы моделью $G$.

Т.е. модель $G$ обучается таким образом, чтобы максимизировать вероятность того, что $D$ ошибётся. А модель $D$ тренируется отличать подмену реальных данных от тех, которые сгенерировала сеть $G$. Таким образом мы получаем как бы игру для двух игроков.

Методика тренировки описана в статье [1] и схематично выглядит так:

Генеративно-состязательная сеть

В качестве примера того, почему этот подход позволяет получить интересные результы, рассмотрим известный набор данных MNIST (см., например, MNIST for ML beginners). Этот набор представляет из себя картинки 28 х 28 пикселей в градациях серого, на которых изображены рукописные цифры от 0 до 9. Воспользовавшись схемой GAN мы можем взять две сети, одна будет по случайному шуму (например, вещественному вектору размерности 100) генерировать картинку 28 х 28 в градациях серого, а вторая пытаться отличить эти сгенерированные картинки от реальных данных взятых из MNIST датасета. Проведя достаточно циклов тренировок мы в результате получим генеративную сеть $G$, которая может по случайному входному вектору выдавать картинку с рукописной цифрой. Пример кода тренировки можно посмотреть вот здесь.

В качестве датасета можно использовать не обязательно MNIST, а, например, базу данных лиц, тогда в результате мы обучим сеть, которая будет уметь создавать лица людей. В принципе, мы можем обучать компьютер генерировать всё на что хватит нашей фантазии и для чего у нас собран более менее приличный датасет. Мне кажется это довольно интересно, поэтому начнём разбираться. И начнём с того, что поймем что такое генеративная и дискриминативная модели и чем они отличаются.

Генеративная и дискриминативная модели

Предположим, у нас есть два множества:

  • $O$ - множество наблюдаемых (observed) величин,

  • $H$ - множество скрытых (hidden) величин.

Тогда наличие генеративной (generative) модели означает, что мы знаем вероятность совместного распределения $P(o, h), o \in O, h \in H$. Таким образом мы можем генерировать события при помощи данной модели.

Дискриминативная (discriminative) же модель означает, что мы знаем условную вероятность $P(h | o)$ и соответственно, можем, например, классифицировать явление (найти наиболее вероятное значение скрытой переменной $h$), наблюдая нечто (зная значение наблюдаемой переменно $o \in O$)

Если генеративная модель известна, то используя формулу Байеса, можно получить дискриминативную модель.

Обычно по теме генеративных vs дискриминативных моделей ссылаются на статью [2].

Примеры

  1. Если множества наблюдаемых и скрытых переменных конечны, то генеративную модель можно описать в виде таблицы:

      $h_1$ $h_2$ $h_N$
    $o_1$ $P(o_1, h_1) = p_{11}$ $P(o_1, h_2) = p_{12}$ $P(o_1, h_N) = p_{1N}$
    $o_2$ $P(o_2, h_1) = p_{21}$ $P(o_2, h_2) = p_{22}$ $P(o_2, h_N) = p_{2N}$
    $o_M$ $P(o_M, h_1) = p_{M1}$ $P(o_M, h_2) = p_{M2}$ $P(o_M, h_N) = p_{MN}$

    Дискриминативная модель получается из данной таблицы, применением формулы Байеса. Обозначим

    $Z_i = \sum_{j=1}^N p_{ij}$

    и получим:

      $h_1$ $h_2$ $h_N$
    $o_1$ $P(h_1 \vert o_1) = p_{11} / Z_1$ $P(h_2 \vert o_1) = p_{12} / Z_1$ $P(h_N \vert o_1) = p_{1N} / Z_1$
    $o_2$ $P(h_1 \vert o_2) = p_{21} / Z_2$ $P(h_2 \vert o_2) = p_{22} / Z_2$ $P(h_N \vert o_2) = p_{2N} / Z_2$
    $o_M$ $P(h_1 \vert o_M) = p_{M1} / Z_M$ $P(h_2 \vert o_M) = p_{M2} / Z_M$ $P(h_N \vert o_M) = p_{MN} / Z_M$
  2. Еще один пример описан раньше - это картинки рукописных цифр. Наблюдаемая величина это собственно картинка: 28 х 28, а скрытая величина - это нарисована ли на картинке цифра.

  3. Хороший пример, это тексты на разных языках. Допустим, что в качестве наблюдаемой величины у нас выступает текст на каком-то языке, а в качестве скрытой - собственно язык на котором текст написан.

    Генеративным подходом в данном случае будет: выучить все возможные языки и тогда мы будем иметь для каждого языка и текста вероятность их совместимости. Дискриминативным же подходом будет только понять как языки отличаются, и приписывать текст к языку на базе этого знания. Можно, например, ограничиться только текстами на двух языках: английском и русском и тогда для генеративного подхода надо всё равно изучить оба языка, а вот для дискриминативного достаточно будет научиться отличать латиницу от кирилицы.

Состязательная сеть

Итак формально мы имеем некоторое множество $X$ примеров, и подмножество $Y \subset X$ позитивных примеров (например, $X$ множество всех картинок размера 28 x 28, а подмножество $Y$ - набор MNIST картинок с рукописными цифрами), таким образом на множестве $X$ определено вероятностное распределение $p_{data}(x)$. Мы хотим отыскать две функции (обе функции мы аппроксимируем при помощи многослойной нейронной сети):

  1. $G(z, \theta_g)$. $z$ - случайный шум с каким-то наперёд заданным распределением $p_z(z)$. $\theta_g$ - параметры, которые будем тренировать. На выходе функция $G$ будет выдавать элемент из множества $X$. Таким образом функция $G$ задаёт распределение $p_g(x)$.

  2. $D(x, \theta_d)$. $x \in X$, а $\theta_d$ - параметры дискриминативной сети. $D(x)$ - вещественное число - вероятность того, что $x$ взято из подмножество $Y$, а не сгенерировано сетью $G$.

Мы одновременно тренируем сети представляющие функции $D$ и $G$. При этом мы ищем такие параметры $\theta_d$, чтобы максимизировать вероятность правильного разделения функцией $D$ позитивных примеров и примеров сгенерированных функцией $G$. И такие параметры $\theta_g$, которые минимизируют функцию $\log(1 - D(G(z)))$.

Можно переформулировать задачу как минимаксную игру для двух игроков с целевой функцией $V(G, D)$:

\[\min_G \max_D V(D, G) = \mathbb{E}_{data}\left(\log D(x))\right) + \mathbb{E}_z\left(\log(1-D(G(z)))\right)\]

В работе [1] теоретически обосновывается, что данная задача имеет глобальный оптимум $-\log(4)$ при $p_g(x) = p_{data}(x)$. Т.е. мы можем натренировать такую функцию $G$, которая будет выдавать исходное распределение, и функция $D$ уже не сможет отличить примеры из тестового набора от примеров генерируемых функцией $G$, а значит будет выдавать вероятность $0.5$ для всех примеров, что собственно и даёт нам в результате $-\log(4)$.

Алгоритм решающий задачу оптимизации, попеременно оптимизирует функцию $G$ при фиксированной $D$, а затем фиксируя $D$ улучшает функцию $G$. При этом, так как на начальном этапе функция $G$ еще не достаточно хорошо умеет генерировать примеры похожие на тестовые, и $D$ назначает им очень маленькую вероятность, то $\log(1-D(G(z)))$ близок к нулю, соответственно обучение параметров функции $G$ будет происходить медленно, поэтому авторы [1] предлагают вместо того, чтобы минимизировать $\log(1-D(G(z)))$ искать максимум $\log(D(G(z)))$. Это эквивалентная задача, но при этом мы будем иметь, на начальном этапе обучения, большие по величине градиенты.

Итак в конечном итоге приходим к следующему алгоритму обучения:

Алгоритм тренировки генеративно-состязательных сетей

  • Для каждого шага тренировки:

    • Делаем $k$ итераций, оптимизируя дискриминатор $D$.

      • Генерируем $m$ примеров $\{z^{(1)}, z^{(2)}, …, z^{(m)}\}$, при помощи текущего варианта функции $G$.

      • Набираем $m$ позитивных примеров $\{x^{(1)}, x^{(2)}, …, x^{(m)}\}$ из тренировочного набора.

      • Обновляем параметры дискриминатора используя градиент:

      \[\nabla_{\theta_d}\frac 1 m \sum_{i=1}^m\left(\log\left(D(x^{(i)})\right) + \log\left(1-D(G(z^{(i)}))\right)\right)\]
    • Генерируем $m$ примеров $\{z^{(1)}, z^{(2)}, …, z^{(m)}\}$, при помощи текущего варианта функции $G$.

    • Обновляем параметры генератора используя градиент:

    \[\nabla_{\theta_g}\frac 1 m \sum_{i=1}^m\log\left(D(G(z^{(i)}))\right)\]

Гиперпараметр $k$ - количество шагов оптимизации дискриминатора на каждый шаг оптимизации генератора. Авторы [1] предлагают использовать $k=1$

Результаты выдаваемые генеративной сетью, натренированной по описанному выше алгоритму на наборе MNIST код для тренировки (не мой):

Результат работы генеративной сети

Здесь и дискриминативная и генеративная сети содержат два полносвязных слоя. А шум, подаваемый на вход генеративной сети, представляет из себя случайный вектор размерности $100$, элементы которого соответствуют равномерному распределению на отрезке $[-1, 1]$.


Литература

  1. Ian J. Goodfellow, Jean Pouget-Abadie, Mehdi Mirza, Bing Xu, David Warde-Farley, Sherjil Ozair, Aaron Courville, Yoshua Bengio, “Generative Adversarial Nets” arXiv:1406.2661 2014

  2. Andrew Y. Ng, Michael I. Jordan, “On Discriminative vs. Generative classifiers: A comparison of logistic regression and naive Bayes” NIPS