這篇是筆者目前看到在 UDA, semantic segmentation 領域表現最好的一篇,分數在加了知識蒸餾(knowledge distillation)後整整甩其他人一大截,用的方法主要是找各個類別在 feature space 的中心點(prototype),使同類別的 data point 接近這個中心點,讓 model 更好分類。

如果只是想快速知道這篇論文大概在做什麼,可以直接看 Overview 就好;
如果是想看詳細方法介紹,筆者會在後面說明此論文要解決的問題是什麼、介紹作者提出的方法,最後展示實驗結果。

Domain Knowledge

本篇會出現的一些常見的專有名詞會在這裡介紹,如果是熟悉這領域的大大可以直接跳過這段~

如果是對 UDAsemantic segmentation 不了解的讀者,可以看看這篇

Pseudo Label

這是近年來在半監督學習(semi-supervised learning)非常常用的技術,主要用在缺少標記資料的情況,把模型預測的高可信度標籤當作一種「偽標籤」,並再度使用這些偽標籤讓模型學習。Pseudo label 也常應用在 UDA 的模型,因為具有高可信度的預測通常是準確的,這樣模型就能更好的學習到 target domain 的 feature。
現在應用 pseudo label 的方式有非常多種,這裡舉一個最常見的步驟來介紹:

  1. 使用有標記的資料訓練模型。
  2. 讓模型對無標記的資料做預測,並選擇預測機率高過某個值的樣本當作 pseudo label。
  3. 使用有標記的資料和 pseudo label 訓練新的模型。
  4. 重複 2 跟 3 直到模型表現不再提升。

Knowledge Distillation

Knowledge Distillation(知識蒸餾)在論文 Distilling the Knowledge in a Neural Network 被正式提出和應用,其概念就是透過一個已訓練好的網路(teacher network)來教導一個較小的網路來學習(student network),而學習的方法常常是讓 student network 去學習 teacher network 的 soft predictions,會在最後一層使用加上參數 temperature TT 的 softmax function,當 TT 越高,會使概率分佈越平均,選擇適當的 TT,就能讓 student network 學習到除了 hard label 以外的重要資訊。

qi=exp(zi/T)jexp(zj/T)q_i=\frac{exp(z_i/T)}{\sum_j exp(z_j/T)}

這樣把大網路的知識傳給小網路的做法,能讓小網路達到比 train from scratch 還要好的表現,而且小網路本身也更容易應用、不需要太多的硬體資源跟運算資源,在 computer vision, object detection, natural language processing 都能看到知識蒸餾被納入訓練方法之中。

Overview

標題:Prototypical Pseudo Label Denoising and Target Structure Learning for Domain Adaptive Semantic Segmentation
作者:Pan Zhang, Bo Zhang, Ting Zhang, Dong Chen, Yong Wang1, Fang Wen (University of Science and Technology of China, Microsoft Research Asia)
年份:2021
榮譽:CVPR 2021

此篇論文主要的貢獻有四點:

  1. 利用 prototype,即時更新 soft pseudo labels,解決 pseudo label noisy 的問題。
  2. 提出一種 structure learning 的方法讓 model 能學習到緊湊的 target feature space,解決 model 通常都學習到分散的 target feature 的問題。
  3. 把已經 train 好的 UDA model distilled 到 student model 上可以讓分數表現得更好。
  4. 目前是 GTA5 \rarr Cityscapes 和 SYNTHIA \rarr Cityscapes 的任務上表現最好的 model。(2021/08/31)

Model Performance:

Task Class Number mIoU
GTA \rarr Cityscapes 19 classes 57.5
Synthia \rarr Cityscapes 16 classes 55.5
13 classes 62.0

Methods

在 UDA 領域中,早期很流行使用一些 adversarial learning 的方法讓 source domain feature space 和 target domain feature space 能盡量接近,其中最有名的就是 AdaptSegNet。近年則是 self-training 盛行,讓 model 自己產生 pseudo labels 來自我訓練,但這方法目前存在兩個問題:

  1. 只選擇 confidence 高於某個嚴格閾值的預測來當作 pseudo label,結果不一定正確,導致 model 在 target domain 訓練失敗。
    可以看到下圖 (a) 的左邊那張圖就呈現了這問題,正確來講藍色虛線圈裡的應該都要是被標記成 ++ 的 pseudo label,卻因為 decisoin boundary 的錯誤而被標記成 -
  2. 由於 source / target domain 差距很大,model 學習到的 target features 很常是分散的。
    下圖的 (b) 可以看到即使 target features 已經是分開的了,但因為彼此太過分散,以至於有近 1/4 的 data 是分類錯誤的。

Figure 1. The existing issues of self-training by visualizing the feature space.

