突破百萬億參數規模:華人團隊開源首個異構並行推薦系統訓練框架

機器之心專欄

機器之心編輯部

Persia打破了前幾代的推薦訓練框架(同構的)設計思路,首次採用異構的設計思路,更合理地配置了CPU和GPU,實現了極致化的性價比。

個性化推薦是互聯網行業提升 DAU (Daily Active Users)和收入的核心技術手段。隨着深度學習的廣泛應用,現代的推薦系統通過神經網路變相地「記住」用戶的行為習慣,從而精準預測出用戶的喜好。在移動互聯網普及之後,用戶的行為數據呈現幾何級數增加,單位時間內產生和收集的用戶行為數據更是極其龐大,因此需要更大的模型來對用戶的興趣編碼。更大的數據規模意味着需要更大的模型容量,模型參數量從 5 年前的十億已經迅速增長達到前段時間 Facebook 公開的十萬億參數規模。在這樣的趨勢下,更大規模的訓練需求無疑將會成為下一個需要攻剋的里程碑。

最近,由兩個華人團隊聯合開源的訓練框架 Persia 通過設計混合架構並在 Google cloud 上成功地把模型規模又推向了一個新的量級 — 百萬億參數量(需占用數百 T 的存儲),並能同時兼顧效率和精度。目前該框架已經受邀集成進 Pytorch 生態圈 Pytorch Lightning。

图片alt

  • GitHub: https://github.com/PersiaML/PERSIA

  • 論文: https://arxiv.org/abs/2111.05897

隨着模型的參數量隨着指數級別的增長,對於高性能的訓練框架的需求也越來越迫切。傳統架構在應對越來越多參數量面前也顯得越來越力不從心。傳統架構採用 CPU 的同構並行機制,對應的參數分佈採用模型並行。其最大的優點是便於機器的水平擴展以支持相對更大的模型,因此至今仍然在很多公司廣泛使用(雖然在各個公司有不同的命名方式,本文統稱為 mio 架構)。當推薦模型從傳統的 Logistic regression 升級到基於 Deep Learning 的模型,且參數量急劇增大的時候,傳統的方案就顯得捉襟見肘 —- 效率低下且難以兼顧精度。後來的進化版通過引入 GPU 承擔了深度網路部分的計算(本文稱為 mio + 架構),由於採用的仍然是同構的設計思路(只是把CPU機器換成了CPU帶GPU的機器),雖然能取得一定效率提升,部分緩解了效率和精度的矛盾,但是當需要應對不同規模的網路結構的時候,往往出現昂貴的 GPU 資源大量空閑的情況,導致性價比受損嚴重。

图片alt

為瞭解決這兩個因模型規模不斷膨脹而帶來的難題,Persia 的核心設計思路如下:

  • 採用異構的架構設計解決 GPU 資源利用率的問題。當 CPU 和 GPU 的配比綁定的時候,任何框架都難以同時保證在任何模型結構下的資源利用率。因此 Persia 設計了一種靈活的異構架構來實現按模型需求分配資源,保證效率的前提下資源的充分利用,大幅提升了性價比。

  • 採用同步和非同步的混合訓練模式同時兼顧效率和精度。傳統的方案中或是採用純同步的訓練,或是採用純非同步的訓練。在模型越來越大、機器數量越來越多的情況下,同步的訓練會導致機器之間相互等待,訓練效率容易受損嚴重。而非同步的訓練方式雖然避免了機器之間相互等待,訓練的效率顯著提高,但是隨着機器數量增加,模型的準確率(Accuracy)會大幅下降。針對超大模型情景下這樣的挑戰,Persia 設計了一種同步和非同步 Hybrid 訓練架構,集二者之長而避其短。並從理論和實踐兩個維度都驗證了 Persia 能同時達到同步訓練的準確率和非同步的訓練的效率。

