-
R과 Few shot learning : 샴 네트워크R 이모저모 2020. 2. 15. 20:49
이번에 다뤄볼 주제는 Few-shot learning으로, 지금까지 블로그에 올려운 주제에 비해 최신 기법입니다. Few shot learning의 사전적 의미는 클래스별로 아주 적은 데이터만으로도 학습을 하는 모델로, 주로 이미지에서 아주 많은 클래스를 가진 문제여서 풍부한 데이터를 구하기 어려울 때 사용하는 학습 방법입니다. 불행히도 현재까지 나온 기술들로는 아주 제한적인 부분에서 성과를 보여왔으나, 최근 open AI의 Reptile(2018) 등 여러 실용적인 알고리즘이 나오면서 활용해볼만한 수준의 주제로 올라왔습니다. 그래서 이번 글에서는 Few shot learning에 대한 간단한 소개와, 이 학습의 대표적의 예시로 주로 나오는 샴 네트워크(Siamese Network, 2015)에 대해 다뤄볼까 합니다.
1. Few shot learning? Meta learning!
글의 제목이나 개요에선 few shot learning이라고 소개하였는데 갑자기 Meta learning이라는 단어가 나오니 무슨 얘기인가 싶을겁니다. Meta learning은 흔히 '학습을 위한 학습'으로 불리우는 방식으로, 메타러닝이라는 개념이 few shot learning을 포함하는 큰 개념으로 사용되기에 사실상 Meta learning을 알아야 few shot learning을 이해할 수 있습니다.
그렇다면 '학습을 위한 학습'이란건 도대체 뭘까요? 이는 메타 러닝이 메타 데이터, 즉 다른 데이터를 설명해줄수 있는 데이터를 기반으로 하는 학습이기 때문입니다. 그러므로 보통 메타 러닝의 목적은 주어진 데이터를 잘 학습하는 것이 아닌, 얼마나 일반화된 항목을 추출하고 다른 데이터가 추가될 때 배워나갈 수 있는지를 목표로 하게 됩니다. 이러한 메타러닝은 크게 Metric based, Optimization based, Model based 3가지로 분류하게 되는데, 여기서 few shot learning은 Metric based learning에 속하게 됩니다. 이 Metric 기반은 매타 데이터를 대표하는 특정 커널을 생성하고, 새로운 데이터가 들어왔을 때 이 커널을 통해 추정하는 방법으로 통계적 기법 중 하나인 최근접이웃(Nearest-Neighbor)과 밀도추정(Density estimation)을 합쳐놓은듯한 모습을 지니게 됩니다. 이번에 다룰 샴 네트워크가 가장 대표적인 Metric 기반 학습의 예시인 이유는 이미지(메타 데이터)를 커널(CNN)에 통과시켜 대표값을 추출한 후, 두 이미지를 비교하는 거리(유사도)를 계산하여 이미지가 다른지 같은지를 판단하기 때문입니다.
2. Siamese Network
위에 설명한대로, 샴 네트워크는 Metric based meta learning / few shot learning의 대표적인 예로 볼 수 있습니다. 이 네트워크의 자세한 구조는 이미지와 함께 설명하겠습니다.
샴 네트워크는 기존의 분류 네트워크들과는 사뭇 다른 구조를 사용합니다. 먼저 입력값을 받는 레이어가 두 개로, 두 레이어는 각각 다른 이미지를 받게 됩니다. 이 레이어를 Convolutional layer들로 학습시켜 이미지 속성을 나타내는 벡터로 변환하고, 두 레이어에서 나온 벡터들의 l1 norm, 즉 절대거리차로 두 입력값의 차이를 계산합니다. 그리고 이 거리값을 가지고 일반적인 이진 분류모델에서 하듯 sigmoid를 통과해서 0/1로 분류를 하게되며, 이때 1은 입력된 두 이미지가 같은 클래스에 들어가있음을, 0은 다른 클래스에 들어감을 의미합니다. 이렇게 학습된 컨볼루션 레이어들, 즉 feature extraction model들은 클래스에 대한 특징을 배우는 것보단 두 이미지가 얼마나 '다른'지, 즉 이미지를 비교할 수 있는 능력을 가지게 됩니다. 그리고 새로운 데이터가 왔을때, 학습에 사용된 데이터셋(Support Set)들과 얼마나 다른지를 비교하여 가장 비슷한 이미지와 같은 클래스로 배분하게 됩니다. 이렇게 모델을 설계함으로서 보다 적은 데이터로도 학습을 할 수 있게 됩니다.
3. R에서의 구현
저는 R에서의 딥러닝 구현을 케라스에 의존하고 있으며, 이번 tensorflow 2.0 이후로 케라스가 공식 api가 되었기에 케라스는 분석가에게 더욱 중요한 api로 자리잡았습니다. 그래서 이번 예제도 케라스를 통해 구현했으며, 혹시 케라스 설치 및 R 내부에서의 연동을 모르시는 분은 구글링을 해서 확인한 후 봐주시면 감사하겠습니다.
3.1 훈련용 데이터 생성
데이터는 mnist 데이터를 사용할 것이며, 당연히 few shot learning인데 모든 데이터를 사용하면 안되므로 일부 데이터만 뽑아서 학습하겠습니다. 이를 위해 먼저 keras 패키지의 dataet_mnist() 함수를 통해 mnist 데이터를 불러옵니다.
이제 few shot을 위해 클래스마다 샘플링을 해줍니다. 물론 완벽하고 숫자를 제일 잘 설명해주는 데이터만 뽑아서 학습해주는 것이 제일 성과가 좋겠지만 저는 그냥 무작위로 3개씩 뽑겠습니다.
(6만,28,28) 크기의 학습 이미지 중 총 30개를 뽑았습니다. 일반적인 분류 모델이라면 그냥 이대로 이미지를 넣으면 되지만, 샴 네트워크는 이미지 입력 레이어가 2개이고 서로를 비교해야하기 때문에, expand.grid함수를 활용하여 이미지끼리 짝을 지어줍니다.
여기서 x1은 첫번째 인풋 레이어에 들어갈 데이터, x2는 두번째에 들어갈 데이터, class_x1,2는 각각의 클래스를 나타내며 y는 둘의 클래스가 같으면1, 다르면 0을 가집니다. 그러나 이렇게 만든 데이터는 (a,b), (b,a)라는 순서가 다르지만 중복되는 쌍을 가지므로, 아래와 같이 중복 제거를 통해 제거해줍니다.
이렇게 완성된 그리드를 이용해서 각각 해당하는 이미지를 불러와서 쌍을 이뤄줘야 합니다. 데이터를 순서대로 정렬해서 만들 수도 있지만 너무 많은 데이터를 중복해서 쌓아야 하기 때문에, 배치에 쓰일 데이터를 생성하는 함수를 만들어 keras의 fit_generator를 활용합니다.
위의 batch_generator함수의 결과물은 리스트로, X에 집어넣을 두 개의 인풋 레이어에 들어갈 데이터들을 묶은 리스트와, y에 들어갈 매트릭스(0/1) 을 만듭니다.
3.2 샴 네트워크 케라스 모델 생성
보통 케라스에서 간단한 모델을 생성할 경우 keras의 순차적 모델 생성방법인 keras_model_sequential을 dplyr chain(%>%)을 통해 쌓아서 생성하지만, 두 개의 인풋 레이어를 가지고 하나의 feature model을 공유하는 방식이기 때문에 비효율적이고 만들기도 어렵습니다. 따라서 예제에서는 인풋 레이어를 두 개 따로 생성하고, 공동으로 사용할 feature model(CNN)을 만들어서 입력값 레이어를 인코딩 한 다음 절대값 거리를 계산하는 모듈과 합쳐서 하나의 모델로 생성할 것입니다.
입력 레이어 정의입니다. input_dim은 이미지를 2d CNN에 통과시킬 것이기 때문에 (x픽셀수,y픽셀수,채널수)를 입력합니다. 여기서는 흑백처리된 28*28 데이터를 사용하기에 위와 같이 (,28,28,1)이 나오게 됩니다.
이미지를 학습하기 위한 모델은 다음과 같이 구성되어 있습니다. 2개의 2차원 컨볼루션 레이어를 통과하며, 마지막에는 896개의 벡터값으로 변환된 이미지 feature를 생성하게 됩니다.
이제 x1, x2에서 나오는 결과값들을 서로 빼주고 절대값을 계산해야 합니다. 즉 입력값 x1, x2가 feature model을 통과한 결과 feature_model(x1), feature_model(x2)를 서로 빼주고 절대값을 계산하는 레이어 l1_distance를 만들게 되며, 이는 keras 백엔드를 R에서 부르는 k_abs 함수로 처리합니다.
이제 keras_model 함수를 이용해 input, output을 지정해줍니다. 현재 처음 나왔단 샴 네트워크 구조 중 마지막 부분인 이진분류(이미지가 같은 클래스인지 다른 클래스인지) 부분이 없는데, 이는 layer_dense()함수를 이용해서 keras_model의 outputs 부분에 구현해줍니다.
이제 model, 혹은 summary(model)을 하면 모델 구조 요약값을 볼 수 있는데, 이 때 feature_model 부분은 Sequential이라고 처리됩니다. 이를 자세히 보고 싶다면 model$layers로 세부 레이어를 불러서 그 중 3번째 부분을 확인할 수 있습니다. 이진분류 문제이므로 loss는 binary_crossentropy를 사용하였으며, 정확도로 metric을 잡았습니다.
*optimizer는 adam을 하는것이 일반적이나, 저는 nadam이 좀 더 잘 나오는 것 같아서 선택했습니다.
3.3 학습
모델도 만들었고, 학습하기 위한 데이터 생성기도 만들었으니 이제 학습할 시간입니다. 일반적으론 학습에 fit 함수를 사용하나 이번엔 데이터 생성기를 만들었으므로 fit_generator 함수를 이용합니다.(*fit_generator는 훈련 데이터가 커서 큰 메모리를 차지하게 될 때 유용합니다) R의 경우 keras 패키지가 자동으로 학습 경과를 시각화하여, 아래와 두번째 그림과 같은 훈련과정 그래프를 얻을 수 있습니다.
그래프에서 보다시피 결과가 매우 좋게 나오는 것을 볼 수 있습니다. 물론 이는 mnist 자체가 워낙 정형화된 데이터이고, 적은 수의 샘플만 사용해서 입니다. 하지만 모든 학습의 목표, 특히 이러한 few shot learning은 training간 결과가 잘 나오는게 중요한 것이 아닙니다. 테스트 데이터를 통해서 얼마나 meta data를 잘 흡수 하였나를 봐야겠죠.
3.4 test 데이터로 평가
테스트에는 mnist의 테스트 데이터 중 각 클래스마다 500개를 뽑아서 진행하였습니다. 샘플링에 따라 다소 차이가 있으나, 보통 60 ~ 70%의 정확도를 보입니다.
이렇게 나오는 경우는 보통 train 데이터가 meta 특성을 잘 반영하지 못하는, 즉 글씨체가 특이한 경우가 많습니다. 이러한 부분을 보완하기 위한 다양한 시도들이 있으나, 아직 완벽하게 작동하는 알고리즘은 찾지 못했습니다. 물론, 논문에선 상대적으로 고화질이고 다양한 글씨체/문자를 포함한 Omniglot 데이터를 사용해서 꽤 괜찮은 성과를 냈었으니 튜닝을 하면 가능할지도 모르겠습니다.
이후 예제코드에 있는 5,6,7은 feature model의 레이어 시각화, 모델 시각화, 결과 저장인데 한번 따라해보시길 바랍니다.
4. 마치며
이번 주제는 구현 난이도도 있고, 메타러닝 및 few shot learning의 역사가 길어 글을 쓰기가 어려웠던 것 같습니다. 그래도 만들고 쓰다보니 재미도 있어서, 다음엔 좀 더 challenging하고 재밌는 글을 써봐야겠단 생각도 드네요.
예제코드 : https://github.com/JunmoNam/applebox_blog/blob/master/R/18)%20siamese%20net.R
'R 이모저모' 카테고리의 다른 글
R과 Leaflet (1) (1) 2020.11.14 R과 Google Map (0) 2020.06.13 sf : R과 지도(2) (0) 2020.01.03 sf : R과 지도 (1) (0) 2019.12.06 R과 Data Wrangling (0) 2019.08.19