[論文筆記]Prototypical Pseudo Label Denoising and Target Structure Learning for Domain Adaptive Semantic Segmentation
這篇是筆者目前看到在 UDA, semantic segmentation 領域表現最好的一篇,分數在加了知識蒸餾(knowledge distillation)後整整甩其他人一大截,用的方法主要是找各個類別在 feature space 的中心點(prototype),使同類別的 data point 接近這個中心點,讓 model 更好分類。
如果只是想快速知道這篇論文大概在做什麼,可以直接看 Overview 就好;
如果是想看詳細方法介紹,筆者會在後面說明此論文要解決的問題是什麼、介紹作者提出的方法,最後展示實驗結果。
Domain Knowledge
本篇會出現的一些常見的專有名詞會在這裡介紹,如果是熟悉這領域的大大可以直接跳過這段~
如果是對 UDA 和 semantic segmentation 不了解的讀者,可以看看這篇。
Overview
此篇論文主要的貢獻有四點:
- 利用 prototype,即時更新 soft pseudo labels,解決 pseudo label noisy 的問題。
- 提出一種 structure learning 的方法讓 model 能學習到緊湊的 target feature space,解決 model 通常都學習到分散的 target feature 的問題。
- 把已經 train 好的 UDA model distilled 到 student model 上可以讓分數表現得更好。
- 目前是 GTA5 Cityscapes 和 SYNTHIA Cityscapes 的任務上表現最好的 model。(2021/08/31)
Model Performance:
Task | Class Number | mIoU |
---|---|---|
GTA Cityscapes | 19 classes | 57.5 |
Synthia Cityscapes | 16 classes | 55.5 |
13 classes | 62.0 |
Methods
在 UDA 領域中,早期很流行使用一些 adversarial learning 的方法讓 source domain feature space 和 target domain feature space 能盡量接近,其中最有名的就是 AdaptSegNet。近年則是 self-training 盛行,讓 model 自己產生 pseudo labels 來自我訓練,但這方法目前存在兩個問題:
- 只選擇 confidence 高於某個嚴格閾值的預測來當作 pseudo label,結果不一定正確,導致 model 在 target domain 訓練失敗。
可以看到下圖 (a) 的左邊那張圖就呈現了這問題,正確來講藍色虛線圈裡的應該都要是被標記成 的 pseudo label,卻因為 decisoin boundary 的錯誤而被標記成 - 由於 source / target domain 差距很大,model 學習到的 target features 很常是分散的。
下圖的 (b) 可以看到即使 target features 已經是分開的了,但因為彼此太過分散,以至於有近 1/4 的 data 是分類錯誤的。
針對第一個問題,作者改用每個 pixel 跟 prototype 的相對特徵距離作為依據,並且在訓練過程中即時調整 prototype 跟 pseudo label,想法上更為直觀,pseudo label 也不會有過時的問題,但這樣相對計算量就增加許多,後面會介紹作者用了什麼方法來減輕計算量。
圖(a)的右邊那張圖就把原本比較接近類別 A 的標記為 ,decision boundary 改變,隨後再根據新一輪的 features 來更新 prototype。
至於第二個問題,作者把一張圖經過兩種不同的轉換方式分別餵進兩個同架構的 model,並要求 model 計算出相似的 prototype 位置,在 structure-learning-by-enforcing-consistency 會更詳細解釋他的做法。
公式符號表
- : source dataset
- : source segmentation labels
- : target dataset
- 個類別
- : 整體網路架構(segmentation network) = 特徵提取網路(feature extractor) + 分類器(classifier)
- : 第 i 個 target data feature
- : prototype, class k 的特徵中心點
- : 把 soft prediction 轉成 hard label 的函數表示
Prototypical pseudo label denoising
如果讓 pseudo labels 一個訓練階段更新一次,那麼 model 可能早就 overfit 在 noisy labels 上,但如果同時更新網路參數跟 pseudo labels 又會造成 trivial solution,因此 ProDA 採用固定 soft predictions(),根據與 prototype 的距離在訓練過程中生成給每個類別的權重(),並更新 hard pseudo label()。
權重公式如下:
- : 第 i 個 target data feature
- : momentum encoder, 的緩慢更新版本
- : prototype,類別 k 的特徵中心點
- : softmax temperature,這裡設為 1
假設現在某個 feature 離 prototype 很遠,表示它較不可能是屬於類別 k,它對類別 k 的權重就會比較小。相反的,如果距離 較近,權重就會比較大。
而權重 其實就是 feature 跟 prototype 的距離過 softmax function 的結果,分母為所有 features 跟所有類別的 prototype 的距離,分子為所有 features 跟代表類別 k 的 prototype 的距離。
prototype 公式如下:
- indicator function: 符合條件的輸出 1,不符合的為 0。在這條式子的條件是 pseudo label 。
從公式中可以看到 prototype 其實就是該類別的 features 平均,也就是中心點(centroid)的概念。但這樣每次計算都要跑過全部的 feature points,計算量龐大,因此 ProDA 是用 mini-batches 的中心點的 moving average(移動平均)取代計算。
- : 當下訓練批次類別 k 的中心點,來自 momentum encoder。
- = 0.9999
最後用 pseudo label 跟 model predictions 算出的 symmetric cross-entropy 作為 target domain 的 loss。
和 為平衡係數,。
Structure learning by enforcing consistency
為了使 target features 能夠更緊湊,作者對 target data 分別做了弱增強 和強增強 ,所謂增強其實是對圖片做一些轉化,而實際上 ProDA 的弱增強就是直接餵原圖給 model,強增強就會對圖片做一些旋轉、明暗度調整、彩度調整等等,詳情可以看他們的 github。
現在有了兩張圖片 和 後,把弱增強的圖片輸入 momentum encoder ,強增強的圖片輸入原始的 encoder ,分別計算 prototype 位置 和 (論文中稱他們為 soft prototypical assignment),並迫使 model 去降低這這兩個的 KL-divergence。
由於 是由弱增強的圖片計算得來,受到的干擾較小,計算出的 prototype 會較正確,用 去教導原本的 encoder 在吃到強增強的圖片後也能得出一樣的 prototype assignment,就表示 encoder 學習到更穩定、緊湊的 target features。而這種讓 model 學習如何得到跟另一個 model 一樣結果的訓練方式,被稱作 consistent learning。
為了防止在學習 target features 時出現 degeneration issue (有個類別的 cluster 是空的),需要再加上一個 regularization term,目的是鼓勵 model 輸出的類別能盡量平均,不要有一個機率總是特別高,或是特別低。
以下是最終的 loss function:
其中 。
Distillation to self-supervised model
在 收斂後,普通方法通常就到此收手,但 ProDA 更結合了知識蒸餾的概念,讓 student network 向 teacher network(同時也是第一階段使用的 model)學習。雖然在這裡student network 跟 teacher network 有完全相同的架構(一般知識蒸餾中 student network 會比 teacher network 還要小),唯一的差別在 student network 有先使用 Sim-CLRv2 pretrained weights 初始化。
以下用簡易的架構圖來說明 ProDA 的訓練模式:
ProDA 整個訓練總共有三個階段,階段一的目的在使 收斂,模型架構圖如 Figure 3 所示(這裡只畫了一個 loss 做為代表),而下方被灰色虛線匡著的 encoder 和 classifier 是最主要的 segmentation network ,也是在接下來的知識蒸餾過程中被視為 teacher network 的網路。
階段二和三則是知識蒸餾的過程,會拿階段一訓練好的模型當作 teacher model,把知識透過降低 knowledge distillation loss 來傳給 student model 。作法除了降低 source domain 的 cross entropy 外,還包含 student model predictions 和 teacher network hard labels 的 cross entropy 跟這兩個 model predictions 的 KL-divergence。
Konwledge distillation loss
Experiments
Main Results
ProDA 不論是在 GTA5 Cityscapes 還是 SYNTHIA Cityscapes 的任務上都是表現最好的。他們使用的是 ResNet-101 + Deeplab-v2 的架構,需要四片 Tesla V100 GPU 來訓練,雖然作者並未公布他們訓練的時間,不過就筆者自己訓練的經驗來看大概需要兩天多的時間,有興趣的讀者可以去 Github 把模型載下來跑跑看。對這些 datasets 不熟悉的可以看這裡
GTA5 - Cityscapes
STNTHIA - Cityscapes
Ablation Study
Effectiveness of the Methods
- 加上 warm-up : 41.6 mIoU (+5)
- 加上 offline pseudo labels : 45.2 mIoU (+8.6)
- 加上 symmetric cross-entropy : 45.6 mIoU (+9.0)
- 加上 prototypical denoising : 52.3 mIoU (+15.7)
Target structure learning 藉由學習緊湊的 target fature cluster 來協助偽標籤能不受雜訊干擾,並提升 1.4 mIoU。
- self-supervised > supervised initialization
- 使用此方法初始化可以避免模型最後收斂在 local optima
- stage 2 + 3 證明了知識蒸餾的有效性
The UMAP visualization of target features
為了證明 target features 真的有學得比較好,ProDA 還提供了視覺化的 target features。
(a)單純把訓練在 source domain 的模型拿到 target domain 上做訓練,target features 的分佈。
(b)傳統自訓練方法可以讓 target features 分開一些,但離可以 linear classfication 的程度還是有些遠。
(c)經過 prototypical denoising 後四個類別的 features 被分開,已經可以用兩條線大致區分出四個類別,但同類別的 features 還是分散的。
(d)可以看出,ProDA 比起左邊的三個版本,可以更好的分開不同類別的 features,同類別的 features 也比較聚集。
Conclusion
筆者在這篇文章中帶大家了解目前 UDA semantic segmentation 領域表現最好的論文是用了哪些方法,也詳細介紹了各個公式,希望大家對 ProDA 有更深的了解。
在 ProDA 提出的方法中,筆者認為知識蒸餾的概念最為有趣,已經訓練好的模型經過近一步的「蒸餾」後竟然能讓表現變得更好,不知道現階段其他的模型架構是否也能透過同樣的方式讓表現有所提升?就等大家來嘗試啦!
Reference:
- https://zhuanlan.zhihu.com/p/102038521
- Prototypical Pseudo Label Denoising and Target Structure Learning for Domain Adaptive Segmantic Segmentation
有任何問題都歡迎在下面提出,喜歡這篇文章的話可以幫我點一個讚,
祝福各位能在機器學習領域走出屬於自己的一條路。