這里簡要列出幾點 Persia 的特點:

  • 原生支持 PyTorch 生態:鑒於 PyTorch 極大地降低了研究人員定義模型的門檻,趨勢上在整個深度學習領域的占比越來越大,有別於已有的推薦訓練框架(如 XDL,PaddlePaddle 等),Persia 決定基於 PyTorch 生態。用戶模型定義等操作可直接借助 PyTorch 實現,因而即便是在研究領域最新最前沿模型(如 Transformer 等)也可直接調用,達到最大限度的靈活性與易用性。

  • 高性能:在 Criteo 標準數據集上,相較其他流行的開源推薦模型訓練框架,同樣資源條件下 Persia 可達到一倍以上的性能提升。Persia 支持 CPU-GPU 異構訓練,支持 GPU 與 GPU 直接通訊,顯著降低訓練成本。

  • 可擴展性:Persia 在高達 100 萬億模型參數訓練的 scale 下保持高訓練效率。同時在多數場景能夠接近線性加速(投入 n 倍的資源量,訓練效率提升接近 n 倍)。

  • 工業級場景大規模驗證:Persia 為 Kubernetes 實現了定製化的 operator,支持雲原生部署。並實現了各種容錯機制,經過在線上生產環境穩定運行兩年以上的驗證。Persia 經過多個億級 DAU 核心業務場景的實踐檢驗,取得了顯著的性能和業務指標提升。

  • 安全、故障易排查:Persia 由註重內存安全、速度和並發性的 Rust 語言實現,在編譯期就排除了大量的內存安全問題。原生提供大量打點監控,與 Grafana 完美結合,可自定義各類報警條件。同時基於 tracing 實現了分模塊、分層級的 log 輸出,使得實際場景中故障排查更加輕松。

  • 靈活的特徵處理:支持交叉特徵等各種常見特徵處理方式,且用戶通過 Python 腳本即可定義各種自定義特徵處理模式。兼具靈活性與易用性。

  • 線上線下一致性:離線訓練和線上訓練代碼統一,解決工程師常常需要花費大量時間排查模型上線效果不一致等痛點問題。

Persia 設計思路

整體架構

在推薦模型中,模型往往由下圖中的幾部分構成:

图片alt

  • Embedding Layer: 用戶 id、item id 等 ID 類 feature 對應的 Embedding 構成的 Embedding 層。每個 id 對應一個預設大小的向量(稱為 Embedding),由於 id 數量往往十分巨大,這些向量常常會占據整個模型體積的 99% 以上。

  • Non-ID Type Features: 圖像信息、LDA 等實數向量特徵。這部分將會與 id 對應的 Embedding vector 組合在一起,輸入到 DNN 中預測點擊率等。

  • Dense Neural Network (以下簡稱 NN): 這部分是一個神經網路,接受 Embedding vector 和實數向量特徵,輸出點擊率等希望預測的量。

這種推薦模型中, Embedding Layer 參數往往占模型體積的絕大部分,但 Embedding Layer 的計算量卻不大。而 NN 的參數量只占模型體積的很小部分,卻占了絕大部分計算量。這正對應了:硬體上 CPU 的內存較大,但算力較低,而 GPU 的顯存較小,但算力較高。

現有的訓練框架雖然包含GPU算力,但是每個 GPU worker 都需要跟大量 PS 之間傳遞數據和模型,這常常會觸發通訊瓶頸,從而整個效率都被拖垮了。

因此,在 Persia 系統設計中,NN 被置於 GPU 顯存中,通過 GPU 進行梯度計算。對於 NN 部分直接通過 GPU 與 GPU 之間的高效集合通訊同步,完全不經過 PS。而 Embedding 則置於內存中,通過 CPU 進行計算。Persia 對於 PS 進行兩層架構設計 (Embedding PS, Embedding Worker,後文介紹),能夠在多數場景下進一步降低 GPU worker 帶寬消耗,提升整體訓練效率。

同步+非同步混合訓練

此外,現存系統往往採用全同步訓練或全非同步訓練方式。在全同步訓練中,所有 GPU worker 對一批數據進行訓練和模型更新,全部完成後再進入下一批數據。在模型越來越大、機器數量越來越多的情況下,會導致機器之間相互等待、同步的時間大幅增加,難以在有限時間內完成訓練。這種情況下系統的訓練過程如下圖中第一行 (Full Sync) 所示。在全非同步訓練中,每個 worker 獨立訓練並更新 PS 參數。雖然 worker 之間不需要相互等待,訓練的效率較高,但是隨着機器數量增加,每個 worker 上使用的模型的差異會變大,導致模型的訓練效果大幅下降。這種情況下系統的訓練過程如下圖中第二行 (Full Async) 所示。

針對這兩種方式的問題,Persia 設計了 Hybrid 訓練架構,能夠在保證訓練效果的同時,達到接近全非同步的訓練效率。推薦場景訓練中的一個核心觀察是,Embedding 的更新非常稀疏,兩次更新之間往往交集很小,因此即使對 Embedding 做非同步更新,對最終的訓練結果影響也不大。而 NN 部分的更新則反之,每一次都會更新全部參數,如果做非同步訓練,會導致訓練結果的巨大差異。

Persia 所提出的 Hybrid 訓練方式,能夠對 NN 部分同步訓練,Embedding 部分非同步訓練。最終訓練效率接近純非同步訓練的效率,同時模型效果保持和全同步訓練一致。兼得兩方面的優勢。這種情況下系統的訓練過程如下圖中第三行 (Naive Hybrid) 所示。Persia 在此之上還對能夠並行執行的通訊、計算操作進行重疊,進一步提升系統效率。