針對第一個問題,作者改用每個 pixel 跟 prototype 的相對特徵距離作為依據,並且在訓練過程中即時調整 prototype 跟 pseudo label,想法上更為直觀,pseudo label 也不會有過時的問題,但這樣相對計算量就增加許多,後面會介紹作者用了什麼方法來減輕計算量。
圖(a)的右邊那張圖就把原本比較接近類別 A 的標記為 ++,decision boundary 改變,隨後再根據新一輪的 features 來更新 prototype。

至於第二個問題,作者把一張圖經過兩種不同的轉換方式分別餵進兩個同架構的 model,並要求 model 計算出相似的 prototype 位置,在 structure-learning-by-enforcing-consistency 會更詳細解釋他的做法。

公式符號表
  • Xs={xs}j=1ns\mathcal{X}_s = \{x_s\}^{n_s}_{j=1} : source dataset
  • Ys={ys}j=1ns\mathcal{Y}_s = \{y_s\}^{n_s}_{j=1} : source segmentation labels
  • Xt={xt}j=1nt\mathcal{X}_t = \{x_t\}^{n_t}_{j=1} : target dataset
  • KK 個類別
  • h=fgh = f \circ g : 整體網路架構(segmentation network)hh = 特徵提取網路(feature extractor)ff + 分類器(classifier)gg
  • f(xt)(i)f(x_t)^{(i)}: 第 i 個 target data feature
  • η(k)\eta^{(k)}: prototype, class k 的特徵中心點
  • ξ()\xi(\cdot): 把 soft prediction 轉成 hard label 的函數表示

Prototypical pseudo label denoising

