Neural Ordinary Differential Equations 神經常微分方程

0 摘要

我們引入了一個新的深度神經網絡模型家族. 我們沒有用非連續(xù)的隱藏層, 而是用神經網絡把隱狀態(tài)的導數參數化. 網絡的輸出是通過黑盒微分方程求解器來計算的. 這些連續(xù)層的網絡的內存消耗是穩(wěn)定不變的, 針對每個輸入來設計估計方法的話, 就能做計算精度和計算速度的權衡. 我們通過連續(xù)層的ResNet和連續(xù)時間的隱變量模型展示了這些特性. 我們還構造了連續(xù)正則化流, 該生成模型可以直接用極大似然來訓練, 不需要對數據維度進行分區(qū)或者排序. 我們展示了訓練過程其實不需要了解ODE求解器內部的實現, 也能對ODE求解器的反向進行計算. 這就允許我們構造大規(guī)模模型, 并進行端到端的訓練.

1 引言

像ResNet, RNN解碼器, 正則化流, 它們都組合了隱狀態(tài)的一系列變換, 構建出一個復雜的變換, 如下:

image.png

其中t屬于[0…T], ht屬于Rd, 這些迭代更新可以看作是連續(xù)變換的歐拉離散化. 當t趨于0, step趨于無窮時, 可以得到如下的常微分方程(ODE, ordinary differential equation):

image.png

給定h(0), 我們可以把h(T) 作為該方程在T時刻的解. 該解可以用黑盒ODE求解器計算得到, 求解器還能根據需要的精度自行決定在何處對f進行擬合. 圖1對比這一過程:

image.png

圖1

左: ResNet定義了一系列非連續(xù)的有限轉換.

右: ODE網絡定義了一個向量場, 可以對隱狀態(tài)進行連續(xù)的轉換.

黑點表示估計點.

定義 使用ODE的模型有如下好處:

內存優(yōu)化:

在第2節(jié), 我們展示了如何在不涉及ODE求解器黑盒內部操作的情況下, 對任意ODE求解過程求反向, 得到標量損失的梯度. 不儲存任何前向計算結果, 就可以讓我們在內存占用不變的情況下訓練任意深度的模型. 這就解決了深度神經網絡模型訓練的主要瓶頸---模型深度.

自適應計算法

歐拉法求ODE是比較古老的方法了, 現代ODE求解器可以做到根據誤差精度要求來調整求解過程, 監(jiān)控誤差來獲得需要的精度. 這就可以根據問題復雜度來調整模型估值的消耗. 在模型訓練結束后, 還能降低計算精度來滿足程序實時性的要求.

可拓展和可逆的標準化流

連續(xù)變換帶了一個意想不到的好處, 變量方程式的變化更加容易計算了. 在第4節(jié), 我們提出這個結論并組建了一個可逆的密度模型, 該模型可以避免正則化流中單單元的瓶頸, 可以直接用極大似然來進行訓練.

連續(xù)時間序列模型

RNN需要離散的觀測和發(fā)射間隔, 而定義連續(xù)的模型可以接收任意時間得到的數據. 此種模型的構建和展示詳見第5節(jié).

2 ODE求解器的反向自動微分

訓練連續(xù)層網絡的主要問題就是對ODE求解器的反向微分(也叫反向傳播). 直接根據求解器內部操作來求微分的內存占用過大, 并且會引入額外的誤差.

我們把ODE求解器當做黑盒, 用”伴隨靈敏度法”(adjoint sensitivity method)來求梯度. 這種計算法是通過計算另一個參數化的ODE來實現的. 這種方法的復雜度會根據問題的規(guī)模線性變化, 內存占用也很低, 并且可以顯式的控制計算精度.

假設標量的損失函數為L, 輸入是ODE求解器的結果:

image.png

為最小化L, 就需要求L對θ的梯度, 第一步就是要求L在每一個時刻對隱狀態(tài)z(t)的梯度. 這部分被稱為”伴隨”:

image.png

它也是一個ODE, 可以視作瞬時的鏈式法則:

image.png

這樣, 再調一次求解器就可以解出
image.png

. 這個求解是反向進行的, 初始狀態(tài)是
image.png

解這個ODE就需要知道從t0到t1軌跡上的所有z(t). 所以在求伴隨的過程中需要把z(t)也一并解出, 就可以在中間的軌跡上使用z(t)的值來求a(t)了.

計算L對θ的偏導則需要求第三個積分式:

image.png

這個式子需要知道z(t)和a(t)的值.

image.png

image.png

