A Probabilistic U-Net for Segmentation of Ambiguous Images
32nd Conference on Neural Information Processing Systems (NeurIPS 2018), Montréal, Canada
0. Abstract
- 다수의 real-world vision 문제는 본질적인 모호성(ambiguities)을 가진다.
- 예를 들어 clinical applications에서는, 오직 CT scan만으로 특정 영역이 cancer tissue인지 명확하지 않을 수 있다.
- 그래서 일반적으로 여러 명의 graders가 서로 다르지만 모두 그럴듯한(plausible) segmentations의 집합을 만든다.
- 우리는 입력이 주어졌을 때 segmentations에 대한 distribution을 학습하는 과제를 다룬다.
- 이를 위해, 무한히 많은 plausible hypotheses를 효율적으로 생성할 수 있는 U-Net과 conditional variational autoencoder를 결합한 generative segmentation model을 제안한다.
- 우리는 lung abnormalities segmentation task와 Cityscapes segmentation task에서,
제안 모델이 가능한 segmentation variants와 그 발생 빈도(frequencies)까지 재현하며, 공개된 기존 접근법들보다 유의미하게 더 나은 성능을 보임을 보였다. - 이러한 모델은 real-world 응용에서 높은 임팩트를 가질 수 있는데,
예를 들어 clinical decision-making algorithms로 사용되어, 다중의 plausible semantic segmentation hypotheses를 고려한 가능한 진단을 제공하고, 현재의 모호성을 해소하기 위한 후속 조치를 권고하는 데 활용될 수 있다.