如果讓 pseudo labels 一個訓練階段更新一次,那麼 model 可能早就 overfit 在 noisy labels 上,但如果同時更新網路參數跟 pseudo labels 又會造成 trivial solution,因此 ProDA 採用固定 soft predictions(ptp_t),根據與 prototype 的距離在訓練過程中生成給每個類別的權重(wtw_t),並更新 hard pseudo label(y^t\hat{y}_t

y^t(i,k)=ξ(ωt(i,k)pt,0(i,k))\hat{y}_t^{(i,k)}=\xi(\omega_t^{(i,k)}p_{t,0}^{(i,k)})

Soft Predictions & Hard Labels
Soft prediction 是 model 計算出 data 屬於某個類別的機率,hard label 代表 model 計算出 data 所屬的類別
例如,要預測一張圖片屬於貓(label: 0)還是狗(label: 1),
輸出形式為 soft predictions 會長這樣:[0.9,0.1][0.9, 0.1]
輸出形式為 hard predictions 會長這樣:00(貓的 label)。

權重公式如下:

ωt(i,k)=exp(f~(xt)(i)η(k)/τkexp(f~(xt)(i)η(k)/τ)\omega_t^{(i,k)}=\frac{exp(-||\tilde{f}(x_t)^{(i)}-\eta^{(k)}||/\tau}{\sum_{k'}exp(-||\tilde{f}(x_t)^{(i)}-\eta^{(k')}||/\tau)}

  • f(xt)(i)f(x_t)^{(i)}: 第 i 個 target data feature
  • f~\tilde{f}: momentum encoder,ff 的緩慢更新版本
  • η(k)\eta^{(k)}: prototype,類別 k 的特徵中心點
  • τ\tau: softmax temperature,這裡設為 1

假設現在某個 feature 離 prototype η(k)\eta^{(k)} 很遠,表示它較不可能是屬於類別 k,它對類別 k 的權重就會比較小。相反的,如果距離 η(k)\eta^{(k)}較近,權重就會比較大。

而權重 ωt(i,k)\omega_t^{(i,k)} 其實就是 feature f~(xt)(i)\tilde{f}(x_t)^{(i)} 跟 prototype η(k)\eta^{(k)} 的距離過 softmax function 的結果,分母為所有 features 跟所有類別的 prototype 的距離,分子為所有 features 跟代表類別 k 的 prototype 的距離。

prototype 公式如下:

  • indicator function: 符合條件的輸出 1,不符合的為 0。在這條式子的條件是 pseudo label y^t(i,k)==1\hat{y}_t^{(i,k)} == 1

從公式中可以看到 prototype 其實就是該類別的 features 平均,也就是中心點(centroid)的概念。但這樣每次計算都要跑過全部的 feature points,計算量龐大,因此 ProDA 是用 mini-batches 的中心點的 moving average(移動平均)取代計算。

η(k)λη(k)+(1λ)η(k)\eta^{(k)}\leftarrow \lambda\eta^{(k)}+(1-\lambda)\eta'^{(k)}

  • η(k)\eta'^{(k)} : 當下訓練批次類別 k 的中心點,來自 momentum encoder。
  • λ\lambda = 0.9999

最後用 pseudo label 跟 model predictions 算出的 symmetric cross-entropy 作為 target domain 的 loss。

lscet=αlce(pt,y^t)+βlce(y^t,pt)l^t_{sce}=\alpha l_{ce}(p_t, \hat{y}_t)+\beta l_{ce}(\hat{y}_t, p_t)

α\alphaβ\beta 為平衡係數,α=0.1,β=1\alpha = 0.1, \beta= 1

Symmetric cross-entropy (SCE)
出自 ICCV 2019 的 Symmetric Cross Entropy for Robust Learning with Noisy Labels,作者結合了傳統的 cross-entropy 和 reverse cross entropy (RCE) 得到 SCE,並透過實驗證明使用這樣的損失函數能讓模型對 noisy label 更 robust,也能收斂的更快更好,更詳細的說明請參考論文

Structure learning by enforcing consistency

為了使 target features 能夠更緊湊,作者對 target data xtx_t 分別做了弱增強 T(xt)\mathcal{T}(x_t) 和強增強 T(xt)\mathcal{T'}(x_t),所謂增強其實是對圖片做一些轉化,而實際上 ProDA 的弱增強就是直接餵原圖給 model,強增強就會對圖片做一些旋轉、明暗度調整、彩度調整等等,詳情可以看他們的 github

現在有了兩張圖片 T(xt)\mathcal{T}(x_t)T(xt)\mathcal{T'}(x_t) 後,把弱增強的圖片輸入 momentum encoder f~\tilde{f},強增強的圖片輸入原始的 encoder ff,分別計算 prototype 位置 zTz_\mathcal{T}zTz_\mathcal{T'}(論文中稱他們為 soft prototypical assignment),並迫使 model 去降低這這兩個的 KL-divergence。

zT(i,k)=exp(f~(T(xt))(i)η(k)/τkexp(f~(T(xt))(i)η(k)/τ)z_{\mathcal{T}}^{(i,k)}=\frac{exp(-||\tilde{f}(\mathcal{T}(x_t))^{(i)}-\eta^{(k)}||/\tau}{\sum_{k'}exp(-||\tilde{f}(\mathcal{T}(x_t))^{(i)}-\eta^{(k')}||/\tau)}

Soft prototypical assignment formula. (for zTz_\mathcal{T'}, just change f~(T(xt)))\tilde{f}(\mathcal{T}(x_t))) to f(T(xt))f(\mathcal{T'}(x_t)).)

lklt=KL(zTzT)l_{kl}^t=KL(z_{\mathcal{T}}||z_{\mathcal{T}'})

KL divergence between the prototypical assignments under two views.

由於 zTz_\mathcal{T} 是由弱增強的圖片計算得來,受到的干擾較小,計算出的 prototype 會較正確,用 zTz_\mathcal{T} 去教導原本的 encoder 在吃到強增強的圖片後也能得出一樣的 prototype assignment,就表示 encoder 學習到更穩定、緊湊的 target features。而這種讓 model 學習如何得到跟另一個 model 一樣結果的訓練方式,被稱作 consistent learning

Figure 2. Model overview for structure learning by enforcing consistency. (由於論文裡沒有給 ProDA 的模型架構圖,這裡筆者自己畫了一張。)

Momentum Encoder
這詞出自於 2020 Facebook AI 發表的 MoCo,可以看作是更新比較緩慢的 encoder,每次參數更新後只會靠近原本的 encoder 一點點,可以參考下列 momentum encoder 更新公式(ProDA momentum encoder 也採用同樣的更新方式):

θme=mθme+(1m)θe\theta_{me} = m\theta_{me} + (1-m)\theta_{e}

其中 θme\theta_{me} 為 momentum encoder 的參數,\theta_{e} 為 encoder 的參數,m 是可調整的超參數,通常會設為近似 1 的值。

為了防止在學習 target features 時出現 degeneration issue (有個類別的 cluster 是空的),需要再加上一個 regularization term,目的是鼓勵 model 輸出的類別能盡量平均,不要有一個機率總是特別高,或是特別低。

lregt=i=1H×Wj=1Klogpt(i,k)l_{reg}^t=-\sum_{i=1}^{H\times W}\sum_{j=1}^Klogp_t^{(i,k)}

以下是最終的 loss function:

ltotal=lces+lscet+γ1lklt+γ2lregtl_{total}=l^s_{ce}+l^t_{sce}+\gamma_1l^t_{kl}+\gamma_2l^t_{reg}

其中 γ1=10,γ2=0.1\gamma_1 = 10, \gamma_2 = 0.1

Distillation to self-supervised model

ltotall_{total} 收斂後,普通方法通常就到此收手,但 ProDA 更結合了知識蒸餾的概念,讓 student network 向 teacher network(同時也是第一階段使用的 model)學習。雖然在這裡student network 跟 teacher network 有完全相同的架構(一般知識蒸餾中 student network 會比 teacher network 還要小),唯一的差別在 student network 有先使用 Sim-CLRv2 pretrained weights 初始化。

以下用簡易的架構圖來說明 ProDA 的訓練模式:

Figure 3. ProDA model overview - stage 1 (註:此圖也是 lscetl^t_{sce} 是如何被計算出的模型架構圖,想更了解的讀者可以對照著 Prototypical pseudo label denoising 看。

ProDA 整個訓練總共有三個階段,階段一的目的在使 ltotall_{total} 收斂,模型架構圖如 Figure 3 所示(這裡只畫了一個 loss 做為代表),而下方被灰色虛線匡著的 encoder ff 和 classifier gg 是最主要的 segmentation network hh,也是在接下來的知識蒸餾過程中被視為 teacher network 的網路。

Figure 4. ProDA model overview - stage 2 + 3. 咖啡色的箭頭代表跟 source domain 有關,黑色的箭頭則跟 target domain 有關。

階段二和三則是知識蒸餾的過程,會拿階段一訓練好的模型當作 teacher model,把知識透過降低 knowledge distillation loss lKDl_{KD} 來傳給 student model hh^\dag。作法除了降低 source domain 的 cross entropy 外,還包含 student model predictions 和 teacher network hard labels 的 cross entropy 跟這兩個 model predictions 的 KL-divergence。

Konwledge distillation loss

lKD=lces(ps,ys)+lcet(pt,ξ(pt))+βKL(ptpt)l_{KD}=l^s_{ce}(p_s,y_s)+l^t_{ce}(p^\dag_t, \xi(p_t))+\beta KL(p_t||p_t^\dag)

Experiments

Main Results

ProDA 不論是在 GTA5 \rarr Cityscapes 還是 SYNTHIA \rarr Cityscapes 的任務上都是表現最好的。他們使用的是 ResNet-101 + Deeplab-v2 的架構,需要四片 Tesla V100 GPU 來訓練,雖然作者並未公布他們訓練的時間,不過就筆者自己訓練的經驗來看大概需要兩天多的時間,有興趣的讀者可以去 Github 把模型載下來跑跑看。對這些 datasets 不熟悉的可以看這裡

GTA5 - Cityscapes

STNTHIA - Cityscapes

Ablation Study

Effectiveness of the Methods

Figure 5. Ablation study of each proposed component on GTA5-Cityscapes. ST stands for self-training, PD for prototypical denoising, and SL for structure learning.

  • 加上 warm-up : 41.6 mIoU (+5)
  • 加上 offline pseudo labels : 45.2 mIoU (+8.6)
  • 加上 symmetric cross-entropy : 45.6 mIoU (+9.0)
  • 加上 prototypical denoising : 52.3 mIoU (+15.7)

Target structure learning 藉由學習緊湊的 target fature cluster 來協助偽標籤能不受雜訊干擾,並提升 1.4 mIoU。

  • self-supervised > supervised initialization
  • 使用此方法初始化可以避免模型最後收斂在 local optima
  • stage 2 + 3 證明了知識蒸餾的有效性

The UMAP visualization of target features

為了證明 target features 真的有學得比較好,ProDA 還提供了視覺化的 target features。
(a)單純把訓練在 source domain 的模型拿到 target domain 上做訓練,target features 的分佈。
(b)傳統自訓練方法可以讓 target features 分開一些,但離可以 linear classfication 的程度還是有些遠。
(c)經過 prototypical denoising 後四個類別的 features 被分開,已經可以用兩條線大致區分出四個類別,但同類別的 features 還是分散的。
(d)可以看出,ProDA 比起左邊的三個版本,可以更好的分開不同類別的 features,同類別的 features 也比較聚集。

Conclusion

筆者在這篇文章中帶大家了解目前 UDA semantic segmentation 領域表現最好的論文是用了哪些方法,也詳細介紹了各個公式,希望大家對 ProDA 有更深的了解。
在 ProDA 提出的方法中,筆者認為知識蒸餾的概念最為有趣,已經訓練好的模型經過近一步的「蒸餾」後竟然能讓表現變得更好,不知道現階段其他的模型架構是否也能透過同樣的方式讓表現有所提升?就等大家來嘗試啦!

Reference:

有任何問題都歡迎在下面提出,喜歡這篇文章的話可以幫我點一個讚,
祝福各位能在機器學習領域走出屬於自己的一條路。