Notice
Recent Posts
Recent Comments
Link
관리 메뉴

one by one ◼◻◼◻

[NLP 논문리뷰]STraTA: Self-Training with Task Augmentation for Better Few-shot Learning 본문

논문리뷰

[NLP 논문리뷰]STraTA: Self-Training with Task Augmentation for Better Few-shot Learning

JihyunLee 2021. 9. 23. 02:20

제목 : HowMuchKnowledge Can You Pack Into the Parameters of a Language Model?

저자 : Tu Vu, Minh-Thang Luong, Quoc V. Le, Grady Simon, Mohit Iyyer, Noam Shazeer

발행년도 : 2021

paper : https://arxiv.org/abs/2109.06270

code : https://github.com/google-research/google-research/tree/master/TA-ST

 

Review

 

Abstract

최근의 발전이 NLP task에서 많은 발전을 이뤘지만, Large scale의 pretrained 언어 모델이 few shot 세팅에서는 그리 좋은 성능을 보이지 못하고 있다. 이 문제를 해결하기 위해 저자들은 Self-Trainning방법과 Task Augmentation방법을 활용한 논문을 냈다.

 

BERT는 few shot learning에서 성능이 떨어지는것이 밝혀져 있고, GPT-3 는 few example만 가지고 fine tunning없이 문제를 풀 수 있는 능력을 보여주었지만, state of art의 성능보다 훨씬 미치지 못하고 있다. 이 문제는 새롭게 데이터를 만들면 해결되겠지만, 데이터의 수를 manually 늘리는 방법은 돈과 시간이 많이 드는 일이기 때문에, machine learning을 이용한 text augmentation 방법은 데이터가 부족한 상황에서 도움이 될 수 있다. 그래서 논문에서는 Task Augmentation을 self trainning앞의 auxilary(중간) task로 두었다.  NLI data generator 모델은 MNLI데이터를 통해 학습시켰다. 

모델의 전체적인 구조

그뒤, STraTA는 NLI data generator로 생성된 데이터로부터 학습된 pre-trained language model을 self-trainning 의 base model으로 사용한다. 각각의 iteration에서 base model은 사용가능한 labeled데이터를 통해 학습된다. base model에서 생성된 prediction의 결과는 기존에 있던 labeled data에 추가되어 student model의 학습데이터로 사용된다. student model은 다시 teacher model이 되어 prediction을 진행한다. 이 과정이 반복되어 모델의 성능이 어느정도 안정화 되었다면, 학습을 멈추게 된다.

이 논문에서는 base model이 좋은 성능을 가지고 있는것이 self trainning에 중요한 영향을 미쳤다고 한다.

 

논문의 Main contribution은 

1. fine tunning 모델을 이용한 task augmentation방법 제시

2. 간단하지만 효과적인 self trainning방법을 제시하고, 중요한 요소들을 분석

3. STraTA를 통해 task augmentation과 self trainning의 효과성을 다양한 NLP benchmarks를 통해 증명

이다.

 

2. Task Augmentation

Task augmentation방식이 NLP에서 최근 많이 사용되고 있었으나, 대부분은 MNLI나 SQuAD를 사용해서 중간단계의 fine tunning을 거쳤다. 그러나 이런 방식은 target task와 auxiliary task가 일치하지 않는다는 명확한 문제점을 가지고 있었다. 이 논문에서는 이러한 문제를 해결하기 위해 먼저 general language model(T5-3B)을 MNLI로 fine tunning시킨 뒤, fine tunning 된 모델로 in-domain(target data와 동일한 형식)의 데이터를 생성했다. 생성된 In-domain 데이터를 general language model에 학습시켜 Auxiliary task model을 생성하는 방식을 택해, In-domain data로 학습된 auxiliary task model을 만들었다.

 

MNLI 데이터를 학습할때는 (sent1, sent2)-> label의 형식이 아니라, (label, sent1)->sent2와 같이 text to text format으로 변형시켜 사용하였다.

데이터 변형 예시

Self-training

Slef trainning 알고리즘

STraTA가 사용한 self trainning의 방식은 1. teacher 모델이 unlabeled data에 대해서 predict를 하게 하고, 2. predict 한 데이터의 라벨과(pseudo-label), 기존의 라벨 된 데이터를 합치고, 3. student 모델이 teacher model이 labeling한 데이터에 대해 학습한 다음 그 모델이 다시 teacher가 되어 predict를 진행하게 되는 과정을 거치는 학습과정을 진행한다.

여기서 STraTA는 기존 self trainning방식과는 다르게 가장 처음의 모델의 모델이 auxilary data로 학습되어 있기 때문에더 좋은 성능을 가지고 있다.

 

이러한 self trainning방식에서 중요한 것은 pseudo-label중 적절한 것을 고르는 문제이다. 전통적인 방식에서는 특정한 확률값을 정해서 teacher모델이 자신있게 선택한 psudo-label만을 이용해서 student model을 학습시켰지만, 이러한 방식은 BERT와 같은 state of art의 laguage 모델이 정답이 틀렸음에도 over confident하게 정답을 찍는다는 문제가 있어 적절하지 않았다. 따라서 논문에서는 다양한 pseudo-label 선택 방법들을 실험 해 본 결과 모든 psudo-label결과를 사용하는것이 더 낫다는 결론을 얻었다.

왼쪽은 pseudo data를 전부 사용하지 않았을때, 오른쪽은 전부 사용한 경우(논문에서 제안)

Experiment

TA 또는 STraTA가 논문에서 제안한 방식인데 대부분의 실험에서 가장 좋은 성능을 거둔것을 볼 수있다.

이 말고도 정말 다양한 실험을 했는데 자세한 것은 논문을 참고하면 좋을 듯!

 

 

Comments