- (a) sampling process
- 1. 입력 이미지는 X
- 화살표는 연산 흐름을 나타냄.
- 파란색 블록은 feature map을 나타낸다.
- 사전 분포 : prior ==> P( z | X )이다.
- 이미지 X만 보고 prior net이 예측한 latent z의 분포( 평균 , 분산 )
---> 학습이 끝난뒤 추론할때는 정답 Y가 없으니, z를 여기서 샘플링
- 이미지 X만 보고 prior net이 예측한 latent z의 분포( 평균 , 분산 )
- 사후분포 : posterior ==> Q( z | X , Y )
- 학습시에는 X와 정답 Y를 사용해서 "정답을 만드는 z"의 분포를 추정
---> CE손실로 예측을 맞추고, KL( Q || P )로 prior가 posterior와 가깝게 되도록 배운다.
- 학습시에는 X와 정답 Y를 사용해서 "정답을 만드는 z"의 분포를 추정
- 2. prior net ⇒ ( μ_prior , σ_prior )
- 왼쪽 위의 작은 파란 블록 스택이 prior net이다...
"입력 X "만을 보고, "latent 분포 P( z | X )"의 평균과 분산을 예측한다.- 오른쪽의 타원형 동심원 히트맵은 "P( z | X )"의 확률 밀도를 시각화( 축정렬 gaussian ).
- 왼쪽 위의 작은 파란 블록 스택이 prior net이다...
- 3. 회색톤의 heatmap은 "저차원 latent space R^N"에서의 확률분포를 표현한다. ( 실험에서 N=6 )
- N은 "latent space의 차원수"다. 그리고 이는 설정값(하이퍼 파라미터)이다.
- 한번 학습하면, 추론 시 N은 바뀌지 않는다. 추론시 바꿀수 있는건 "샘플 개수 m"이다.
- 4. latent space에서의 샘플링. ( z ∼ N ( μ_prior , diag( σ_prior ) ) )
- 네트워크를 한번 실행할 때마다, " z ∈ R^N "을 하나 sample하여, 하나의 segmentation mask를 예측한다.
- 회색 박스 안의 " z1 , z2 , z3, ... "가 각각 한번의 실행에서 뽑힌 latent 벡터다.
- 한 개의 z가 하나의 일관된 segmentation 가설을 대표한다.
- ( z = [ z1 , z2 , z3 , z4 , .... , z_N ]이다. 즉, z는 N개의 연속값을 가진 하나의 벡터다. )
- 5. U-Net feature추출
- 중앙의 큰 파란색 블록 스택이 U-Net이다.
- 이미지 X를 encoding/decoding하여, 마지막 activation map(고해상도 feature)을 만든다.
- 6. broadcasting z → N-channel feature map ( 초록색 블록 )
- 초록색 블록(sample z1 , z2 , z3 .... )은 sample z를 broadcasting하면서 만든 N-channel feature map이다.
- " z ∈ R^N "을 H*W의 공간 차원으로 복제(broadcast)해서, H * W * N 텐서를 만든다.
- 가독성을 위해서 그림 속 feature map블록의 개수는 실제보다 줄였다.
- 7. Mask 생성 ( 오른쪽 세로 이미지 1 2 3 .. )
- z를 바꿔가며 실행하면, 가설 1,2,3....처럼 다양한 segmetation mask가 생성된다.
- 같은 X라도 z가 달라지면, 병변의 경계,존재/부재 같은 전역적으로 일관된 차이가 나타난다.
- 1. 입력 이미지는 X
- (b) training process illustrated for one training example.
- 하나의 학습 샘플에 대한 training 과정을 도식화했다.
- 초록색 화살표는 손실함수들을 의미한다.
- 1. "입력 X"와 "Ground Truth Y"
- 왼쪽 아래에 입력 이미지 X , 오른쪽 끝에는 "predicted segmentation"과 "ground Truth"가 나란히 존재.
- 2. prior net ( 왼쪽 상단 )
- 패널 (a)와 동일하게 "p( z | X )"의 μ_prior , σ_prior를 출력한다.
- 이 분포는 추론 시 사용할 사전분포로 학습되어야 한다.
- 3. posterior Net ( 오른쪽 상단 )
- 상단의 파란 블록 스택이 posterior net이다.
- 여기서는 X와 Y를 모두 입력으로 받아서, 사후분포 Q( z | X , Y )의 μ_post , σ_post를 출력한다.
- 이 박스 옆의 "양방향 초록책 화살표(KL)"가 "D_KL( Q || P )"항을 의미.
- 4. z ~ Q( z | X , Y )샘플링 ---> U-Net으로 예측
- 학습 때는 Posterior에서 샘플링한 z를 사용한다. ( 정답 Y를 알고 있으니 "정답에 맞는 z"를 찾도록 유도)
- 이후의 과정은 a와 동일.
( U-Net feature ⊕ broadcast( z ) → 1×1 conv head → Predicted Segmentation.)
- 이후의 과정은 a와 동일.
- 학습 때는 Posterior에서 샘플링한 z를 사용한다. ( 정답 Y를 알고 있으니 "정답에 맞는 z"를 찾도록 유도)
- 5. 손실 ( 초록색 화살표 두 종류 )
- Cross-Entropy : 예측 S와 ST Y의 pixel-wise CE. 오른쪽 "cross-entropy"상자.
- KL Divergence : D_KL( Q( z | X , Y ) || P( z | X ) ). posterior가 prior에서 너무 멀어지지 않도록 당겨서, 평균적으로 가 “해당 이미지에서 나타날 수 있는 segmentation 가설들”을 커버하도록 만든다.
- 실제 최족 loss는 CE + (β-가중치).
1 Introduction
- Semantic segmentation 과제는 이미지의 각 pixel에 class label을 할당하는 것이다.
- 많은 경우 이미지의 context만으로도 이 매핑의 모호성(ambiguities)을 해소할 수 있지만,
전체 이미지 context를 모두 보더라도, 모든 모호성이 해소되지 않는 중요한 부류의 이미지들이 존재한다. - 이러한 모호성은 medical imaging 응용(예: CT 이미지에서 lung abnormalities segmentation)에서 흔하다.
- 병변(lesion)은 분명히 보일 수 있지만, 그것이 cancer tissue인지 여부는 이 이미지 하나만으로는 알 수 없을 수 있다.
- 유사한 모호성은 사진에서도 나타난다.
- 예를 들어 소파 아래로 보이는 fur의 일부가 고양이인지 강아지인지 이미지 자체만으로는 구분이 불가능할 수 있다.
- 대부분의 기존 segmentation 알고리즘은 하나의 그럴듯한 일관된 가설(예: “모든 pixel이 고양이에 속한다”)만 제공하거나, pixel-wise probability(예: “각 pixel이 50% 고양이, 50% 강아지”)만 제공한다.
- 예를 들어 소파 아래로 보이는 fur의 일부가 고양이인지 강아지인지 이미지 자체만으로는 구분이 불가능할 수 있다.
- 특히 segmentation map에 근거해 후속 diagnosis나 treatment가 결정되는 medical applications에서는,
가장 그럴듯한 단일 가설만 제공하는 알고리즘은 misdiagnoses와 sub-optimal treatment로 이어질 수 있다.
- 많은 경우 이미지의 context만으로도 이 매핑의 모호성(ambiguities)을 해소할 수 있지만,
- 오직 pixel-wise probabilities만 제공하는 접근법은 pixel들 사이의 모든 co-variances를 무시하므로,
그 이후의 분석을 매우 어렵거나 불가능하게 만든다.- 만약 여러 개의 일관된 가설들을 제공한다면, 이것들을 diagnosis pipeline의 다음 단계로 직접 전달할 수 있으며,
모호성을 해소하기 위한 추가 diagnostic tests를 제안하는 데 사용할 수 있고, 추가 정보에 접근 가능한 전문가가 이후 단계에 적합한 가설을 선택할 수도 있다.
- 만약 여러 개의 일관된 가설들을 제공한다면, 이것들을 diagnosis pipeline의 다음 단계로 직접 전달할 수 있으며,
- 본 논문에서는 모호한 이미지에 대해 다수의 segmentation hypotheses를 제공하는 segmentation framework를 제시한다(Fig. 1a).
- '저차원 latent space'가 '가능한 segmentation variants'를 인코딩하며,
이 공간에서 '무작위로 sampling한 값'을 U-Net에 주입하여 해당 'segmentation map'을 생성한다.- 이 아키텍처의 핵심 특징 중 하나는 "segmentation map의 모든 pixel에 대한 joint probability를 모델링할 수 있다"는 점이다.
- ( joint probability : 여러 사건( or 변수)이 동시에 특정한 값들을 가질 확률 )
- 즉, pixel-wise가 픽셀들 사이의 의존성을 반영 못한다면, joint probability는 전체 라벨 조합의 확률을 다룬다. 결과적으로 pixels 사이의 의존성을 반영할 수 있다.
- ( joint probability : 여러 사건( or 변수)이 동시에 특정한 값들을 가질 확률 )
- 이 아키텍처의 핵심 특징 중 하나는 "segmentation map의 모든 pixel에 대한 joint probability를 모델링할 수 있다"는 점이다.
- 그 결과, 각각이 전체 이미지를 일관되게 해석하는 다중 segmentation maps를 산출할 수 있다.
- 더 나아가 본 framework는 발생 확률이 낮은 가설들까지 학습하고, 그에 상응하는 빈도로 예측할 수 있다.
- 우리는 네 명의 전문가가 각 lesion을 독립적으로 segment한 lung abnormalities segmentation task와, 학습 중 일정 빈도로 라벨을 인위적으로 뒤집은 Cityscapes dataset에서 이러한 특성을 보인다.
- 확률적(probabilistic)이고 multi-modal한 segmentation을 위한 다양한 접근들이 존재한다.
- 가장 보편적인 접근은 pixel-wise probabilities를 독립적으로 제공한다 [7, 8].
- 이들 모델은 spatial features 위에 dropout을 사용하여 확률 분포를 유도한다.
- 이 전략은 pixel-wise uncertainty를 정량화한다는 해당 연구 흐름의 목적은 달성하지만, 일관적이지 않은(inconsistent) 출력을 낳는다.
- 그럴듯한 가설들을 생성하는 단순한 방법은 (deep) models의 ensemble을 학습하는 것이다 [9].
- Ensemble의 출력은 일관적일 수 있으나, 반드시 diverse하지는 않으며, 각 구성원이 독립적으로 학습되기 때문에 드문 variants를 학습하지 못하는 경우가 많다.
- 이를 극복하기 위해, 여러 접근은 oracle set loss [10], 즉 ground truth에 가장 가까운 예측만을 고려하는 loss로 모델들을 공동 학습한다.
- 이는 deep networks의 ensemble을 사용하는 [11], [1]과, 하나의 공통 deep network에 M heads를 두는 [12], [13]에서 탐구되었다.
- Multi-head 접근은 다양한 variants를 담을 capacity는 있을 수 있으나, 개별 variants의 발생 빈도(occurrence frequencies)를 학습하도록 설계되어 있지는 않다.
- Ensemble과 M heads 모델의 두 가지 공통 단점은 가설의 수가 커질 때 우아하게 확장되지 않는다는 점과, 학습 시 허용할 가설의 개수를 고정해야 한다는 점이다.
- 가장 보편적인 접근은 pixel-wise probabilities를 독립적으로 제공한다 [7, 8].
- 다양한 해를 생성하는 또 다른 접근 집합은 junction chains [14]와 보다 일반적으로 Markov Random Fields(MRF) [15, 16, 17, 18] 같은 graphical models에 의존한다.
- 많은 기존 접근은 가장 좋은 diverse solutions를 찾는 것을 보장하지만, tractable한 graphical model로 의존성(dependencies)을 기술할 수 있는 structured problems로 영역이 제한된다.
- Image-to-image translation [19] 과제는 매우 유사한 문제를 다룬다.
- 즉, under-constrained한 이미지 도메인 전이를 학습해야 한다.
- 최근 접근의 다수는 generative adversarial networks(GANs)를 사용하며, 이는 ‘mode-collapse’ [20] 같은 문제를 겪는 것으로 알려져 있다.
- Mode-collapse를 해결하려는 시도로 ‘bicycleGAN’ [21]은 우리의 것과 유사한 아키텍처 구성요소를 포함한다.
- 그러나 우리의 아키텍처와 달리, 그들의 모델은 고정 prior distribution을 사용하며, 학습 시 posterior distribution은 output image에만 condition된다.
- 매우 최근의, shape encoding이 주어졌을 때 appearances를 생성하는 연구 [22] 역시 U-Net과 VAE를 결합하며, 우리의 연구와 동시에 개발되었다.
- 다만 그들의 학습은 reconstruction loss로 사전학습된 VGG-net을 추가로 요구한다.
- 마지막으로 [23]에서는 ground truth와 예측 분포 간의 dissimilarity coefficient [24]를 최적화하는 것에 기반한 structured outputs용 probabilistic model이 제안되었다.
- 결과 접근은 14개 관절 위치를 예측하는 hand pose estimation 과제에 대해 평가되었는데, 이는 우리가 여기서 다루는 segmentation 공간과 비교하면 상대적으로 단순한 공간이라고 볼 수 있다.
- 아래에서 제시하는 접근과 유사하게, 그들 역시 네트워크 아키텍처의 후반부 단계에서 latent variables를 주입한다.
- 본 연구의 주요 공헌은 다음과 같다.
- (1) 본 framework는 pixel-wise probabilities 대신, '일관된 segmentation maps'를 제공하며, 따라서 modes의 joint likelihood를 산출할 수 있다.
- (2) 본 모델은 매우 희귀한 rare modes의 발생까지 포함하는 임의로 복잡한 output distributions를 유도할 수 있으며, segmentation modes에 대한 calibrated probabilities를 학습할 수 있다.
- (3) 본 모델에서 sampling은 계산 비용이 낮다.
- (4) 정성적 평가에만 머무는 경우가 많은 많은 기존 deep generative models 응용과 달리, 본 응용과 datasets는 누락된 modes에 대한 페널티를 포함한 정량적 성능 평가가 가능하다.
2 Network Architecture and Training Procedure
2-1. Network Architecture (Sampling)
- 제안하는 아키텍쳐는 CVAE와 U-Net의 결합이다.
- 이는 이미지에 "조건부 확률 모형"을 학습하는 것 목적이다.
- 핵심은 "저차원 latent space R^N"이다. 이는 공간의 한 점이 "하나의 segmentation vairant(가설)"을 나타낸다.
- prior net이 입력이미지 X에 대해서, 해당 variant들의 사전분포 P( z | X )를 추정한다.
P는 평균과 분산을 내는 축정렬 Gaussian으로 모델링된다.

