논문링크 : CoOp: Learning to Prompt for Vision-Language Models

Implementation : https://github.com/KaiyangZhou/CoOp

    오늘 소개드릴 논문은 CoOp: Learning to Prompt for Vision-Language Models입니다. CoOp은 기존 Vision-Language Model의 prompt engineering을 학습 기반으로 효과적으로 풀어보고자 제안된 방법론입니다. IJCV 2022 accepted paper입니다.



1. Abstract

    기존 vision-language model의 major한 challenge는 prompt engineering이었습니다. 어떠한 prompt를 사용하느냐에 따라 모델 performance에 큰 영향이 있을 수 있으며, 적절한 prompt를 고르는 일은 많은 시간과 사전지식을 필요로 합니다. CoOp은, 이러한 문제점을 해결하기 위해 NLP 분야의 prompt engineering 연구에서 영감을 받아 시작하였으며, image recognition task를 위해 총 2가지 버전의 context를 구현하였습니다. 또한, CoOp은 총 11개의 dataset에 대해, 1~2 shot만으로도 기존 hand-craft prompt의 성능을 크게 앞질렀다고 소개하고 있습니다.



2. Introduction

    기존 CV분야의 SOTA model들은 discrete label을 사용하여 supervised laerning을 진행하였습니다. 하지만 이러한 학습 패러다임은, 모델의 퍼포먼스를 제한하기 때문에 zero-shot prediction이 불가능하다는 한계가 존재합니다. 이러한 문제점을 해결하고자 등장한 대안이 바로 CLIP으로 대표되는 vision-language model입니다. 이러한 vision-language모델이 기존 학습 패러다임을 대체할 promising한 대안인 것은 맞으나, prompt engineering에 의해 모델의 performance가 크게 변한다는 문제점을 가지고 있습니다. 이렇듯 non-trivial한 task인 prompt engineering을, 학습 기반으로 automatic하게 만들고자 하여 시작된 연구가 바로 CoOp입니다.


    CoOp을 쉽게 표현하자면, CLIP의 prompt 부분을 fine-tuning하는 것이라고 할 수 있습니다. 전체적인 CoOp의 학습 과정은 위 figure에 잘 표현되어 있습니다. CLIP을 그대로 사용하되, 모든 파라미터를 freeze한 채 prompt 부분의 context만 learnable한 vector로 두고, GT와 cross-entropy loss를 걸어 학습을 진행합니다. 여기서 주목해야할 부분은 context 부분인데, CoOp은 다양한 recognition task를 다루기 위해 2가지의 context를 구현하였습니다. 그 중 하나는 unified context로, 모든 class가 동일한 context를 가지며, 나머지 하나는 class-specific context로, class마다 context를 다르게 가져갑니다.


    CoOp은 총 11개의 dataset에 대해 실험을 진행하였으며, 실험 결과 단지 1~2 shot만으로도 hand-crafted prompt의 성능을 높은 마진으로 능가하는 것을 확인할 수 있었다고 합니다. 또한, 더 많은 16 shot에 대해서는 최대 45%까지도 높은 퍼포먼스를 보여주었다고 하며, learning-base approach임에도 불구하고 domain shift에 더 robust한 모습을 확인할 수 있었다고 합니다.


    해당 연구의 contribution은 다음과 같습니다.

  • vision-language model의 downstream application에 대한 시기적절한 연구 제안 & prompt engineering의 중요성 식별

  • prompt engineering을 자동화하기 위해 continuous prompt learning에 기반한 간단한 방법론 제시 & 다양한 recognition task를 다룰 수 있는 2개의 context 구현

  • prompt engineering 분야에서, hand-crafted prompt와 linear probe model을 transfer learning performance와 robustness 측면에서 능가한 최초의 learning-base approach



3. Method

3-1. Unified Context

    앞서 잠시 설명드렸듯이, CoOp은 2가지의 context를 구현하였습니다. 그 중, unified context에 대해 먼저 소개드리도록 하겠습니다. unified context는 모든 class가 동일한 context를 갖습니다. 이 때, prompt $t$는 다음과 같이 표현될 수 있습니다.


\[t = [V]_1 [V]_2 ... [V]_M [CLASS]\]

    이 때, $[V]_m$ $(m \in {1, …, M})$은 word embedding과 같은 dimension을 갖는 벡터이며(CLIP의 경우 512), $M$은 context token의 개수를 결정하는 하이퍼 파라미터입니다. 이러한 prompt $t$를 text encoder $g(\cdot)$의 input으로 넣어줌으로써, classificaiton weight vector를 얻을 수 있으며, 이를 통해 다음과 같이 prediction probability를 계산할 수 있습니다.