這兩個向量-jacobian 乘積可以通過一次自動微分直接得到, 時間消耗跟對f的估值差不多. 只要把初始狀態(tài), 伴隨和另一個偏導 concat 到一個向量中, 所有求解z,a和
image.png

的積分, 都可以通過調用一次ODE求解器計算得出. 如下算法1的偽代碼:
image.png

大多數的ODE求解器都可以輸出中間計算結果z(t), 當loss取決于這些中間狀態(tài)時, 反向偏導的計算也必須拆成一系列的求解. 如圖2所示:


image.png

圖2: ODE求解器的反向過程.

伴隨敏感度法求反向是分時刻實時求解的. 參數化的系統(tǒng)包括了初始狀態(tài)以及l(fā)oss對狀態(tài)的靈敏度. 如果損失直接依賴于多個時刻的隱狀態(tài)的觀測, 伴隨狀態(tài)也必須在loss對觀測的偏導方向上更新.

在每個觀測處, 伴隨都必須跟著偏導
image.png

的方向調整.

在附錄C中給出了L關于t0, t1偏導的解法. 附錄B中給出上面公式的詳細推導過程. 附錄D給出了上述算法scipy實現, 這部分代碼也支持更高階的微分.

https://github.com/rtqichen/torchdiffeq中還給出了pytorch版本的實現.

3 用ODE來取代ResNet進行有監(jiān)督的訓練

本節(jié)嘗試用神經ODE進行有監(jiān)督訓練.

軟件: (作者說自己選取了某某ODE求解器, 還用一個第三方框架實現了求反向, 但是在pytorch版代碼中這些都對不上)

模型結構: 使用了一個小的殘差網絡, 對輸入進行了2次下采樣, 然后疊了6個標準殘差鏈接層, 這6個殘差連接層替換成ODE求解器模塊. 還測試了一下同樣結構, 但是反向直接用鏈式法則求解的網絡, 記為RK-Net. 各網絡的表現如下:

image.png

可以看到, ODE網絡和RK網絡可以達到和ResNet相同的性能.

ODE****網絡的誤差控制: ODE求解器可以保證計算誤差在真實解的某個誤差限內. 更改這個誤差限會改變網絡的性能表現. 圖3a展示了誤差是可控的. 圖3b展示了前向計算時間是跟著函數估值次數成比例增加的. 所以降低誤差限可以在計算速度和精度之間做取舍. 你可以在訓練時用高精度, 但是在推理時用低精度來加快速度..

image.png

圖3c表明: 反向計算的消耗只有前向計算的一半左右. 這就表明, 伴隨法不但節(jié)省內存, 還比直接求反向更加高效.

網絡深度: 在ODE中不太好直接定義網絡層數這個概念. 有點類似的是隱狀態(tài)方程估值所需的次數, 這依賴于ODE求解器的輸入和初始狀態(tài). 圖3d展示了訓練過程中估值次數的增加, 這對應了模型復雜度的增長.

4 連續(xù)正則化流

還有一個模型也出現了類似式1的非連續(xù)型方程, 那就是正則化流(NF, normalization flows)和NICE framework. 這些模式使用變量代換定理來計算可逆變換之后的概率密度.


image.png

經典的正則化流模型: planar normalization flows的公式如下:

image.png

一般來說, 使用變量代換公式的瓶頸是計算雅克比矩陣
image.png

, 它的計算復雜度要么是z維度的立方, 要么是隱藏單元數量的立方. 最近的研究都是在NF模型的表達能力和計算復雜度做取舍.

令人驚訝的是, 我們把非連續(xù)的模型公式, 用第3節(jié)同樣的思路來轉換成連續(xù)模型可以減少計算量.

定理1: 變量瞬時變化

設z(t)是一個有限連續(xù)隨機變量,概率p(z(t))依賴于時間. 則下式是z(t)隨時間連續(xù)變化的微分方程:

image.png

假設f在z上均勻Lipschitz連續(xù),在t上連續(xù),那么對數概率密度的變化也遵循微分方程:


image.png

證明見附錄A. 與式6的log計算不同, 本式只需要計算跡(trace)的操作. 另外, 不像標準的NF模型, 本式不要求f是可逆的, 因為如果滿足唯一性,那么整個轉換自然就是可逆的.

應用變量瞬時變化定理,我們可以看一下planar normalization flows的連續(xù)模擬版本:

image.png