- 같은 입력 이미지로 개의 segmentation을 예측하려면,
동일한 에 대해, 회 샘플링만 반복하면 된다(각 반복마다 네트워크의 일부분만 재계산하면 됨). - 샘플 z_i ∈ R^N은 broadcast되어 -채널의 feature map(세그멘테이션 맵과 동일한 공간 해상도)을 만든 뒤,
이를 U-Net의 마지막 activation map( f_(U-Net)(X;θ) )과 concatenate한다. - 이어서 1×1 convolution을 3층 쌓은 결합 모듈 f_comb(⋅;ψ)가 정보를 섞어 원하는 class 수의 출력으로 사상한다.
- 최종 출력 S_i는 latent 공간의 점 에 대응하는 세그멘테이션 맵은 아래 식2와 같다.

==> 주의: 동일 이미지에서 m개의 샘플을 뽑을 때 prior net의 출력과 U-Net의 activation은 재사용 가능하고, f_comb만 번 재평가하면 된다(계산 효율↑).
2-2. Network Architecture (Training)
- conditional VAE의 표준 절차로 end-to-end 학습한다(변분 하한을 최소화).
- 학습 시에는 GT 마스크 YY도 조건으로 쓰는 posterior net(파라미터 ν\nu)이 사후분포 Q(z∣X,Y)Q(z\mid X,Y)를 추정한다. 이것도 Gaussian:

