Implementation : https://github.com/KaiyangZhou/CoOp
오늘 소개드릴 논문은 CoOp: Learning to Prompt for Vision-Language Models
입니다. CoOp은 기존 Vision-Language Model의 prompt engineering을 학습 기반으로 효과적으로 풀어보고자 제안된 방법론입니다. IJCV 2022 accepted paper입니다.
1. AbstractPermalink
기존 vision-language model의 major한 challenge는 prompt engineering이었습니다. 어떠한 prompt를 사용하느냐에 따라 모델 performance에 큰 영향이 있을 수 있으며, 적절한 prompt를 고르는 일은 많은 시간과 사전지식을 필요로 합니다. CoOp은, 이러한 문제점을 해결하기 위해 NLP 분야의 prompt engineering 연구에서 영감을 받아 시작하였으며, image recognition task를 위해 총 2가지 버전의 context를 구현하였습니다. 또한, CoOp은 총 11개의 dataset에 대해, 1~2 shot만으로도 기존 hand-craft prompt의 성능을 크게 앞질렀다고 소개하고 있습니다.
2. IntroductionPermalink
기존 CV분야의 SOTA model들은 discrete label을 사용하여 supervised laerning을 진행하였습니다. 하지만 이러한 학습 패러다임은, 모델의 퍼포먼스를 제한하기 때문에 zero-shot prediction이 불가능하다는 한계가 존재합니다. 이러한 문제점을 해결하고자 등장한 대안이 바로 CLIP으로 대표되는 vision-language model입니다. 이러한 vision-language모델이 기존 학습 패러다임을 대체할 promising한 대안인 것은 맞으나, prompt engineering에 의해 모델의 performance가 크게 변한다는 문제점을 가지고 있습니다. 이렇듯 non-trivial한 task인 prompt engineering을, 학습 기반으로 automatic하게 만들고자 하여 시작된 연구가 바로 CoOp입니다.
CoOp을 쉽게 표현하자면, CLIP의 prompt 부분을 fine-tuning하는 것이라고 할 수 있습니다. 전체적인 CoOp의 학습 과정은 위 figure에 잘 표현되어 있습니다. CLIP을 그대로 사용하되, 모든 파라미터를 freeze한 채 prompt 부분의 context만 learnable한 vector로 두고, GT와 cross-entropy loss를 걸어 학습을 진행합니다. 여기서 주목해야할 부분은 context 부분인데, CoOp은 다양한 recognition task를 다루기 위해 2가지의 context를 구현하였습니다. 그 중 하나는 unified context로, 모든 class가 동일한 context를 가지며, 나머지 하나는 class-specific context로, class마다 context를 다르게 가져갑니다.
CoOp은 총 11개의 dataset에 대해 실험을 진행하였으며, 실험 결과 단지 1~2 shot만으로도 hand-crafted prompt의 성능을 높은 마진으로 능가하는 것을 확인할 수 있었다고 합니다. 또한, 더 많은 16 shot에 대해서는 최대 45%까지도 높은 퍼포먼스를 보여주었다고 하며, learning-base approach임에도 불구하고 domain shift에 더 robust한 모습을 확인할 수 있었다고 합니다.
해당 연구의 contribution은 다음과 같습니다.
-
vision-language model의 downstream application에 대한 시기적절한 연구 제안 & prompt engineering의 중요성 식별
-
prompt engineering을 자동화하기 위해 continuous prompt learning에 기반한 간단한 방법론 제시 & 다양한 recognition task를 다룰 수 있는 2개의 context 구현
-
prompt engineering 분야에서, hand-crafted prompt와 linear probe model을 transfer learning performance와 robustness 측면에서 능가한 최초의 learning-base approach
3. MethodPermalink
3-1. Unified ContextPermalink
앞서 잠시 설명드렸듯이, CoOp은 2가지의 context를 구현하였습니다. 그 중, unified context에 대해 먼저 소개드리도록 하겠습니다. unified context는 모든 class가 동일한 context를 갖습니다. 이 때, prompt t는 다음과 같이 표현될 수 있습니다.
이 때, [V]m (m∈1,…,M)은 word embedding과 같은 dimension을 갖는 벡터이며(CLIP의 경우 512), M은 context token의 개수를 결정하는 하이퍼 파라미터입니다. 이러한 prompt t를 text encoder g(⋅)의 input으로 넣어줌으로써, classificaiton weight vector를 얻을 수 있으며, 이를 통해 다음과 같이 prediction probability를 계산할 수 있습니다.
probability를 계산하는 수식은 CLIP의 것과 동일하며, 간단합니다. 각각 K는 class 수, f는 image feature, cos는 cosine sim 계산 함수를 의미하며 τ는 학습 가능한 temperature params를 의미합니다. 여기서 temperature parameter란, 분포의 skewness를 조절하는 파라미터라고 이해하시면 됩니다. 수식 자체는 간단하기 때문에 쉽게 이해하실 수 있을거라 생각됩니다.
또는 위 식에 대한 변형으로 class를 다음과 같이 중간에 위치시킬 수도 있습니다.
이 경우 학습에 대한 flexibility를 증가시킬 수 있으며, 경우에 따라 class 이후 cell에 보충 설명을 적을 수도, 마침표를 찍어 문장을 끝낼 수도 있다고 합니다.
3-2. Class-Specific ContextPermalink
CoOp의 또다른 옵션으로는 Class-Specific Context(CSC)가 있습니다. CSC는 unified context와는 다르게 각 class마다 independent한 context를 사용하는데, 이를 수식으로 나타내면 다음과 같습니다.
이러한 CSC는 Fine-grained classification task에서 유용하게 사용될 수 있다고 합니다. (e.g. 개 품종 구분 문제)
4. ExperimentsPermalink
CLIP model에 CoOp을 사용하여 prompt engineering을 하였을 때, 위 figure와 같이 총 11개의 데이터셋에 대해서 뛰어난 성능을 보이는 것을 확인할 수 있었습니다. 오직 2 shot만으로도 zero-shot CLIP의 성능을 높은 margin으로 능가할 수 있었으며, 16 shot에 대해서는 zero-shot CLIP 대비 평균 15% 정도의 성능 향상이 있었다고 합니다.
위 figure는 16 shot을 사용하여 학습한 CoOp이 hand-craft prompt에 비해 어느 정도의 성능 향상이 있었는지를 백분율로 나타낸 그래프입니다. EuroSAT이나 Flowers102와 같이 specialize된 task를 갖는 데이터셋에 대해서는 각각 45%, 28%의 성능 향상이 있었으며, 대부분의 fine-grained dataset에서 모두 괄목할 만한 성능 향상이 있는 것을 확인할 수 있습니다. ImageNet의 경우, 수치적으로 큰 improvement는 없는 것처럼 보입니다. 하지만, 이는 ImageNet의 class가 매우 많기 때문이며(1,000개), 이를 감안한다면 4.77%의 성능 향상 역시 주목할 만한 개선입니다.
다만, OxfordPets와 Food101 데이터셋에 대해서는 zero-shot CLIP과 비슷하거나 더 좋지 않은 결과를 보이고 있는데, 저자는 이러한 현상의 원인이 over fitting에 있을 것이라고 추측하고 있습니다. 이러한 현상을 해결하기 위해 차후 weight decay를 더 증가시키는 등의 regularization을 적용할 예정이라고 합니다.
distribution shift에도 CLIP+CoOp 조합이 기존 연구대비 더 robust하다고 합니다.
context의 길이를 길게 가져가는 것이 더 나은 performance를 보여주긴 하지만, 짧은 context length를 사용했을 경우보다 robustness가 더 떨어진다고 합니다. 즉, 모델의 performance와 robustness간에는 trade-off 관계가 존재한다고 하며, 둘 사이의 적절한 밸런스를 찾는 것이 중요하다고 말하고 있습니다.
댓글남기기