給定一個初始分布p(z(0),我們可以從p(z(T))中采樣,并通過求解這組ODE來評估其概率密度。

使用多個線性成本的隱藏單元

當det(行列式)不是線性方程時, 跡的方程還是線性的, 并且滿足:

image.png

這樣我們的方程就可以由一系列的求和得到, 概率密度的微分方程也是一個求和:

image.png

這意味著我們可以很簡便的評估多隱藏單元的流模型,其成本僅與隱藏單元M的數量呈線性關系。使用標準的NF模型評估這種“寬”層的成本是O(M3),這意味著標準NF體系結構的多個層只使用單個隱藏單元.

依賴于時間的動態(tài)方程

我們可以將流的參數指定為t的函數,使微分方程f(z(t)、t)隨t而變化。這種參數化的方法是一種超網絡. 我們還為每個隱藏層引入了門機制:

image.png

其中:
image.png

, 是一個神經網絡, 可以學習到何時使用fn. 我們把該模型稱之為連續(xù)正則化流(CNF, continuous normalizing flows)

4.1 CNF試驗

我們首先比較連續(xù)的和離散的planar正則化流在學習樣本從一個已知的分布。我們證明了一個具有M個隱藏單元的連續(xù) planar CNF至少可以與一個具有K層(M = K)的離散 planar NF具有同樣的擬合能力,某些情況下CNF的擬合能力甚至更強.

擬合概率密度

設置一個前述的CNF, 用adam優(yōu)化器訓練10000個step. 對應的NF使用RMSprop訓練500000個step. 此任務中損失函數為KL (q(x)||p(x)), 最小化這個損失函數, 來用q(x)擬合目標概率分布p(x). 圖4表明, CNF可以得到更低的損失.

[圖片上傳失敗...(image-7d47a5-1616472352555)]

極大似然訓練

CNF一個有用的特性是: 計算反向轉換和正向的成本差不多, 這一點是NF模型做不到的. 這樣在用CNF模型做概率密度估計任務時, 我們可以通過極大似然估計來進行訓練 也就是最大化log(q(x))的期望值. 其中q是變量代換之后的函數. 然后反向轉換CNF來從q(x)中進行采樣.

該任務中, 我們使用64個隱藏單元的CNF和64層的NF來進行對比. 圖5展示了最終的訓練結果. 從最初的高斯分布, 到最終學到的分布, 每一個圖代表時間t的某一步. 有趣的是: 為了擬合兩個圓圈, CNF把planar 流 進行了旋轉, 這樣粒子會均分到兩個圓中. 跟 CNF的平滑可解釋相對的是, NF模型比較反直覺, 并且很難擬合雙月牙的概率分布(見圖5.b)

[圖片上傳失敗...(image-687aaa-1616472365302)]

5 生成式隱方程時間序列模型

將神經網絡應用于不規(guī)則采樣的數據,如醫(yī)療記錄、網絡流量或神經尖峰數據是困難的。 通常,觀測被放入固定持續(xù)時間的桶中,隱方程(變量?原文是dynamic)以同樣的方式進行離散。如果存在數據缺失或隱變量定義不當的情況, 問題就比較困難. 數據缺失可以用數據填充和生成時間序列模型來進行標記. 還有一種方式是給RNN的輸入加時間戳信息.

我們提出了一種連續(xù)時間,生成的方法來建模時間序列。我們的模型用一個隱軌跡來表示每個時間序列。每個軌跡都是由一個局部初始狀態(tài)zt0和跨所有時間序列共享的全局隱方程組來確定。給定觀測時間t0、t1、……tN和初始狀態(tài)zt0,ODE求解算器產生zt1,…ztN,描述每個觀測的潛在狀態(tài)。我們通過一個采樣程序正式地定義了這個生成模型:

image.png

函數f是一個時間無關的函數,在當前時間步長取z并輸出梯度:

image.png

我們用神經網絡來參數化這個方程. 因為f是時間無關的, 給定隱狀態(tài)z(t), 整個隱軌跡就是唯一確定的. 推斷隱軌跡可以讓我們在時間上任意向前或后退做出預測

image.png

訓練與預測

我們可以用觀測的序列將這個潛變量模型訓練為變分自動編碼器. 我們的判別模型RNN倒序的接收時間序列數據, 輸出q φ (z 0 |x 1 ,x 2 ,...,x N ). 詳見附錄E. 使用ODE來做生成模型, 我們就能在已知時間序列的情況下, 在任意時間點做出預測.

泊松過程似然

觀測本身就給出了一些隱狀態(tài)的信息, 比如說: 得病的人更傾向于做藥物測試. 事件發(fā)生率可以用隱方程來進行參數化:

image.png

給定這個概率函數,非均勻泊松過程給出了區(qū)間[tstart,tend]中獨立觀測的可能性:

image.png

我們可以使用另一個神經網絡來參數化λ(·)。因此,我們可以調用一次ODE求解器就評估出隱軌跡和泊松過程概率值。圖7為該模型在數據集上學習到的事件發(fā)生率。


image.png

觀測時間上的泊松過程似然可以與數據似然相結合,共同模擬所有觀測和時間。

5.1 事件序列隱ODE試驗

我們研究了隱ODE模型的擬合和推斷時間序列的能力。該判別網絡是一個有25個隱藏單元的RNN。我們使用一個四維的隱空間。我們用一個具有20個隱藏單元的單隱藏層網絡來參數化函數f。解碼器是一個神經網絡, 只有一個隱藏層, 20個隱藏單元, 用于計算p(x t i |z t i )。我們的基線是一個有25個隱藏單元的RNN,用最小化負高斯對數似然為目標函數訓練。我們訓練了這個RNN的第二個版本,其輸入與下一個觀測的時間差連接,以幫助RNN進行不規(guī)則的觀測。

雙向螺旋數據集

我們生成了一個1000個二維螺旋的數據集,每個螺旋從一個不同的點開始,在100個相同間隔的時間步長采樣。 數據集包含兩種類型的螺旋:一半是順時針方向,另一半是逆時針方向。 為了模擬真實情況,我們在觀測中加入高斯噪聲。

具有不規(guī)則時間點的時間序列

為了生成不規(guī)則的時間戳,我們不替換的從每個軌跡隨機采樣 (n={30,50,100}). 訓練數據之外, 我們展示了100個時間點的預測均方根誤差(RMSE)。 表2顯示,隱ODE預測時的RMSE明顯較低.

image.png

圖8展示了用下采樣的30個點來擬合螺旋的結果.

[圖片上傳失敗...(image-53050-1616472422476)]

隱ODE的重構是通過對潛在軌跡的后驗采樣并將其解碼為數據空間得到的. 附錄F展示了更多不同數據點的情況. 我們發(fā)現, 不管多少個點的下采樣, 不管有沒有高斯噪聲, 重建和推斷都和真實情況一致.

隱空間推斷

圖8c展示了隱軌跡投影到隱空間前2個維度的結果. 這是兩個軌跡群, 一個順時針一個逆時針. 圖9展示了: 初始狀態(tài)隱軌跡方程為順時針, 而后轉變?yōu)槟鏁r針, 這一轉變過程是非常連續(xù)的.