最終系統的訓練過程如下圖中第四行 (Persia) 所示:

图片alt

理論保證

有別於現存系統,Persia 對於 Hybrid 演算法的設計給出了嚴格的理論保證。對於 Expectation of Loss 的優化問題(比如推薦場景中最普遍的每個樣本對應一個 loss 的場景):

图片alt

其中 f(w) 代表整個數據集上的平均 loss,ξ 代表一個樣本,w 代表模型參數,F(w; ξ) 代表樣本 ξ 上的 loss。模型訓練的目標是最小化整個數據集上的平均 loss。使用 Persia Hybrid 的訓練方式,可以證明模型的收斂速度為:

图片alt

其中 σ 為數據集方差,T 為迭代次數,τ 為 GPU worker 數量,α 為 ID 類 feature 碰撞概率。其中前兩項為全同步訓練的收斂速度,最後一項為 Hybrid 訓練引入的誤差。在推薦場景中,因為 Embedding 的更新非常稀疏,碰撞概率 α 遠小於 1,因此 Hybrid 收斂速度與全同步訓練收斂速度幾乎完全一致,但因為同步開銷減少,每一步的訓練執行效率大幅提升。對於具體的理論證明,可以參考 Persia 的論文 [1]。

其他優化

在演算法創新的基礎上,為了發揮極致的性能。Persia 提供了大量的實現層優化。比如:

  • 所有 PS 服務通訊使用為訓練場景優化的 zero-copy Persia RPC 系統,在訓練場景下(特點是 payload 非常大,包括 Embedding 和梯度等等大量 tensor 數據)性能遠超傳統 RPC 框架(如 gRPC、bRPC);

  • GPU 之間通訊使用同為快手開源的 Bagua 訓練加速框架,對 GPU 之間的集合通訊性能有顯著提升,並能通過梯度壓縮等演算法進一步降低通訊開銷;

  • Embedding 在 PS 中通過特殊設計的數據結構 (Persia Embedding Array List) 存儲,大幅提升 PS 效率和模型存取效率。這包括運行過程中無需動態申請新內存,同時更好地利用 CPU cache 機制。支持 Embedding 逐出邏輯。模型保存和讀取過程簡化為對連續內存的直接 Dump/Load 過程;

图片alt

  • 引入 Persia Embedding Worker 組件,將 Embedding Sum Pooling、處理原始數據等操作執行後再發送給 GPU,大幅減少 GPU 帶寬占用;

图片alt

  • 原始數據處理為 Persia Compact Batch 格式,自帶 ID 去重和數據壓縮表徵等性質,相比一般表示方式數據體積降低至 1/4,提升系統數據處理效率。

驗證和比較

測試選用 Alibaba-Ad,Avazu-Ad,Criteo-Ad 等多種開源數據集,訓練效率整體有 8 倍以上提升:

图片alt

Persia 可支持高達 100 萬億模型訓練,並在模型規模變大時保持訓練效率:

图片alt

在資源量擴大時,Persia 可以接近線性擴展(投入 n 倍的資源量,訓練效率提升接近 n 倍):

图片alt

Persia 使用實例

使用 Persia 非常簡單,主要分為訓練部署、模型定義、自定義數據集部分。

  • 分散式部署:通過 Persia operator 可在 Kubernetes 集群上一鍵部署 PERSIA 任務

  • 模型定義:直接使用 PyTorch

  • 自定義數據集:自定義預處理邏輯,將結果通過 Persia 提供的 Python 工具包轉換成 Persia Compact Batch 即可

完整例子和更多場景,歡迎參考 Persia Tutorial 文檔(https://persiaml-tutorials.pages.dev/)。

Persia 模型上線推理

Persia 訓練的模型 Embedding 部分可通過線上部署 Embedding PS 和 Embedding Worker 直接提供服務。NN 部分為原生 PyTorch 模型,在 Persia Tutorial 中提供了通過 TorchServe 推理的簡單例子。用戶也可以通過原生 PyTorch 的各種工具,比如轉換成 TensorRT 模型,進一步提升推理性能。

參考文獻

Xiangru Lian, Binhang Yuan, Xuefeng Zhu, Yulong Wang, Yongjun He, Honghuan Wu, Lei Sun, Haodong Lyu, Chengjun Liu, Xing Dong, Yiqiao Liao, Mingnan Luo, Congfei Zhang, Jingru Xie, Haonan Li, Lei Chen, Renjie Huang, Jianying Lin, Chengchun Shu, Xuezhong Qiu, Zhishan Liu, Dongying Kong, Lei Yuan, Hai Yu, Sen Yang, Ce Zhang, & Ji Liu. (2021). Persia: A Hybrid System Scaling Deep Learning Based Recommenders up to 100 Trillion Parameters.