- 를 샘플링해 U-Net activation과 결합하면 예측 SS가 나오며, 이는 학습 샘플의 YY에 가깝도록 해야 한다.
- 손실은 두 항의 가중합:
- Cross-Entropy loss — SS가 파라미터화하는 pixel-wise categorical distribution PcP_c에 대한 음의 로그우도(softmax CE)
- KL divergence — posterior QQ와 prior PP 사이의 DKL(Q∥P)D_{\mathrm{KL}}(Q\|P), 가중치 β\beta로 결합:


무작위 초기화에서 from scratch로 학습하며, KL 항은 posterior(개별 variant를 인코딩)와 prior(이미지 조건부 분포)를 서로 끌어당긴다. 많은 샘플에 대해 평균적으로 prior는 “주어진 이미지 XX”에서 나타날 수 있는 segmentation variants의 공간을 cover하도록 학습된다.
'논문' 카테고리의 다른 글
| [Trajectory] Resonance: Learning to Predict Social-Aware Pedestrian Trajectories as Co-Vibrations (0) | 2025.12.08 |
|---|---|
| [RRT] A Method of Enhancing Rapidly-Exploring Random Tree Robot Path Planning Using Midpoint Interpolation (0) | 2025.10.22 |
| [Trajectory] Y-net (0) | 2025.09.17 |
| [Trajectory] MUSE-VAE (0) | 2025.09.15 |
| [ Trajectory Prediction ]Socail LSTM (1) | 2025.09.10 |