image.png

6 應用范圍與限制.

Mini-Batch

Mini-Batch的使用不如標準神經網絡那么直觀。我們仍然可以通過將每個batch的狀態(tài)連接在一起,創(chuàng)建維度D×K的ODE方程組,通過ODE求解器來計算。In some cases, controlling

error on all batch elements together might require evaluating the combined system K times more

often than if each system was solved individually(不太懂什么意思)。不過,在實踐中使用Mini-Batch時,計算量并沒有大幅增加.

唯一性

什么情況下連續(xù)方程有唯一解? 皮卡存在定理限定了, 當微分方程Lipschitz連續(xù)并且z在t上連續(xù)時, 初值問題的解存在且唯一. 這就對我們使用的神經網絡有所限制, 模型的權重有限, 且不能使用非Lipschitz連續(xù)的激活函數, 比如tanh或者relu.

設置計算精度

模型允許用戶在計算精度和速度之間做trade-off, 需要用戶在訓練的前向和反向中設置誤差限. 對于序列模型, 默認值為1.5e-8. 在分類和概率密度擬合問題中, 不降低模型性能的情況下, 默認值可設置為1e-3和1e-5.

重建前向軌跡

如果重建的軌跡偏離了原軌跡,則通過向后運行的方程來重建狀態(tài)軌跡會帶來額外的數值誤差。這個問題可以通過checkpoint來解決:將z的中間值存儲在前向過程中,并通過從這些點重新積分來重建精確的前向軌跡。不過在實際計算中這不是一個問題,多層CNF的反向可以恢復到初始狀態(tài).

7 相關工作

8 結語

我們探索了黑盒ODE求解器作為模型的一部分, 并用它開發(fā)了新模型可以用于時間序列問題, 監(jiān)督學習問題, 概率密度估計問題. 這些模型可以自適應的進行估值計算, 并且允許用戶顯式的在計算速度和精度之間做取舍. 最終, 我們提出了連續(xù)版本的變量代換模型, 命名為CNF, 該模型的層可以擴展到比較大的尺度.

9 注:

我沒有對附錄和參考文獻做翻譯, 這部分大家請下載論文原文查看: https://arxiv.org/pdf/1806.07366

