0 摘要
我們引入了一個新的深度神經網絡模型家族. 我們沒有用非連續(xù)的隱藏層, 而是用神經網絡把隱狀態(tài)的導數參數化. 網絡的輸出是通過黑盒微分方程求解器來計算的. 這些連續(xù)層的網絡的內存消耗是穩(wěn)定不變的, 針對每個輸入來設計估計方法的話, 就能做計算精度和計算速度的權衡. 我們通過連續(xù)層的ResNet和連續(xù)時間的隱變量模型展示了這些特性. 我們還構造了連續(xù)正則化流, 該生成模型可以直接用極大似然來訓練, 不需要對數據維度進行分區(qū)或者排序. 我們展示了訓練過程其實不需要了解ODE求解器內部的實現, 也能對ODE求解器的反向進行計算. 這就允許我們構造大規(guī)模模型, 并進行端到端的訓練.
1 引言
像ResNet, RNN解碼器, 正則化流, 它們都組合了隱狀態(tài)的一系列變換, 構建出一個復雜的變換, 如下:
其中t屬于[0…T], ht屬于Rd, 這些迭代更新可以看作是連續(xù)變換的歐拉離散化. 當t趨于0, step趨于無窮時, 可以得到如下的常微分方程(ODE, ordinary differential equation):
給定h(0), 我們可以把h(T) 作為該方程在T時刻的解. 該解可以用黑盒ODE求解器計算得到, 求解器還能根據需要的精度自行決定在何處對f進行擬合. 圖1對比這一過程:
圖1
左: ResNet定義了一系列非連續(xù)的有限轉換.
右: ODE網絡定義了一個向量場, 可以對隱狀態(tài)進行連續(xù)的轉換.
黑點表示估計點.
定義 使用ODE的模型有如下好處:
內存優(yōu)化:
在第2節(jié), 我們展示了如何在不涉及ODE求解器黑盒內部操作的情況下, 對任意ODE求解過程求反向, 得到標量損失的梯度. 不儲存任何前向計算結果, 就可以讓我們在內存占用不變的情況下訓練任意深度的模型. 這就解決了深度神經網絡模型訓練的主要瓶頸---模型深度.
自適應計算法
歐拉法求ODE是比較古老的方法了, 現代ODE求解器可以做到根據誤差精度要求來調整求解過程, 監(jiān)控誤差來獲得需要的精度. 這就可以根據問題復雜度來調整模型估值的消耗. 在模型訓練結束后, 還能降低計算精度來滿足程序實時性的要求.
可拓展和可逆的標準化流
連續(xù)變換帶了一個意想不到的好處, 變量方程式的變化更加容易計算了. 在第4節(jié), 我們提出這個結論并組建了一個可逆的密度模型, 該模型可以避免正則化流中單單元的瓶頸, 可以直接用極大似然來進行訓練.
連續(xù)時間序列模型
RNN需要離散的觀測和發(fā)射間隔, 而定義連續(xù)的模型可以接收任意時間得到的數據. 此種模型的構建和展示詳見第5節(jié).
2 ODE求解器的反向自動微分
訓練連續(xù)層網絡的主要問題就是對ODE求解器的反向微分(也叫反向傳播). 直接根據求解器內部操作來求微分的內存占用過大, 并且會引入額外的誤差.
我們把ODE求解器當做黑盒, 用”伴隨靈敏度法”(adjoint sensitivity method)來求梯度. 這種計算法是通過計算另一個參數化的ODE來實現的. 這種方法的復雜度會根據問題的規(guī)模線性變化, 內存占用也很低, 并且可以顯式的控制計算精度.
假設標量的損失函數為L, 輸入是ODE求解器的結果:
為最小化L, 就需要求L對θ的梯度, 第一步就是要求L在每一個時刻對隱狀態(tài)z(t)的梯度. 這部分被稱為”伴隨”:
它也是一個ODE, 可以視作瞬時的鏈式法則:
. 這個求解是反向進行的, 初始狀態(tài)是
解這個ODE就需要知道從t0到t1軌跡上的所有z(t). 所以在求伴隨的過程中需要把z(t)也一并解出, 就可以在中間的軌跡上使用z(t)的值來求a(t)了.
計算L對θ的偏導則需要求第三個積分式:
這個式子需要知道z(t)和a(t)的值.
和
這兩個向量-jacobian 乘積可以通過一次自動微分直接得到, 時間消耗跟對f的估值差不多. 只要把初始狀態(tài), 伴隨和另一個偏導 concat 到一個向量中, 所有求解z,a和
的積分, 都可以通過調用一次ODE求解器計算得出. 如下算法1的偽代碼:
大多數的ODE求解器都可以輸出中間計算結果z(t), 當loss取決于這些中間狀態(tài)時, 反向偏導的計算也必須拆成一系列的求解. 如圖2所示:
圖2: ODE求解器的反向過程.
伴隨敏感度法求反向是分時刻實時求解的. 參數化的系統(tǒng)包括了初始狀態(tài)以及l(fā)oss對狀態(tài)的靈敏度. 如果損失直接依賴于多個時刻的隱狀態(tài)的觀測, 伴隨狀態(tài)也必須在loss對觀測的偏導方向上更新.
在每個觀測處, 伴隨都必須跟著偏導的方向調整.
在附錄C中給出了L關于t0, t1偏導的解法. 附錄B中給出上面公式的詳細推導過程. 附錄D給出了上述算法scipy實現, 這部分代碼也支持更高階的微分.
https://github.com/rtqichen/torchdiffeq中還給出了pytorch版本的實現.
3 用ODE來取代ResNet進行有監(jiān)督的訓練
本節(jié)嘗試用神經ODE進行有監(jiān)督訓練.
軟件: 略 (作者說自己選取了某某ODE求解器, 還用一個第三方框架實現了求反向, 但是在pytorch版代碼中這些都對不上)
模型結構: 使用了一個小的殘差網絡, 對輸入進行了2次下采樣, 然后疊了6個標準殘差鏈接層, 這6個殘差連接層替換成ODE求解器模塊. 還測試了一下同樣結構, 但是反向直接用鏈式法則求解的網絡, 記為RK-Net. 各網絡的表現如下:
可以看到, ODE網絡和RK網絡可以達到和ResNet相同的性能.
ODE****網絡的誤差控制: ODE求解器可以保證計算誤差在真實解的某個誤差限內. 更改這個誤差限會改變網絡的性能表現. 圖3a展示了誤差是可控的. 圖3b展示了前向計算時間是跟著函數估值次數成比例增加的. 所以降低誤差限可以在計算速度和精度之間做取舍. 你可以在訓練時用高精度, 但是在推理時用低精度來加快速度..
圖3c表明: 反向計算的消耗只有前向計算的一半左右. 這就表明, 伴隨法不但節(jié)省內存, 還比直接求反向更加高效.
網絡深度: 在ODE中不太好直接定義網絡層數這個概念. 有點類似的是隱狀態(tài)方程估值所需的次數, 這依賴于ODE求解器的輸入和初始狀態(tài). 圖3d展示了訓練過程中估值次數的增加, 這對應了模型復雜度的增長.
4 連續(xù)正則化流
還有一個模型也出現了類似式1的非連續(xù)型方程, 那就是正則化流(NF, normalization flows)和NICE framework. 這些模式使用變量代換定理來計算可逆變換之后的概率密度.
經典的正則化流模型: planar normalization flows的公式如下:
, 它的計算復雜度要么是z維度的立方, 要么是隱藏單元數量的立方. 最近的研究都是在NF模型的表達能力和計算復雜度做取舍.
令人驚訝的是, 我們把非連續(xù)的模型公式, 用第3節(jié)同樣的思路來轉換成連續(xù)模型可以減少計算量.
定理1: 變量瞬時變化
設z(t)是一個有限連續(xù)隨機變量,概率p(z(t))依賴于時間. 則下式是z(t)隨時間連續(xù)變化的微分方程:
假設f在z上均勻Lipschitz連續(xù),在t上連續(xù),那么對數概率密度的變化也遵循微分方程:
證明見附錄A. 與式6的log計算不同, 本式只需要計算跡(trace)的操作. 另外, 不像標準的NF模型, 本式不要求f是可逆的, 因為如果滿足唯一性,那么整個轉換自然就是可逆的.
應用變量瞬時變化定理,我們可以看一下planar normalization flows的連續(xù)模擬版本:
給定一個初始分布p(z(0),我們可以從p(z(T))中采樣,并通過求解這組ODE來評估其概率密度。
使用多個線性成本的隱藏單元
當det(行列式)不是線性方程時, 跡的方程還是線性的, 并且滿足:
這樣我們的方程就可以由一系列的求和得到, 概率密度的微分方程也是一個求和:
這意味著我們可以很簡便的評估多隱藏單元的流模型,其成本僅與隱藏單元M的數量呈線性關系。使用標準的NF模型評估這種“寬”層的成本是O(M3),這意味著標準NF體系結構的多個層只使用單個隱藏單元.
依賴于時間的動態(tài)方程
我們可以將流的參數指定為t的函數,使微分方程f(z(t)、t)隨t而變化。這種參數化的方法是一種超網絡. 我們還為每個隱藏層引入了門機制:
, 是一個神經網絡, 可以學習到何時使用fn. 我們把該模型稱之為連續(xù)正則化流(CNF, continuous normalizing flows)
4.1 CNF試驗
我們首先比較連續(xù)的和離散的planar正則化流在學習樣本從一個已知的分布。我們證明了一個具有M個隱藏單元的連續(xù) planar CNF至少可以與一個具有K層(M = K)的離散 planar NF具有同樣的擬合能力,某些情況下CNF的擬合能力甚至更強.
擬合概率密度
設置一個前述的CNF, 用adam優(yōu)化器訓練10000個step. 對應的NF使用RMSprop訓練500000個step. 此任務中損失函數為KL (q(x)||p(x)), 最小化這個損失函數, 來用q(x)擬合目標概率分布p(x). 圖4表明, CNF可以得到更低的損失.
[圖片上傳失敗...(image-7d47a5-1616472352555)]
極大似然訓練
CNF一個有用的特性是: 計算反向轉換和正向的成本差不多, 這一點是NF模型做不到的. 這樣在用CNF模型做概率密度估計任務時, 我們可以通過極大似然估計來進行訓練 也就是最大化log(q(x))的期望值. 其中q是變量代換之后的函數. 然后反向轉換CNF來從q(x)中進行采樣.
該任務中, 我們使用64個隱藏單元的CNF和64層的NF來進行對比. 圖5展示了最終的訓練結果. 從最初的高斯分布, 到最終學到的分布, 每一個圖代表時間t的某一步. 有趣的是: 為了擬合兩個圓圈, CNF把planar 流 進行了旋轉, 這樣粒子會均分到兩個圓中. 跟 CNF的平滑可解釋相對的是, NF模型比較反直覺, 并且很難擬合雙月牙的概率分布(見圖5.b)
[圖片上傳失敗...(image-687aaa-1616472365302)]
5 生成式隱方程時間序列模型
將神經網絡應用于不規(guī)則采樣的數據,如醫(yī)療記錄、網絡流量或神經尖峰數據是困難的。 通常,觀測被放入固定持續(xù)時間的桶中,隱方程(變量?原文是dynamic)以同樣的方式進行離散。如果存在數據缺失或隱變量定義不當的情況, 問題就比較困難. 數據缺失可以用數據填充和生成時間序列模型來進行標記. 還有一種方式是給RNN的輸入加時間戳信息.
我們提出了一種連續(xù)時間,生成的方法來建模時間序列。我們的模型用一個隱軌跡來表示每個時間序列。每個軌跡都是由一個局部初始狀態(tài)zt0和跨所有時間序列共享的全局隱方程組來確定。給定觀測時間t0、t1、……tN和初始狀態(tài)zt0,ODE求解算器產生zt1,…ztN,描述每個觀測的潛在狀態(tài)。我們通過一個采樣程序正式地定義了這個生成模型:
函數f是一個時間無關的函數,在當前時間步長取z并輸出梯度:
我們用神經網絡來參數化這個方程. 因為f是時間無關的, 給定隱狀態(tài)z(t), 整個隱軌跡就是唯一確定的. 推斷隱軌跡可以讓我們在時間上任意向前或后退做出預測
訓練與預測
我們可以用觀測的序列將這個潛變量模型訓練為變分自動編碼器. 我們的判別模型RNN倒序的接收時間序列數據, 輸出q φ (z 0 |x 1 ,x 2 ,...,x N ). 詳見附錄E. 使用ODE來做生成模型, 我們就能在已知時間序列的情況下, 在任意時間點做出預測.
泊松過程似然
觀測本身就給出了一些隱狀態(tài)的信息, 比如說: 得病的人更傾向于做藥物測試. 事件發(fā)生率可以用隱方程來進行參數化:
給定這個概率函數,非均勻泊松過程給出了區(qū)間[tstart,tend]中獨立觀測的可能性:
我們可以使用另一個神經網絡來參數化λ(·)。因此,我們可以調用一次ODE求解器就評估出隱軌跡和泊松過程概率值。圖7為該模型在數據集上學習到的事件發(fā)生率。
觀測時間上的泊松過程似然可以與數據似然相結合,共同模擬所有觀測和時間。
5.1 事件序列隱ODE試驗
我們研究了隱ODE模型的擬合和推斷時間序列的能力。該判別網絡是一個有25個隱藏單元的RNN。我們使用一個四維的隱空間。我們用一個具有20個隱藏單元的單隱藏層網絡來參數化函數f。解碼器是一個神經網絡, 只有一個隱藏層, 20個隱藏單元, 用于計算p(x t i |z t i )。我們的基線是一個有25個隱藏單元的RNN,用最小化負高斯對數似然為目標函數訓練。我們訓練了這個RNN的第二個版本,其輸入與下一個觀測的時間差連接,以幫助RNN進行不規(guī)則的觀測。
雙向螺旋數據集
我們生成了一個1000個二維螺旋的數據集,每個螺旋從一個不同的點開始,在100個相同間隔的時間步長采樣。 數據集包含兩種類型的螺旋:一半是順時針方向,另一半是逆時針方向。 為了模擬真實情況,我們在觀測中加入高斯噪聲。
具有不規(guī)則時間點的時間序列
為了生成不規(guī)則的時間戳,我們不替換的從每個軌跡隨機采樣 (n={30,50,100}). 訓練數據之外, 我們展示了100個時間點的預測均方根誤差(RMSE)。 表2顯示,隱ODE預測時的RMSE明顯較低.
圖8展示了用下采樣的30個點來擬合螺旋的結果.
[圖片上傳失敗...(image-53050-1616472422476)]
隱ODE的重構是通過對潛在軌跡的后驗采樣并將其解碼為數據空間得到的. 附錄F展示了更多不同數據點的情況. 我們發(fā)現, 不管多少個點的下采樣, 不管有沒有高斯噪聲, 重建和推斷都和真實情況一致.
隱空間推斷
圖8c展示了隱軌跡投影到隱空間前2個維度的結果. 這是兩個軌跡群, 一個順時針一個逆時針. 圖9展示了: 初始狀態(tài)隱軌跡方程為順時針, 而后轉變?yōu)槟鏁r針, 這一轉變過程是非常連續(xù)的.
6 應用范圍與限制.
Mini-Batch
Mini-Batch的使用不如標準神經網絡那么直觀。我們仍然可以通過將每個batch的狀態(tài)連接在一起,創(chuàng)建維度D×K的ODE方程組,通過ODE求解器來計算。In some cases, controlling
error on all batch elements together might require evaluating the combined system K times more
often than if each system was solved individually(不太懂什么意思)。不過,在實踐中使用Mini-Batch時,計算量并沒有大幅增加.
唯一性
什么情況下連續(xù)方程有唯一解? 皮卡存在定理限定了, 當微分方程Lipschitz連續(xù)并且z在t上連續(xù)時, 初值問題的解存在且唯一. 這就對我們使用的神經網絡有所限制, 模型的權重有限, 且不能使用非Lipschitz連續(xù)的激活函數, 比如tanh或者relu.
設置計算精度
模型允許用戶在計算精度和速度之間做trade-off, 需要用戶在訓練的前向和反向中設置誤差限. 對于序列模型, 默認值為1.5e-8. 在分類和概率密度擬合問題中, 不降低模型性能的情況下, 默認值可設置為1e-3和1e-5.
重建前向軌跡
如果重建的軌跡偏離了原軌跡,則通過向后運行的方程來重建狀態(tài)軌跡會帶來額外的數值誤差。這個問題可以通過checkpoint來解決:將z的中間值存儲在前向過程中,并通過從這些點重新積分來重建精確的前向軌跡。不過在實際計算中這不是一個問題,多層CNF的反向可以恢復到初始狀態(tài).
7 相關工作
略
8 結語
我們探索了黑盒ODE求解器作為模型的一部分, 并用它開發(fā)了新模型可以用于時間序列問題, 監(jiān)督學習問題, 概率密度估計問題. 這些模型可以自適應的進行估值計算, 并且允許用戶顯式的在計算速度和精度之間做取舍. 最終, 我們提出了連續(xù)版本的變量代換模型, 命名為CNF, 該模型的層可以擴展到比較大的尺度.
9 注:
我沒有對附錄和參考文獻做翻譯, 這部分大家請下載論文原文查看: https://arxiv.org/pdf/1806.07366