코딩 및 기타/이미지

deeplab v2

정지홍 2025. 2. 1. 00:08

deeplab v2

  • 딥러닝 기반의 semantic segmentation 모델이다.
  • image내의 각 픽셀을 특정 클래스에 할당하는 작업에서 뛰어남.
  • 주요 구성 요소
  • 구조
    • 1. backbone network : ResNet-101(Residual Network-101)
      • ResNet-101의 마지막 몇 층에 atrous convolution을 추가해서 수용 영역을 확장함.
        • 이 과정에서 stride를 줄이고 atrous convolution으로 이미지 해상도를 유지함
    • 2. ASPP 모듈
      • network의 마지막 부분에 병렬로 atrous convolution 블록을 추가하여 다양한 스케일의 특징을 추출
    • 3. CRF 후처리
      • segmentation 결과를 개선하기 위해서 CRF를 사용함. 이로 인해 픽셀간의 경계를 더 세밀하게 만듬
  • 한계점
    • CRF의 계산비용으로 인해서 속도가 저하
    • 복잡한 객체의 경계 처리나 고밀도 객체를 분리하는데는 어렵다.

 

 

deeplab v2의 동작 과정

  • 1. 입력을 받고 backbone network에 적합한 형식으로 정규화 한다.
  • 2. ResNet-101을 사용하여 특징을 수출
    • 여기에서 ResNet-101의 마지막 몇개의 layer에 atrous convolution을 적용해서 receptive filed를 확장
      • 이렇게 하면 다운 샘플링을 최소화하며, 더 많은 문맥 정보를 가진 feature map을 얻을수있다.
  • 3. ASPP를 사용한다.
    • 여러 dilation rate로 처리한 결과를 병합하여 하나의 feature map을 얻는다.
  • 4. 최종 feature map 처리
    • ASPP로 생성된 feature map은 1x1 convolution으로 채널 수 를 축소한다.
      • 이렇게 하여, 각 픽셀에 대한 클래스 점수를 계산하기 위한 준비를 함
  • 5. upsampling을 사용하여 feature map을 입력 image와 동일한 크기로 만들어준다. ( 즉 ,복원 )
    • 주로 Bilinear upsampling을 사용
  • 6. CRF를 사용하여 후처리 

 

import tensorflow as tf
from tensorflow.keras.layers import Conv2D, MaxPooling2D, Input, Conv2DTranspose, Softmax, Concatenate, BatchNormalization, Add, ReLU
from tensorflow.keras.models import Model
from tensorflow.keras.utils import plot_model

def build_resnet101_backbone(input_shape):
    """
    ResNet-101 백본을 직접 구현하는 함수
    입력: 이미지 입력 크기 (input_shape)
    출력: ResNet-101의 마지막 convolution feature map
    """
    inputs = Input(shape=input_shape)
    
    # 초기 컨볼루션 블록
    x = Conv2D(64, (7, 7), strides=2, padding='same', activation='relu')(inputs)
    x = MaxPooling2D((3, 3), strides=2, padding='same')(x)

    def residual_block(x, filters, dilation_rate=1):
        shortcut = x
        if x.shape[-1] != filters * 4:
            shortcut = Conv2D(filters * 4, (1, 1), padding='same', activation='relu')(shortcut)

        x = Conv2D(filters, (1, 1), padding='same', activation='relu')(x)
        x = Conv2D(filters, (3, 3), padding='same', dilation_rate=dilation_rate, activation='relu')(x)
        x = Conv2D(filters * 4, (1, 1), padding='same', activation='relu')(x)

        x = Add()([shortcut, x])  # 채널 수가 맞아야 Add 가능
        x = ReLU()(x)
        return x

    
    for _ in range(23):  # ResNet-101의 Conv4 블록 반복 횟수
        x = residual_block(x, 256)
    
    return Model(inputs=inputs, outputs=x, name="ResNet101_Backbone")

def atrous_spatial_pyramid_pooling(x):
    """
    ASPP (Atrous Spatial Pyramid Pooling) 모듈 구현
    """
    atrous_rates = [6, 12, 18, 24]
    atrous_layers = [Conv2D(256, (3, 3), padding='same', dilation_rate=rate, activation='relu')(x) for rate in atrous_rates]
    
    # 1x1 convolution 추가
    conv_1x1 = Conv2D(256, (1, 1), activation='relu', padding='same')(x)
    
    # Global Average Pooling
    global_avg = tf.reduce_mean(x, axis=[1, 2], keepdims=True)
    global_avg = Conv2D(256, (1, 1), activation='relu', padding='same')(global_avg)
    global_avg = tf.image.resize(global_avg, (x.shape[1], x.shape[2]))
    
    # ASPP 결과 합치기
    x = Concatenate()([conv_1x1] + atrous_layers + [global_avg])
    x = Conv2D(256, (1, 1), activation='relu', padding='same')(x)
    return x

def build_deeplabv2_resnet101(input_shape=(512, 512, 3), num_classes=21):
    """
    DeepLab v2 모델을 ResNet-101 백본을 기반으로 구축하는 함수
    입력: 이미지 크기 (input_shape), 클래스 수 (num_classes)
    출력: DeepLab v2 모델
    """
    base_model = build_resnet101_backbone(input_shape)
    
    # ResNet-101의 마지막 convolution feature map 가져오기
    x = base_model.output
    
    # ASPP 적용
    x = atrous_spatial_pyramid_pooling(x)
    
    # Upsampling (8배 upsampling)
    x = Conv2DTranspose(num_classes, kernel_size=8, strides=8, padding='same')(x)
    x = Softmax()(x)
    
    model = Model(inputs=base_model.input, outputs=x)
    
    # 모델 구조 시각화 (Graphviz 사용)
    plot_model(model, to_file='deeplabv2_resnet101.png', show_shapes=True, show_layer_names=True)
    
    return model

# 모델 생성
model = build_deeplabv2_resnet101()

# 모델 구조 출력
model.summary()