樹(shù)莓派上運(yùn)行 Stable Diffusion,260MB 的 RAM「hold」住 10 億參數(shù)大模型
Stable Diffusion 能在樹(shù)莓派上運(yùn)行了!
11 個(gè)月前 Stable Diffusion 誕生,它能夠在消費(fèi)級(jí) GPU 上運(yùn)行的消息讓不少研究者備受鼓舞。不僅如此,蘋(píng)果官方很快下場(chǎng),將 Stable Diffusion「塞進(jìn)」iPhone、iPad 和 Mac 中運(yùn)行。這大大降低了 Stable Diffusion 對(duì)硬件設(shè)備的要求,讓其逐漸成為人人都能使用的「黑科技」。
現(xiàn)在,它甚至已經(jīng)可以在 Raspberry Pi Zero 2 上運(yùn)行了。
Raspberry Pi Zero 2 「Just as small. Five times as fast.」
這是怎樣一個(gè)概念?運(yùn)行 Stable Diffusion 并不是一件容易的事,它包含一個(gè) 10 億參數(shù)的大型 Transformer 模型,建議使用的最低 RAM/VRAM 通常為 8GB。而 RPI Zero 2 只是內(nèi)存為 512MB 的微型計(jì)算機(jī)。
這意味著在 RPI Zero 2 上運(yùn)行 Stable Diffusion 是一個(gè)巨大的挑戰(zhàn)。而且,在運(yùn)行過(guò)程中,作者沒(méi)有增加存儲(chǔ)空間,也沒(méi)有將中間結(jié)果卸載到磁盤(pán)上。
一般而言,主要的機(jī)器學(xué)習(xí)框架和庫(kù)都專(zhuān)注于最小化推理延遲和 / 或最大化吞吐量,但以上這些都以?xún)?nèi)存使用為代價(jià)。因此,作者決定寫(xiě)一個(gè)超小的、可破解的推理庫(kù),致力于將內(nèi)存消耗最小化。
OnnxStream 做到了。
項(xiàng)目地址 https://github.com/vitoplantamura/OnnxStream
OnnxStream 基于將推理引擎與負(fù)責(zé)提供模型權(quán)重的組件解耦的思路,后者是派生自 WeightsProvider 的一個(gè)類(lèi)。一個(gè) WeightsProvider 的專(zhuān)門(mén)化可以實(shí)現(xiàn)任何類(lèi)型的模型參數(shù)加載、緩存和預(yù)取。例如,一個(gè)自定義的 WeightsProvider 可以決定直接從 HTTP 服務(wù)器下載數(shù)據(jù),而不加載或?qū)懭肴魏蝺?nèi)容到磁盤(pán)(這也是 OnnxStream 命名中有 Stream 的原因)。有兩個(gè)默認(rèn)的 WeightsProviders 可用:DiskNoCache 和 DiskPrefetch。
與微軟的推理框架 OnnxRuntime 相比,OnnxStream 只需要消耗 1/55 的內(nèi)存就可以達(dá)到同樣的效果,但(在 CPU 上的)速度只比前者慢 0.5-2 倍。
接下來(lái)你將看到 Stable Diffusion 在 RPI Zero 2 上運(yùn)行的效果,以及背后的方法。需要注意的是,雖然運(yùn)行速度較慢,但是它是大模型在更小、更有限的設(shè)備上運(yùn)行的嶄新嘗試。
網(wǎng)友們認(rèn)為這個(gè)項(xiàng)目很酷
將 Stable Diffusion 在 Raspberry Pi Zero 2 上運(yùn)行
VAE ****是 Stable Diffusion 中唯一無(wú)法以單精度或半精度放入 RPI Zero 2 RAM 的模型。這是因?yàn)槟P椭写嬖跉埐钸B接、非常大的張量和卷積。唯一的解決辦法就是靜態(tài)量化(8 bit)。
以下這些圖像是由作者 repo 中包含的 Stable Diffusion 示例實(shí)現(xiàn)在不同精度的 VAE ****下使用 OnnxStream 生成的。
第一張圖像是在作者的 PC 上生成的,使用了由 RPI Zero 2 生成的相同的 latent。
精度為 W16A16 的 VAE ****的生成效果
精度為 W8A32 的 VAE ****的生成效果
第三張圖由 RPI Zero 2 在大約 3 小時(shí)內(nèi)生成。圖注:精度為 W8A8 的 VAE ****的生成效果
OnnxStream 的特點(diǎn)
- 推理引擎與 WeightsProvider 解耦
- WeightsProvider 可以是 DiskNoCache、DiskPrefetch 或自定義
- 注意力切片
- 動(dòng)態(tài)量化(8 bit 無(wú)符號(hào)、非對(duì)稱(chēng)、百分位數(shù))
- 靜態(tài)量化(W8A8 無(wú)符號(hào)、非對(duì)稱(chēng)、百分位數(shù))
- 輕松校準(zhǔn)量化模型
- 支持 FP16(使用或不使用 FP16 運(yùn)算)
- 實(shí)現(xiàn)了 24 個(gè) ONNX 算子(最常用的算子)
- 運(yùn)算按順序執(zhí)行,但所有算子都是多線(xiàn)程的
- 單一實(shí)現(xiàn)文件 + header 文件
- XNNPACK 調(diào)用被封裝在 XnnPack 類(lèi)中 (用于將來(lái)的替換)
并且需要注意的是,OnnxStream 依賴(lài) XNNPACK 來(lái)加速某些原語(yǔ):MatMul、Convolution、element-wise Add/Sub/Mul/Div、Sigmoid 和 Softmax。
性能對(duì)比
Stable Diffusion 由三個(gè)模型組成:文本編碼器(672 次運(yùn)算和 1.23 億個(gè)參數(shù))、UNET 模型(2050 次運(yùn)算和 8.54 億個(gè)參數(shù))和 VAE ****(276 次運(yùn)算和 4900 萬(wàn)個(gè)參數(shù)。
假設(shè)批大小等于 1,生成完整圖像則需要 10 步,這需要運(yùn)行 2 次文本編碼器、運(yùn)行 20 次(即 2*10)UNET 模型和運(yùn)行 1 次 VAE ****,才能獲得良好效果(使用 Euler Ancestral 調(diào)度器)。
該表顯示了 Stable Diffusion 的三個(gè)模型不同的推理時(shí)間,以及內(nèi)存消耗(即 Windows 中的 Peak Working Set Size 或 Linux 中的 Maximum Resident Set Size)。
可以發(fā)現(xiàn),在 UNET 模型中(以 FP16 精度運(yùn)行時(shí),OnnxStream 中啟用了 FP16 算術(shù)),OnnxStream 的內(nèi)存消耗量?jī)H為 OnnxRuntime 的 1/55,但速度只慢 0.5-2 倍。
這次測(cè)試需要注明的幾點(diǎn)是:
- OnnxRuntime 的第一次運(yùn)行是預(yù)熱推理,因?yàn)樗?nbsp;InferenceSession 是在第一次運(yùn)行前創(chuàng)建的,并在隨后的所有運(yùn)行中重復(fù)使用。而 OnnxStream 沒(méi)有預(yù)熱推理,因?yàn)樗脑O(shè)計(jì)是純粹「eager」的(不過(guò),后續(xù)運(yùn)行可以受益于操作系統(tǒng)對(duì)權(quán)重文件的緩存)。
- 目前 OnnxStream 不支持 batch size ! = 1 的輸入,這與 OnnxRuntime 不同,后者在運(yùn)行 UNET 模型時(shí)使用 batch size = 2 可以大大加快整個(gè)擴(kuò)散過(guò)程。
- 在測(cè)試中,改變 OnnxRuntime 的 SessionOptions(如 EnableCpuMemArena 和 ExecutionMode)對(duì)結(jié)果沒(méi)有產(chǎn)生明顯影響。
- 在內(nèi)存消耗和推理時(shí)間方面,OnnxRuntime 的性能與 NCNN(另一個(gè)框架)非常相似。
- 測(cè)試的運(yùn)行條件:Windows Server 2019、16GB 內(nèi)存、8750H CPU (AVX2)、970 EVO Plus SSD, VMWare 上的 8 個(gè)虛擬內(nèi)核。
注意力切片與量化
在運(yùn)行 UNET 模型時(shí),采用「注意力切片」技術(shù),并對(duì) VAE ****使用 W8A8 量化,這對(duì)于將模型內(nèi)存消耗降低到適合在 RPI Zero 2 上運(yùn)行的水平至關(guān)重要。
雖然互聯(lián)網(wǎng)上有很多關(guān)于量化神經(jīng)網(wǎng)絡(luò)的信息,但關(guān)于「注意力切片」的卻很少。
這里的想法很簡(jiǎn)單:目標(biāo)是在計(jì)算 UNET 模型中各種多頭注意力的縮放點(diǎn)積注意力時(shí),避免生成完整的 Q @ K^T 矩陣。在 UNET 模型中,注意力頭數(shù)為 8 時(shí),Q 的形狀為 (8,4096,40),同時(shí) K^T 為 (8,40,4096)。因此,第一個(gè) MatMul 的最終形狀為 (8,4096,4096),這是一個(gè) 512MB 的張量(FP32 精度)。
解決方案是垂直分割 Q,然后在每個(gè) Q 塊上正常進(jìn)行注意力操作。Q_sliced 形狀為 (1,x,40),其中 x 為 4096(在本例中),除以 onnxstream::Model::m_attention_fused_ops_parts(默認(rèn)值為 2,但可以自定義。
這個(gè)簡(jiǎn)單的技巧可以將 UNET 模型以 FP32 精度運(yùn)行時(shí)的整體內(nèi)存消耗從 1.1GB 降低到 300MB。一個(gè)更高效的替代方案是使用 FlashAttention,但是 FlashAttention 需要為每個(gè)支持的架構(gòu)(AVX, NEON)等編寫(xiě)自定義內(nèi)核,在作者給出的例子中繞過(guò) XnnPack。
更多信息參見(jiàn)該項(xiàng)目的 GitHub 界面。
參考鏈接:https://www.reddit.com/r/MachineLearning/comments/152ago3/p_onnxstream_running_stable_diffusion_in_260mb_of/https://github.com/vitoplantamura/OnnxStream
*博客內(nèi)容為網(wǎng)友個(gè)人發(fā)布,僅代表博主個(gè)人觀(guān)點(diǎn),如有侵權(quán)請(qǐng)聯(lián)系工作人員刪除。