Передача знаний в нейронной сети
Статья рассказывает про передачу знаний (knowledge distillation) от уже наученной сети (обычно большой и сложной) к сети необученной (обычно существенно меньших размеров). Исходную большую и сложную сеть называют учителем, а маленькую - учеником.
На самом деле можно рассматривать эту задачу двояко. Во-первых, как способ натренировать малую сеть лучше, чем она сможет это сделать просто на исходном датасете. Здесь работает то, что большая сеть может уловить сходство разных классов и передать эти знания малой сети, которая такие тонкости не заметит. Действительно, если мы решили классифицировать котиков, собачек, птиц и самолеты, то возможно нам будет полезно, что котики могут быть похожи на собак, а вот на самолеты совсем нет, и если большая сеть это “знает” и сможет рассказать малой, то это важно. Потому что малая, возможно, сама из датасета этому научиться не сможет.
Во-вторых, случаются ситуации, когда большая сеть натренирована, но датасет на котором она тренировалась по каким-то причинам не доступен (очень большой, очень под копирайтом и т.п.), зато есть какой-то неразмеченный набор объектов тех же классов с помощью которого, можно передать знания от большой сети к малой.
Итак, предположим, что у нас есть уже натренированная сеть, для классификации на $N$ классов. Последним слоем у этой сети, является softmax, который преобразует оценки (logits) принадлежности объекта классам (обозначим их как $v_i,\, i=1,…,N$) в набор псевдовероятностей:
\[P_i = \frac {\exp (v_i)} {\sum_j \exp (v_j)},\, i=1,...,N\]Так же имеется сеть, которую собираемся натренировать. Обозначим logits сети “ученика” как $z_i$, а вероятности на выходе также будем получать при помощи softmax:
\[Q_i = \frac {\exp (z_i)} {\sum_j \exp (z_j)},\, i=1,...,N\]Предполагается, что есть датасет, возможно неразмеченный или размеченный не полностью, используя который мы будем “передавать знания” от сетки учителя - сетке ученику. Чтобы это сделать авторы статьи предлагают “ослабить” softmax вводя температуру $T$:
\[p_i = \frac {\exp (v_i / T)} {\sum_j \exp (v_j / T)},\, i=1,...,N\] \[q_i = \frac {\exp (z_i / T)} {\sum_j \exp (z_j / T)},\, i=1,...,N\]температура сглаживает вероятностное распределение, при этом сохраняет больше информации о том как сеть учитель относит объекты к классам и похожести разных классов.
Сразу следует оговориться, что после тренировки, в “боевом режиме” сеть-ученик используется только с температурой $T = 1$
В простейшем случае, когда разметка на датасете отсутствует совсем, предлагается в качестве функции потерь использовать кроссэнтропию между ослабленными softmax ученика и учителя:
\[\mathcal L_{KD} = -\sum_j p_j \ln q_j\]При этом $p_j$ в данной сумме фиксированы, в том смысле, что полученны от сети учителя, которая уже натренирована и параметры которой заморожены, а вот веса сети ученика ($q_j$) надо будет натренировать.
В случае, когда у датасета присутствует разметка, авторы предлагают добавить ещё одно слагаемое в функцию потерь, а именно обычную кроссэнтропию, между выходом сети $Q_i$ и разметкой. В их экспериментах это улучшало качество ученической сети.
Разберёмся с производными функции $\mathcal L_{KD}$ по $z_i$:
\[\frac{\partial }{\partial z_i} \mathcal L_{KD} = -\sum_j p_j \frac 1 {q_j} \frac{\partial q_j }{\partial z_i}\]посчитаем $\partial q_j / \partial z_i$.
Рассмотрим вначале случай $i = j$:
\[\frac{\partial q_i }{\partial z_i} = \frac{\partial }{\partial z_i} \left( \frac {\exp (z_i / T)} {\sum_k \exp (z_k / T)}\right) = \left[S = \sum_k \exp (z_k / T)\right] = \\ = \frac {\frac 1 T \exp (z_i / T) \cdot S - \exp (z_i / T) \frac 1 T \exp (z_i / T)} {S^2} =\\ = \frac 1 T \frac {\exp (z_i / T)} {S} - \left(\frac {\exp (z_i / T)} {S} \right)^2 = = \frac 1 T \left(q_i - q_i^2\right)\]Теперь посчитаем тоже самое для $i \neq j$:
\[\frac{\partial q_j }{\partial z_i} = \frac{\partial }{\partial z_i} \left( \frac {\exp (z_j / T)} {\sum_k \exp (z_k / T)}\right) = \left[S = \sum_k \exp (z_k / T)\right] = \\ = -\frac {\exp (z_j / T) \frac 1 T \exp (z_i / T)} {S^2} = -\frac 1 T q_i q_j\]объединяем:
\[\frac{\partial }{\partial z_i} \mathcal L_{KD} =\\ = - \sum_{j\neq i} \frac {p_j} {q_j} \left(-\frac 1 T q_i q_j\right) - \frac {p_i} {q_i} \frac 1 T \left(q_i - q_i^2\right) = \frac 1 T \sum_{j\neq i} p_j q_i - \frac 1 T p_i + \frac 1 T p_i q_i =\\ = \frac 1 T q_i \sum_j p_j - \frac 1 T p_i\]и, наконец, так как $p_j$ - вероятностное распределение и значит $\sum_j p_j = 1$, получаем:
\[\frac{\partial }{\partial z_i} \mathcal L_{KD} = \frac 1 T (q_i - p_i)\]Таким образом:
\[\frac{\partial }{\partial z_i} \mathcal L_{KD} = \frac 1 T \left(\frac {\exp (z_i / T)} {\sum_j \exp (z_j / T)} - \frac {\exp (v_i / T)} {\sum_j \exp (v_j / T)}\right)\]если теперь взять достаточно большое $T$ и разложить экспоненты в ряд Тейлора отбросив все члены старше первой степени, получим:
\[\frac{\partial }{\partial z_i} \mathcal L_{KD} \approx \frac 1 T \left(\frac {1 + z_i / T} {N + \sum_j z_j / T} - \frac {1 + v_i / T} {N + \sum_j v_j / T}\right)\]Далее, авторы предполагают, что на каждом объекте среднее logit-ов и сети учителя и сети ученика будет нулевым, т.е. $\sum_i z_i = \sum_i v_i = 0$, тогда
\[\frac{\partial }{\partial z_i} \mathcal L_{KD} \approx \frac 1 {NT^2} \left(z_i - v_i\right)\]и таким образом при больших $T$ и при условии на нулевое среднее logit-ов на каждом обучающем примере, штрафная функция передачи знаний эквивалентна квадрату нормы разности logit векторов $(z_i - v_i)^2$.
Эксперименты на MNIST
Авторы приводят результаты ряда экспериментов, я опишу только эксперименты с MNIST как мне наиболее близкие.
Большая сеть-учитель представляет из себя полносвязную сеть с двумя скрытыми слоями по 1’200 нейронов в каждом. Тренируется эта сеть на всех 60’000 объектах тренировочного датасета. Во время тренировки используется $L^2$ регуляризация и dropout. Датасет аугментируется - каждая картинка случайным образом сдвигается до 2 пикселей в любом из направлений. В результате получилась сеть-учитель, которая на тестовом наборе выдавала только 67 ошибок.
Далее была натренирована меньшая сеть, с всего 800 нейронами на двух скрытых слоях, при её тренировке регуляризация не использовалась и в результате такая сеть допускала на тестовом наборе 146 ошибок. Если эту сеть малого размера перетренировать новым способом передавая её знания от большой сети, с температурой $T=20$, то количество допускаемых ей ошибок сократится до 74. Это показывает, что таким образом можно действительно передавать знания от одной сети к другой причем в том числе знания как хорошо обобщаться в том числе на сдвинутых примерах картинок, которые сеть - ученик не видела, потому что датасет не аугментировался при тренировке малой сети.
В следующем эксперименте авторы исключили из датасета на котором передавались знания в малую сеть, все изображения цифры 3. Таким образом, малая сеть с такой цифрой не встречалась в принципе, однако, при тестах допустила только 206 ошибок из которых 133 на цифре 3 (всего троек в тестовом наборе 1010). Это связано с тем, что для цифры 3 при обучении ученика учителем на выходе softmax слоя всегда будет оценка смещенная к нулю (действительно, раз мы картинки с цифрой 3 не поместили в датасет на котором передаются знания, то у всех объектов в датасете, значение вероятности на классе 3 классификатора учителя будет существенно близко к нулю, даже если мы сильно разгладим softmax при помощи темепературы $T$). Авторы увеличивают смещение для класса 3 при обучении, и количество ошибок, полученной таким образом сети, сокращается до 109, при этом количество ошибок на картинках класса 3 опускается до 14. Т.е. качество ученика достигает значения 98.6% на классе представителей которого вовсе нет в датасете на котором передаются знания, правда это получается только при правильном смещении для данного класса при обучении.
Последний эксперимент, заключается в том, что авторы оставляют в датасете только 7 и 8 и передают знания на таком наборе. В базовом варианте качество на тестовом наборе получается всего 47.3%, но уменьшив смещение для 7 и 8 (что тоже самое увеличив для отсутствующих в тестовом наборе классов) авторы добиваются, чтобы сеть ученик добралась до 86.8% качества. Это, вне сомнений, очень хороший результат, учитывая, что восемь из десяти классов ученик не видел вообще, а только воспринял от учителя.