最后編輯于
?著作權歸作者所有,轉載或內容合作請聯(lián)系作者
平臺聲明:文章內容(如有圖片或視頻亦包括在內)由作者上傳并發(fā)布,文章內容僅代表作者本人觀點,簡書系信息發(fā)布平臺,僅提供信息存儲服務。
  • 序言:七十年代末,一起剝皮案震驚了整個濱河市,隨后出現的幾起案子,更是在濱河造成了極大的恐慌,老刑警劉巖,帶你破解...
    沈念sama閱讀 228,345評論 6 531
  • 序言:濱河連續(xù)發(fā)生了三起死亡事件,死亡現場離奇詭異,居然都是意外死亡,警方通過查閱死者的電腦和手機,發(fā)現死者居然都...
    沈念sama閱讀 98,494評論 3 416
  • 文/潘曉璐 我一進店門,熙熙樓的掌柜王于貴愁眉苦臉地迎上來,“玉大人,你說我怎么就攤上這事。” “怎么了?”我有些...
    開封第一講書人閱讀 176,283評論 0 374
  • 文/不壞的土叔 我叫張陵,是天一觀的道長。 經常有香客問我,道長,這世上最難降的妖魔是什么? 我笑而不...
    開封第一講書人閱讀 62,953評論 1 309
  • 正文 為了忘掉前任,我火速辦了婚禮,結果婚禮上,老公的妹妹穿的比我還像新娘。我一直安慰自己,他們只是感情好,可當我...
    茶點故事閱讀 71,714評論 6 410
  • 文/花漫 我一把揭開白布。 她就那樣靜靜地躺著,像睡著了一般。 火紅的嫁衣襯著肌膚如雪。 梳的紋絲不亂的頭發(fā)上,一...
    開封第一講書人閱讀 55,186評論 1 324
  • 那天,我揣著相機與錄音,去河邊找鬼。 笑死,一個胖子當著我的面吹牛,可吹牛的內容都是我干的。 我是一名探鬼主播,決...
    沈念sama閱讀 43,255評論 3 441
  • 文/蒼蘭香墨 我猛地睜開眼,長吁一口氣:“原來是場噩夢啊……” “哼!你這毒婦竟也來了?” 一聲冷哼從身側響起,我...
    開封第一講書人閱讀 42,410評論 0 288
  • 序言:老撾萬榮一對情侶失蹤,失蹤者是張志新(化名)和其女友劉穎,沒想到半個月后,有當地人在樹林里發(fā)現了一具尸體,經...
    沈念sama閱讀 48,940評論 1 335
  • 正文 獨居荒郊野嶺守林人離奇死亡,尸身上長有42處帶血的膿包…… 初始之章·張勛 以下內容為張勛視角 年9月15日...
    茶點故事閱讀 40,776評論 3 354
  • 正文 我和宋清朗相戀三年,在試婚紗的時候發(fā)現自己被綠了。 大學時的朋友給我發(fā)了我未婚夫和他白月光在一起吃飯的照片。...
    茶點故事閱讀 42,976評論 1 369
  • 序言:一個原本活蹦亂跳的男人離奇死亡,死狀恐怖,靈堂內的尸體忽然破棺而出,到底是詐尸還是另有隱情,我是刑警寧澤,帶...
    沈念sama閱讀 38,518評論 5 359
  • 正文 年R本政府宣布,位于F島的核電站,受9級特大地震影響,放射性物質發(fā)生泄漏。R本人自食惡果不足惜,卻給世界環(huán)境...
    茶點故事閱讀 44,210評論 3 347
  • 文/蒙蒙 一、第九天 我趴在偏房一處隱蔽的房頂上張望。 院中可真熱鬧,春花似錦、人聲如沸。這莊子的主人今日做“春日...
    開封第一講書人閱讀 34,642評論 0 26
  • 文/蒼蘭香墨 我抬頭看了看天上的太陽。三九已至,卻和暖如春,著一層夾襖步出監(jiān)牢的瞬間,已是汗流浹背。 一陣腳步聲響...
    開封第一講書人閱讀 35,878評論 1 286
  • 我被黑心中介騙來泰國打工, 沒想到剛下飛機就差點兒被人妖公主榨干…… 1. 我叫王不留,地道東北人。 一個月前我還...
    沈念sama閱讀 51,654評論 3 391
  • 正文 我出身青樓,卻偏偏與公主長得像,于是被迫代替她去往敵國和親。 傳聞我的和親對象是個殘疾皇子,可洞房花燭夜當晚...
    茶點故事閱讀 47,958評論 2 373

推薦閱讀更多精彩內容