이 글은 제가 공부한 내용을 정리하는 글입니다. 따라서 잘못된 내용이 있을 수도 있습니다. 잘못된 내용을 발견하신다면 리플로 알려주시길 부탁드립니다. 감사합니다.
Knowledge distillation 이란?
Knowledge distillation 은 NIPS 2014 에서 제프리 힌튼, 오리올 비니알스, 제프 딘 세 사람의 이름으로 제출된 "Distilling the Knowledge in a Neural Network" 라는 논문에서 제시된 개념입니다.
Knowledge distillation 의 목적은 "미리 잘 학습된 큰 네트워크(Teacher network) 의 지식을 실제로 사용하고자 하는 작은 네트워크(Student network) 에게 전달하는 것" 입니다.
이 목적을 풀어서 설명하면 다음과 같습니다. 딥러닝 모델은 보편적으로 넓고 깊어서 파라미터 수가 많고 연산량이 많으면 feature extraction 이 더 잘 되고, 그에 따라서 모델의 목적인 Classification 이나 Object detection 등의 성능 또한 좋아질 것입니다.
그러나 딥러닝은 단순히 "목적 성능이 좋은 모델이 좋은 모델이다." 라고 말할 수 있는 기술 수준을 넘어섰습니다. 작은 모델로 더 큰 모델만큼의 성능을 얻을 수 있다면, Computing resource(GPU 와 CPU), Energy(배터리 등), Memory 측면에서 더 효율적이라고 말할 수 있겠죠. 예를 들어서 핸드폰에서 딥러닝을 활용한 어플리케이션을 사용하고 싶은데, 몇 GB 의 메모리를 필요로하는 모델을 사용하려면 온라인 클라우드 서버등에 접속해서 GPU 등의 자원을 사용해야 하지만, 이 모델을 충분히 작게 만들어서 핸드폰의 CPU 만으로도 계산이 가능하다면 여러가지 비용 측면에서 더 좋을 것입니다.
이렇게 Knowledge distillation 은 작은 네트워크도 큰 네트워크와 비슷한 성능을 낼 수 있도록, 학습과정에서 큰 네트워크의 지식을 작은 네트워크에게 전달하여 작은 네트워크의 성능을 높이겠다는 목적을 가지고 있습니다.
Knowledge distillation 방법론
Loss fucntion 과 네트워크 구조를 함께 보면 Knowledge distillation 을 쉽게 이해할 수 있습니다.
이미지 출처: https://nervanasystems.github.io/distiller/knowledge_distillation.html
네트워크 구조를 가장 직관적으로 잘 표현한 그림인 것 같아 가져왔습니다.
Loss function 수식은 "Distilling the Knowledge in a Neural Network" 를 인용한 논문인 "Similiarity preserving knowledge distillation" 에서 가져왔습니다.(SPKD 에 써있는 수식이 더 간결한 것 같아서요.)
α, T 에 대한 설명은 Loss function 의 설명에서 가장 마지막에 하겠습니다.
Loss function 의 왼쪽항은 아래 그림의 파란색 영역에 해당합니다.
이는 Student network 의 분류 성능에 대한 Loss 로, Ground truth 와 Student 의 분류 결과와의 차이를 Cross entropy loss 로 계산하고 있습니다.
Loss function 의 오른쪽 항은 아래 그림의 빨간색 영역에 해당합니다.
오른쪽항은 Teacher network 와 Student network 의 분류결과의 차이를 Loss 에 포함시키는 것입니다. Teacher 와 Student 의 Output logit 을 Softmax 로 변환한 값의 차이를 Cross entropy loss 로 계산하고 있습니다. Teacher 와 Student 의 분류 결과가 같다면 작은 값을 취합니다.
여기서 두 네트워크의 분류 결과를 비교하기 위해서 Hard label 이 아닌 Soft label 을 사용하고 있는데,
곰, 고양이, 개 3가지 클래스를 구분하는 모델이 있을 때, 분류 결과가 왼쪽과 같다면 Hard label, 오른쪽과 같다면 Soft label 이라고 부릅니다. 위 예시의 결과값을 보면 입력 이미지가 고양이었다고 유추할 수 있습니다. Soft label 을 곰곰이 생각해보면 입력 이미지에서 고양이와 개가 함께 가지고 있는 특징들이 어느정도 있었기 때문에 Dog class score 가 0.2 만큼 나왔다고 생각할 수 있습니다. 결과값을 Hard label 로 표현하면 이런 정보가 사라지게 됩니다.
이런 정보의 손실 없이, Teacher network 의 분류 결과를 Student network 의 분류 결과와 비교시켜서, Student network 가 Teacher network 를 모방하도록 학습시킵니다.
Hyperparameter 인 α, T 중 α 는 왼쪽항과 오른쪽항에 대한 가중치입니다. α 가 크면 오른쪽항 Loss 를 더 중요하게 보고 학습하겠다는 의미죠. T 는 Temperature 라고 부르는데 Softmax 함수가 입력값이 큰 것은 아주 크게, 작은 것은 아주 작게 만드는 성질을 완화해줍니다. 기존 Softmax 함수와 Temperature 를 사용한 Softmax 함수는 다음과 같이 표현되는데,
간단하게 예시를 하나 계산해보면
Temperature 를 사용한 경우가 낮은 입력값의 출력을 더 크게 만들어주고 큰 입력값의 출력은 작게 만들어주는 것을 알 수 있습니다. Temperature 를 사용하여, Soft label 을 사용하는 이점을 최대화합니다.
요약 혹은 정리
Knowledge distillation 은 미리 학습시킨 Teacher network 의 출력을 내가 실제로 사용하고자 하는 작은 모델인 Student network 가 모방하여 학습함으로써, 상대적으로 적은 Parameter 를 가지고 있더라도 모델의 성능을 높이는 방법론입니다.