Записная книжка

Компьютерное зрение, машинное обучение, нейронные сети и т.п.

Улучшение передачи знаний с использованием ассистента учителя

arXiv:1902.03393

В прошлый раз разбирался с передачей знаний от большой нейронной сети (учителя) к маленькой нейронной сети (ученику). В статье, разбираемой сегодня, авторы выдвигают идею, которую можно описать примерно следующим образом: слишком умный учитель не может хорошо обучить слишком слабого ученика, поэтому такому учителю нужен ассистент. Т.е. передача знаний от большой нейронной сети к очень маленькой, будет идти лучше, если мы вначале натренируем промежуточную сеть: ассистента учителя (teacher assistent) в которую передадим знания от учителя, а уже эта средняя сеть будет использоваться для тренировки совсем малой нейронной сети-ученика.

Учитель, ассистент и ученик, изображение из статьи

Авторы начинают с того, что берут датасеты CIFAR-10 и CIFAR-100 и обучают на них несколько свёрточных нейронных сетей - учителей с разным количеством свёрточных слоёв: 4, 6, 8, 10 (используется архитектура типа VGG, т.е. пара свёрточных слоёв, потом max-pooling, снова пара свёрточных слоёв, и т.д., наконец, полносвязный слой, на выходе которого logits). Очевидно, в данном случае с ростом количества слоёв растёт и точность сети-учителя на тестовых данных. Взяв теперь в качестве ученика сеть из 2 свёрточных слоёв, авторы пытаются передать ему знания от каждого из учителей.

Учитель, ученик качество сетей, изображение из статьи

Из графиков видно, что качество ученика растёт с ростом качества учителя, но лишь до определённого момента, после чего начинает падать.

Авторы предлагают для объяснения полученных эмпирических данных рассмотреть три фактора:

  1. Если качество сети-учителя растёт, то эта сеть должна лучше обучать ученика.

  2. Если сеть-учитель усложняется, то сеть-ученик может быть не достаточно приёмистой, чтобы суметь воспринять передаваемые знания.

  3. С ростом внутренней сложности сеть-учитель начинает чётче разделять данные на классы, а значит её оценки (logits) становятся менее сглажены, а собственно, сглаженность оценок есть фактор, который и работает в knowledge distillation подходе.

Первый фактор приводит к росту, а второй и третий наоборот к уменьшению, качества сети ученика при обучении от более сложной сети учителя.

Я тут от себя замечу, что фактор номер один крайне противоречив, т.е. наилучшее качество классификации - это просто брать GT данные из датасета, однако, именно их размытие и позволяет лучше тренировать ученика, а вот растёт ли с ростом качества классификации, качество поиска схожести классов - вопрос сложный и как раз фактор номер три, говорит, что скорее всего, если и растёт, то только до определенного значения сложности сети.

Далее авторы проводят обратный эксперимент: берут в качестве учителя сеть размера 10 (размер - кол-во свёрточных слоёв) и обучают с её помощью учеников размера 2, 4, 6 и 8. На графиках отображён прирост качества сети-ученика, натренированной с передачей знаний от сети-учителя размера 10, относительно просто тренировки такой сети-ученика на соотвествующем датасете.

Прирост качества сетей-учеников, изображение из статьи

Таким образом, передача знаний от очень большой сети к очень маленькой оказывается не вполне эффективной. Поэтому, авторы статьи предлагают использовать промежуточную сеть - помошник учителя (teacher assistant), натренируем вначале эту промежуточную сеть, используя сеть-учитель, а затем натренируем ученика, используя в качестве источника знаний помошника учителя (teacher assistant knowledge distillation (TAKD)).

В качестве доказательства, что такой подход работает, авторы проверяют его на CIFAR-10 и CIFAR-100 на двух типах свёрточных сетей: VGGlike и ResNet, и дополнительно на ImageNet датасете на ResNet моделе.

Модель   | Dataset   |   NoKD  |   BLKD  |   TAKD  |
----------------------------------------------------
VGG      | CIFAR-10  |  70.16% |  72.57% |  73.51% |
         | CIFAR-100 |  41.09% |  44.57% |  44.92% |
