大一統(tǒng)視角理解擴(kuò)散模型Understanding Diffusion Models: A Unified Perspective(1)
這篇文章是近期筆者閱讀擴(kuò)散模型的一些技術(shù)博客和概覽的一篇梳理。主要參考的內(nèi)容來自Calvin luo的論文,針對(duì)的對(duì)象主要是對(duì)擴(kuò)散模型已經(jīng)有一些基礎(chǔ)了解的讀者。Calvin luo 的這篇論文為理解擴(kuò)散模型提供了一個(gè)統(tǒng)一的視角,尤其是其中的數(shù)理公式推導(dǎo)非常詳盡,本文將試圖盡量簡要地概括一遍大一統(tǒng)視角下的擴(kuò)散模型的推導(dǎo)過程。在結(jié)尾處,筆者附上了一些推導(dǎo)過程中的強(qiáng)假設(shè)的思考和疑惑,并簡要討論了下擴(kuò)散模型應(yīng)用在自然語言處理時(shí)的一些思考。
本篇閱讀筆記一共參考了以下技術(shù)博客。其中如果不了解擴(kuò)散模型的讀者可以考慮先閱讀lilian-weng的科普博客。Calvin-Luo的這篇介紹性論文在書寫的時(shí)候經(jīng)過了包括Jonathan Ho(DDPM作者), SongYang博士 和一系列相關(guān)擴(kuò)散模型論文的發(fā)表者的審核,非常值得一讀。
1. What are Diffusion Models? by Lilian Weng
2. Generative Modeling by Estimating Gradients of the Data Distribution by Song Yang
3. Understanding Diffusion Models: A Unified Perspective by Calvin Luo
生成模型希望可以生成符合真實(shí)分布(或給定數(shù)據(jù)集)的數(shù)據(jù)。我們常見的幾種生成模型有GANs,F(xiàn)low-based Models, VAEs, Energy-Based Models 以及我們今天希望討論的擴(kuò)散模型Diffusion Models. 其中擴(kuò)散模型和變分自編碼器VAEs, 和基于能量的模型EBMs有一些聯(lián)系和區(qū)別,筆者會(huì)在接下來的章節(jié)闡述。
常見的幾種生成模型
在介紹擴(kuò)散模型前,我們先來回顧一下變分自編碼器VAE。我們知道VAE最大的特點(diǎn)是引入了一個(gè)潛在向量的分布來輔助建模真實(shí)的數(shù)據(jù)分布。那么為什么我們要引入潛在向量?有兩個(gè)直觀的原因,一個(gè)是直接建模高維表征十分困難,常常需要引入很強(qiáng)的先驗(yàn)假設(shè)并且有維度詛咒的問題存在。另外一個(gè)是直接學(xué)習(xí)低維的潛在向量,一方面起到了維度壓縮的作用,一方面也希望能夠在低維空間上探索具有語義化的結(jié)構(gòu)信息(例如圖像領(lǐng)域里的GAN往往可以通過操控具體的某個(gè)維度影響輸出圖像的某個(gè)具體特征)。
引入了潛在向量后,我們可以將我們的目標(biāo)分布的對(duì)數(shù)似然logP(x),也稱為“證據(jù)evidence“寫成下列形式:
ELBO的推理過程
其中,我們重點(diǎn)關(guān)注式15. 等式的左邊是生成模型想要接近的真實(shí)數(shù)據(jù)分布(evidence),等式右邊由兩項(xiàng)組成,其中第二項(xiàng)的KL散度因?yàn)楹愦笥诹悖圆坏仁胶愠闪?。如果在等式右邊減去該KL散度,則我們得到了真實(shí)數(shù)據(jù)分布的下界,即證據(jù)下界ELBO。對(duì)ELBO進(jìn)行進(jìn)一步的展開,我們就可以得到VAE的優(yōu)化目標(biāo)
ELBO等式的展開
對(duì)該證據(jù)下界的變形的形式,我們可以直觀地這么理解:證據(jù)下界等價(jià)于這么一個(gè)過程,我們用編碼器將輸入x編碼為一個(gè)后驗(yàn)的潛在向量分布q(z|x)。我們希望這個(gè)向量分布盡可能地和真實(shí)的潛在向量分布p(z)相似,所以用KL散度約束,這也可以避免學(xué)習(xí)到的后驗(yàn)分布q(z|x)坍塌成一個(gè)狄拉克delta函數(shù)(式19的右側(cè))。而得到的潛在向量我們用一個(gè)****重構(gòu)出原數(shù)據(jù),對(duì)應(yīng)的是式19的左邊P(x|z)。
VAE為什么叫變分自編碼器。變分的部分來自于尋找最優(yōu)的潛在向量分布q(z|x)的這個(gè)過程。自編碼器的部分是上面提到的對(duì)輸入數(shù)據(jù)的編碼,再解碼為原數(shù)據(jù)的行為。
那么提煉一下為什么VAE可以比較好地貼合原數(shù)據(jù)的分布?因?yàn)楦鶕?jù)上述的公式推導(dǎo)我們發(fā)現(xiàn):原數(shù)據(jù)分布的對(duì)數(shù)似然(稱為證據(jù)evidence)可以寫成證據(jù)下界加上我們希望近似的后驗(yàn)潛在向量分布和真實(shí)的潛在向量分布間的KL散度(即式15)。如果把該式寫為A = B+C的形式。因?yàn)閑vidence(即A)是個(gè)常數(shù)(與我們要學(xué)習(xí)的參數(shù)無關(guān)),所以最大化B,也就是我們的證據(jù)下界,等價(jià)于最小化C,也即是我們希望擬合的分布和真實(shí)分布間的差別。而因?yàn)樽C據(jù)下界,我們可以重新寫成式19那樣一個(gè)自編碼器的形式,我們也就得到了自編碼器的訓(xùn)練目標(biāo)。優(yōu)化該目標(biāo),等價(jià)于近似真實(shí)數(shù)據(jù)分布,也等價(jià)于用變分手法來優(yōu)化后驗(yàn)潛在向量分布q(z|x)的過程。
但VAE自身依然有很多問題。一個(gè)最明顯的就是我們?nèi)绾芜x定后驗(yàn)分布q_phi(z|x)。絕大多數(shù)的VAE實(shí)現(xiàn)里,這個(gè)后驗(yàn)分布被選定為了一個(gè)多維高斯分布。但這個(gè)選擇更多的是為了計(jì)算和優(yōu)化的方便而選擇。這樣的簡單形式極大地限制了模型逼近真實(shí)后驗(yàn)分布的能力。VAE的原作者kingma曾經(jīng)有篇非常經(jīng)典的工作就是通過引入normalization flow[1]在改進(jìn)后驗(yàn)分布的表達(dá)能力。而擴(kuò)散模型同樣可以看做是對(duì)后驗(yàn)分布q_phi(z|x)的改進(jìn)。
Hierarchical VAE下圖展示了一個(gè)變分自編碼器里,潛在向量和輸入間的閉環(huán)關(guān)系。即從輸入中提取低維的潛在向量后,我們可以通過這個(gè)潛在向量重構(gòu)出輸入。
VAE里潛在向量與輸入的關(guān)系
很明顯,我們認(rèn)為這個(gè)低維的潛在向量里一定是高效地編碼了原數(shù)據(jù)分布的一些重要特性,才使得我們的****可以成功重構(gòu)出原數(shù)據(jù)分布里的各式數(shù)據(jù)。那么如果我們遞歸式地對(duì)這個(gè)潛在向量再次計(jì)算“潛在向量的潛在向量”,我們就得到了一個(gè)多層的HVAE,其中每一層的潛在向量條件于所有前序的潛在向量。但是在這篇文章里,我們主要關(guān)注具有馬爾可夫性質(zhì)的層級(jí)變分自編碼器MHVAE,即每一層的潛在向量僅條件于前一層的潛在向量。
MHVAE里的潛在向量只條件于上一層
對(duì)于該MHVAE,我們可以通過馬爾可夫假設(shè)得到以下二式
23和24式是用鏈?zhǔn)椒▌t對(duì)依賴圖里的關(guān)系的拆解
對(duì)于該MHVAE,我們可以用以下步驟推導(dǎo)其證據(jù)下界
MHVAE的變分下界推導(dǎo)
我們之所以在談?wù)摂U(kuò)散模型之前,要花如此大的篇幅介紹VAE,并引出MHVAE的證據(jù)下界推導(dǎo)是因?yàn)槲覀兛梢苑浅W匀坏貙U(kuò)散模型視為一種特殊的MHVAE,該MHVAE滿足以下三點(diǎn)限制(注意以下三點(diǎn)限制也是整個(gè)擴(kuò)散模型推斷的基礎(chǔ)):
- 潛在向量Z的維度和輸入X的維度保持一致。
- 每一個(gè)時(shí)間步的潛在向量都被編碼為一個(gè)僅依賴于上一個(gè)時(shí)間步的潛在向量的高斯分布。
- 每一個(gè)時(shí)間步的潛在向量的高斯分布的參數(shù),隨時(shí)間步變化,且滿足最終時(shí)間步的高斯分布滿足標(biāo)準(zhǔn)高斯分布的限制。
因?yàn)榈谝稽c(diǎn)維度一致的原因,在不影響理解的基礎(chǔ)上,我們將MHVAE里的Zt表示為Xt(其中x0為原始輸入),則我們可以將MHVAE的層級(jí)潛在向量依賴圖,重新畫為以下形式(即將擴(kuò)散模型的中間擴(kuò)散過程當(dāng)做潛在向量的層級(jí)建模過程):
擴(kuò)散過程的直觀解釋:在數(shù)據(jù)x0上不斷加高斯噪聲直至退化為純?cè)肼晥D像Xt
直至這里,我們終于見到了我們熟悉的擴(kuò)散模型的形式。
而在將上面的公式25-28里的Zt與Xt替換后,我們可以得到VDM里證據(jù)下界的推導(dǎo)公式里的前四行,即公式34-37。并且在此基礎(chǔ)上,我們可以繼續(xù)往下推導(dǎo)。37至38行的變換是鏈?zhǔn)椒▌t的等價(jià)替換(或上述公式23和24的變換),38至39行是連乘過程的重組,39至40行是對(duì)齊連乘符號(hào)的區(qū)間,40至41行應(yīng)用了Log乘法的性質(zhì),41至42繼續(xù)運(yùn)用該性質(zhì)進(jìn)一步拆分,42至43行是因?yàn)楹偷钠谕扔谄谕暮停?3至44是因?yàn)槠谕繕?biāo)與部分時(shí)間步的概率無關(guān)可以直接省去,44至45步是應(yīng)用了KL散度的定義進(jìn)行了重組。
VDM的證據(jù)下界推導(dǎo)
至此,我們又一次將原數(shù)據(jù)分布的對(duì)數(shù)似然,轉(zhuǎn)化為了證據(jù)下界(公式37),并將其轉(zhuǎn)化為了幾項(xiàng)非常直觀的損失函數(shù)的加和形式(公式45),他們分別為:
- 重構(gòu)項(xiàng),即從潛在向量x1到原數(shù)據(jù)x0的變化。在VAE里該重構(gòu)項(xiàng)寫為logP(x|z),而在這里我們寫做logP(x0|x1)
- 先驗(yàn)匹配項(xiàng)?;貞浳覀兩鲜鎏岬降腗HVAE里最終時(shí)間步的高斯分布應(yīng)建立為標(biāo)準(zhǔn)高斯分布
- 一致項(xiàng)。該項(xiàng)損失是為了使得前向加噪過程和后向去噪的過程中,Xt的分布保持一致。直觀上講,對(duì)一個(gè)更混亂圖像的去噪應(yīng)一致于對(duì)一個(gè)更清晰的圖像的加噪。而因?yàn)橐恢马?xiàng)的損失是定義于所有時(shí)間步上的,這也是三項(xiàng)損失里最耗時(shí)計(jì)算的一項(xiàng)。
雖然以上的公式推導(dǎo)給了我們一個(gè)非常直觀的證據(jù)下界,并且由于每一項(xiàng)都是以期望來計(jì)算,所以天然適用蒙特卡洛方法來近似,但如果優(yōu)化該證據(jù)下界依然存在幾個(gè)問題:
- 我們的一致項(xiàng)損失是一項(xiàng)建立在兩個(gè)隨機(jī)變量(Xt-1, Xt+1)上的期望。他們的蒙特卡洛估計(jì)的方差大概率比建立在單個(gè)獨(dú)立變量上的蒙特卡洛估計(jì)的方差大。
- 我們的一致項(xiàng)是定義于所有時(shí)間步上的KL散度的期望和。對(duì)于T取值較高的情況(通常擴(kuò)散模型T取2000左右),該期望的方差也會(huì)很大。
所以我們需要重新推導(dǎo)一個(gè)證據(jù)下界。而這個(gè)推導(dǎo)的關(guān)鍵將著眼于以下這個(gè)觀察:我們可以將擴(kuò)散過程的正向加噪過程q(xt|xt-1)重寫為q(xt|xt-1, x0)。之所以這樣重寫的原因是基于馬爾可夫假設(shè),這兩個(gè)式子完全等價(jià)。于是對(duì)這個(gè)式子使用貝葉斯法則,我們可以得到式46.
對(duì)前向加噪過程使用馬爾可夫假設(shè)和貝葉斯法則后的公式
基于公式46,我們可以重寫上面的證據(jù)下界(式37)為以下形式:其中式47,48和式37,38一致。式49開始,分母的連乘拆解由從T開始改為從1開始。式50基于上文提及的馬爾可夫假設(shè)對(duì)分母添加了x0的依賴。式51用log的性質(zhì)拆分了對(duì)數(shù)的目標(biāo)。式52代入了式46做了替換。式53將劃掉的分母部分連乘單獨(dú)提取出來后發(fā)現(xiàn)各項(xiàng)可約剩下式54部分的log(q(x1|x0)/q(xT|x0))。式54用log的性質(zhì)消去了q(x1|x0)得到了式55。式56用log的性質(zhì)拆分重組了公式,式57如同前述式43-44的變換,省去了無關(guān)的時(shí)間步。式58則用了KL散度的性質(zhì)。
應(yīng)用了馬爾可夫假設(shè)的擴(kuò)散模型證據(jù)下界推導(dǎo)1
應(yīng)用了馬爾可夫假設(shè)的擴(kuò)散模型證據(jù)下界推導(dǎo)2
至此,我們應(yīng)用了馬爾可夫假設(shè)得到了一個(gè)更優(yōu)的證據(jù)下界推導(dǎo)。該證據(jù)下界同樣包含幾項(xiàng)直觀的損失函數(shù):
- 重構(gòu)項(xiàng)。該重構(gòu)項(xiàng)與上面提及的重構(gòu)項(xiàng)一致。
- 先驗(yàn)匹配項(xiàng)。與上面提及的形式略有差別,但同樣是基于最終時(shí)間步應(yīng)為標(biāo)準(zhǔn)高斯的先驗(yàn)假設(shè)
- 去噪匹配項(xiàng)。與上面提及的一致項(xiàng)的最大區(qū)別在于不再是對(duì)兩個(gè)隨機(jī)變量的期望。并且直觀上理解p(xt-1|xt)代表的是后向的去噪過程,而q(xt-1|xt, x0)代表的是已知原始圖像和目標(biāo)噪聲圖像的前向加噪過程。該加噪過程作為目標(biāo)信號(hào),來監(jiān)督后向的去噪過程。該項(xiàng)解決了期望建立于兩個(gè)隨機(jī)變量上的問題。
注意,以上的推導(dǎo)完全基于馬爾可夫的性質(zhì)所以適用于所有MHVAE,所以當(dāng)T=1的時(shí)候,以上的證據(jù)下界和VAE所推導(dǎo)出的證據(jù)下界完全一致!并且本文之所以稱為大一統(tǒng)視角,是因?yàn)閷?duì)于該證據(jù)下界里的去噪匹配項(xiàng),不同的論文有不同的優(yōu)化方式。但歸根結(jié)底,他們的本質(zhì)互相等價(jià),且皆由該式展開推導(dǎo)得到。下面我們會(huì)從擴(kuò)散模型的角度做公式推導(dǎo),來展開計(jì)算去噪匹配項(xiàng)。(注意第一版的推導(dǎo)里的一致項(xiàng),也完全可以通過下一節(jié)的方式得到q和p的表達(dá)式,再通過KL來計(jì)算解析式)
*博客內(nèi)容為網(wǎng)友個(gè)人發(fā)布,僅代表博主個(gè)人觀點(diǎn),如有侵權(quán)請(qǐng)聯(lián)系工作人員刪除。
物聯(lián)網(wǎng)相關(guān)文章:物聯(lián)網(wǎng)是什么