\[p(y=i | x) = \frac{exp(cos(g(t_i), f) / \tau)} {\sum^{K}_{j=1} exp( cos(g(t_j), f) / \tau ) }\] \[\begin{matrix} K &:& \mathrm{the \; number \; of \; classes} \\ f &:& \mathrm{the \; image \; feature \; obtained \; by \; image \; encoder} \\ cos(,) &:& \mathrm{cosine \; similarity} \\ \tau &:& \mathrm{temperature \; parameter} \end{matrix}\]

    probability를 계산하는 수식은 CLIP의 것과 동일하며, 간단합니다. 각각 $K$는 class 수, $f$는 image feature, $cos$는 cosine sim 계산 함수를 의미하며 $\tau$는 학습 가능한 temperature params를 의미합니다. 여기서 temperature parameter란, 분포의 skewness를 조절하는 파라미터라고 이해하시면 됩니다. 수식 자체는 간단하기 때문에 쉽게 이해하실 수 있을거라 생각됩니다.

    또는 위 식에 대한 변형으로 class를 다음과 같이 중간에 위치시킬 수도 있습니다.


\[t = [V]_1 ... [V]_{M\over2} [CLASS] [V]_{\frac{M}{2} + 1} ... [V]_M\]

    이 경우 학습에 대한 flexibility를 증가시킬 수 있으며, 경우에 따라 class 이후 cell에 보충 설명을 적을 수도, 마침표를 찍어 문장을 끝낼 수도 있다고 합니다.



3-2. Class-Specific Context

    CoOp의 또다른 옵션으로는 Class-Specific Context(CSC)가 있습니다. CSC는 unified context와는 다르게 각 class마다 independent한 context를 사용하는데, 이를 수식으로 나타내면 다음과 같습니다.


\[[V]^i_1 [V]^i_2 ... [V]^i_M \; \ne \; [V]^j_1 [V]^j_2 ... [V]^j_M, \\ (i \ne j \; \; and \; \; i,j \in \{1, ..., K\})\]

    이러한 CSC는 Fine-grained classification task에서 유용하게 사용될 수 있다고 합니다. (e.g. 개 품종 구분 문제)



4. Experiments

    CLIP model에 CoOp을 사용하여 prompt engineering을 하였을 때, 위 figure와 같이 총 11개의 데이터셋에 대해서 뛰어난 성능을 보이는 것을 확인할 수 있었습니다. 오직 2 shot만으로도 zero-shot CLIP의 성능을 높은 margin으로 능가할 수 있었으며, 16 shot에 대해서는 zero-shot CLIP 대비 평균 15% 정도의 성능 향상이 있었다고 합니다.


    위 figure는 16 shot을 사용하여 학습한 CoOp이 hand-craft prompt에 비해 어느 정도의 성능 향상이 있었는지를 백분율로 나타낸 그래프입니다. EuroSAT이나 Flowers102와 같이 specialize된 task를 갖는 데이터셋에 대해서는 각각 45%, 28%의 성능 향상이 있었으며, 대부분의 fine-grained dataset에서 모두 괄목할 만한 성능 향상이 있는 것을 확인할 수 있습니다. ImageNet의 경우, 수치적으로 큰 improvement는 없는 것처럼 보입니다. 하지만, 이는 ImageNet의 class가 매우 많기 때문이며(1,000개), 이를 감안한다면 4.77%의 성능 향상 역시 주목할 만한 개선입니다.

    다만, OxfordPets와 Food101 데이터셋에 대해서는 zero-shot CLIP과 비슷하거나 더 좋지 않은 결과를 보이고 있는데, 저자는 이러한 현상의 원인이 over fitting에 있을 것이라고 추측하고 있습니다. 이러한 현상을 해결하기 위해 차후 weight decay를 더 증가시키는 등의 regularization을 적용할 예정이라고 합니다.


    distribution shift에도 CLIP+CoOp 조합이 기존 연구대비 더 robust하다고 합니다.


    context의 길이를 길게 가져가는 것이 더 나은 performance를 보여주긴 하지만, 짧은 context length를 사용했을 경우보다 robustness가 더 떨어진다고 합니다. 즉, 모델의 performance와 robustness간에는 trade-off 관계가 존재한다고 하며, 둘 사이의 적절한 밸런스를 찾는 것이 중요하다고 말하고 있습니다.

댓글남기기