데이터 분석/딥러닝

[SeNet] 논문 리뷰 & 구현 (Pytorch)

개발자 소신 2020. 12. 22. 14:45
반응형

안녕하세요 ! 소신입니다.

 

 

2017년 마지막 ILSVRC대회에서 1등을 한 SeNet 모델입니다.

사실 모델이라고 하기는 조금? 애매하고 새로운 Block이라고 할 수 있습니다.

 

최근 연구에서 네트워크 내부 피쳐간의 공간적 상관성을 학습하는 메커니즘으로 CNN이 좀 더 강력하게 될 수 있다는 것을 확인했습니다.

따라서 이러한 연구에 맞춰 CNN 내부 Feature들의 중요도를 학습하는 SeNet이 탄생하게 된 것이죠

SeNet을 활용하면, 적은 Complexity의 증가로 Performance의 증가를 가져올 수 있습니다.

 

 


# SeNet의 원리

SeNet은 위와같이, WxH의 Feature Map을 Squeeze (압축)하는 첫 번째 단계가 있습니다.

논문상에선 편리상 Global Average Pooling을 사용했지만, 압축하는 방법은 정해져있지 않습니다.

HxWxC의 Feature Map을 GAP를 거쳐 1x1xC의 형태로 압축합니다.

 

 

그 뒤, Excitation을 거쳐, Channel 사이의 Feature들에 대한 중요도를 Recalibaration(재조정)하게 됩니다.

(색깔은 각 Feature의 0~1까지의 중요도를 나타냄)

 

이 재조정된 중요도를 담은 Feature를 원래 Feature Map에 곱해주어 중요도가 학습된 새로운 Feature Map이 탄생하게 되는 것이죠.

 

이렇게 Feature Map을 압축 - 재조정하는 과정이 SeNet의 핵심입니다.

 


# SeNet 구조

Squeeze는 pooling이 전부이지만, Excitation은 다릅니다.

추출된 1x1xC의 피처맵을 FC Layer를 거쳐 C/r (r = reduction ratio) 만큼 1차적으로 차원을 축소해줍니다.

ReLU를 거쳐 FC Layer에서 C만큼 다시 차원을 늘립니다.

이렇게 추출된 중요도 Map을 원래 Feature Map과 곱해 Feature Map을 재조정 해줍니다.

 

Excitation은 C에서 C/r로 축소된다음, 다시 C가 되는 과정이라고 할 수 있습니다.

 


# 성능

출처 : https://bskyvision.com/425

성능은 위에서 보시면 알 수 있듯이, 2017년 이미지분류대회에서 2.3%의 분류오차를 내면서 1등을 가져가게 됩니다.

이후로는 대회가 열리지 않았습니다. 이정도면 할만큼 했다는 걸까요


# Pytorch 구현

from torch import nn

class SEBlock(nn.Module):
    def __init__(self, c, r=16):
        super(SEBlock, self).__init__()
        self.squeeze = nn.AdaptiveAvgPool2d(1)
        self.excitation = nn.Sequential(
            nn.Linear(c, c // r, bias=False),
            nn.ReLU(inplace=True),
            nn.Linear(c // r, c, bias=False),
            nn.Sigmoid()
        )

    def forward(self, input_filter):
        batch, channel, _, _ = input_filter.size()
        se = self.squeeze(input_filter).view(batch, channel)
        se = self.excitation(se).view(batch, channel, 1, 1)
        return input_filter * se.expand_as(input_filter)

SEBlock 코드입니다. 

c는 filter의 channel 크기이고, r은 channel을 얼마나 축소시켰다가 다시 늘릴지를 나타내는 reduction ratio입니다.

 

AdaptiveAvgPool을 사용해서 이미지 사이즈가 어떻게 들어오든 1x1의 정방형 필터로 바꿔주고,

excitation에서 FC Layer를 거쳐 중요도 map을 추출합니다.

forward에서 input filter에 계산된 se map을 곱해줘 가중치가 조정된 filter가 최종적으로 나오게 됩니다.

 

이 블록을 layer에 넣으면 끝입니다.


# 결론

SE Block을 MobileNet에 추가해서 학습을 진행했습니다.

총 소요시간은 65분으로 기존 MobileNet이 61분소요된것에 비해 4분정도가 더 소요되었습니다

(Train 진행하면서 validation도 동시 진행)

그럼에도 불구하고 76.74%로 기존 MobileNet의 70.13%보다 정확도 6.61%의 향상을 보였습니다.

 

논문을 보다보면, Multi-Crop과 Multi-Scaling에 대한 내용이 많이 나왔습니다.

SeNet의 성능도 물론 궁금했지만, 다양한 앙상블 전략을 직접 구현 해보고싶었습니다.

 

Input이 하나일 때, 여러 Augmentation을 적용해 여러개의 Input으로 나눈 뒤,

예측된 결과들을 통합해 예측하는 것을 Test-Time-Augmentation(TTA) 이라고합니다.

Pytorch에도 구현이 되어있어, 직접 한번 시도해봤습니다.

기존 76.74%

 

Five Crop(5) + Horizontal Flip(2) = 10

앙상블 정확도 : 80.25%

앙상블만으로 정확도를 3.51% 향상시킬 수 있었습니다.

물론 validating 시간이 기존에 비해 10배 정도 늘어난다는 단점이 있지만,

정확도가 중요하다면, 앙상블은 필수불가결한 선택인 것 같습니다.

 


Ref.

SeNet Article

SeNet 설명

SeNet Pytorch

TTA github

반응형