----------------------------------------------------
ResNet   | CIFAR-10  |  88.52% |  88.65% |  88.98% |
         | CIFAR-100 |  61.37% |  61.41% |  61.82% |
----------------------------------------------------
ResNet   | ImageNet  |  65.20% |  66.60% |  67.36% |
----------------------------------------------------
  • NoKD - тренировка ученика, без учителя, только по набору данных
  • BLKD - передача знаний от учителя к ученику непосредственно (baseline knowledge distillation)
  • TAKD - передача знаний от учителя к ученику, через помошника

Для CIFAR используем следующие размеры сетей:

  • VGG - ученик размера S = 2, помошник учителя TA = 4 и учитель T = 10
  • ResNet - S = 8, TA = 20, T = 110

Для ImageNet, ResNet - S = 14, TA = 20, T = 50

Из таблицы видно, что вариант с помошником (TAKD) во всех случаях даёт наибольшую точность.

Следуюший вопрос, а какого размера должна быть промежуточная сеть, чтобы передать знания максимально эффективно? Чтобы это выяснять, авторы проводят эксперименты с различными размерами (не всеми возможными, но достаточно, для получения оценок). Первый набор экспериментов это VGGlike сеть на датасетах CIFAR, знания передаются от сети размера 10 к сети размера 2, через промежуточную сеть размера TA:

Модель   | Dataset   |  TA = 8 |  TA = 6 |  TA = 4 |
----------------------------------------------------
VGG      | CIFAR-10  |  72.75% |  73.15% |  73.51% |
         | CIFAR-100 |  44.28% |  44.57% |  44.92% |
----------------------------------------------------

Второй набор экспериментов для ResNet варианта сети на тех же датасетах, с учителем размера T = 110 и учеником S = 8

Модель   | Dataset   | TA = 56 | TA = 32 | TA = 20 | TA = 14 |
--------------------------------------------------------------
ResNet   | CIFAR-10  |  88.70% |  88.73% |  88.90% |  88.98% |
         | CIFAR-100 |  61.47% |  61.55% |  61.82% |  61.50% |
--------------------------------------------------------------

Дальше авторы делают интересное, но как мне кажется, слабо обоснованное предположение. Рассмотрим следующий график, для VGGlike сетей на CIFAR-10 и CIFAR-100:

VGGlike сети на CIFAR-10 и CIFAR-100, изображение из статьи

синим здесь показано качество сетей разных размеров, натренированных только на датасете без KD, а красная черта это среднее арифметическое значений качества сети размера 10 и сети размера 2. Авторы пишут: “смотрите, среднее находится на уровне сети размера 4, и сеть этого размера, использованная в качестве помошника даёт наилучшее качество при тренировки ученика, похоже именно так и надо выбирать размер помошника”. Этот вывод подтверждает и график для варианта с ResNet сетью:

ResNet сети на CIFAR-10 и CIFAR-100, изображение из статьи

Здесь для CIFAR-10 средняя проходит через сеть размера TA = 14, а для CIFAR-100 - через сеть размера TA = 20 и действительно, при таких размерах помошника мы получаем максимальное качество при передаче знаний от учителя ученику.

Однако, во-первых, не понятно как это обосновать теоретически. Во-вторых, чтобы получить таким образом оценку размера помошника надо натренировать достаточно большое число сетей промежуточного размера, можно сразу решать задачу передачи знаний с разными размерами помошника и выбрать лучший вариант ученика.

Наконец, последний вопрос: а что если вместо одного помошника, использовать серию, последовательно уменьшая размер? Авторы статьи поставили эксперимент для ответа и на этот вопрос. Результаты этого эксперимента сведены в следующую таблицу:

Несколько шагов передачи знаний, изображение из статьи

Надо отметить, что из таблицы видно, что лучший результат получается при последовательной тренировки помошника всё меньшего размера, а резкие переходы снижают результирующую точность.