文章目錄
1.梯度
2.多元線性回歸參數求解
3.梯度下降
4.梯度下降法求解多元線性回歸
梯度下降算法在機器學習中出現頻率特別高,是非常常用的優化算法。
本文借多元線性回歸,用人話解釋清楚梯度下降的原理和步驟。
1.梯度
梯度是什么呢?
我們還是從最簡單的情況說起,對于一元函數來講,梯度就是函數的導數。
而對于多元函數而言,梯度是一個向量,也就是說,把求得的偏導數以向量的形式寫出來,就是梯度。
例如,我們在用人話講明白線性回歸LinearRegression一文中,求未知參數 和
時,對損失函數求偏導,此時的梯度向量為
,其中:
那篇文章中,因為一元線性回歸中只有2個參數,因此令兩個偏導數為0,能很容易求得 和
的解。
但是,這種求導的方法在多元回歸的參數求解中就不太實用了,為什么呢?
2.多元線性回歸參數求解
多元線性回歸方程的一般形式為:
可以簡寫為矩陣形式(一般加粗表示矩陣或向量):
其中,
之前我們介紹過一元線性回歸的損失函數可以用殘差平方和:
代入多元線性回歸方程就是:
用矩陣形式表示:
上面的展開過程涉及矩陣轉置,這里簡單提一下矩陣轉置相關運算,以免之前學過但是現在忘了:
好了,按照一元線性回歸求解析解的思路,現在我們要對Q求導并令導數為0(原諒我懶,后面寫公式就不對向量或矩陣加粗了,大家能理解就行):
上面的推導過程涉及矩陣求導,這里以求導為例展開講下,為什么
,其他幾項留給大家舉一反三。
首先:
為了直觀點,我們將記為A,因為Y是n維列向量,X是n×(p+1)的矩陣,因此
是(p+1)維行向量:
那么上面求導可以簡寫為:
這種形式的矩陣求導屬于分母布局,即分子為行向量或者分母為列向量(這里屬于后者)。
搞不清楚的可以看看這篇:矩陣求導實例,這里我直接寫出標量/列向量求導的公式,如下(y表示標量,X表示列向量):
根據上式,顯然有:
前面我們將記為A,
,那么上面算出來的結果就是
,即
。
說了這么多有的沒的,最終我想說是的,里面涉及到矩陣求逆,但實際問題中可能X沒有逆矩陣,這時計算的結果就不夠精確。
第二個問題就是,如果維度多、樣本多,即便有逆矩陣,計算機求解的速度也會很慢。
所以,基于上面這兩點,一般情況下我們不會用解析解求解法求多元線性回歸參數,而是采用梯度下降法,它的計算代價相對更低。
3.梯度下降
好了,重點來了,本文真正要講的東西終于登場了。
梯度下降,就是通過一步步迭代,讓所有偏導函數都下降到最低。如果覺得不好理解,我們就還是以最簡單的一元函數為例開始講。
下圖是我用Excel簡單畫的二次函數圖像(看起來有點歪,原諒我懶……懶得調整了……),函數為,它的導數為y=2x。
如果我們初始化的點在x=1處,它的導函數值,也就是梯度值是2,為正,那就讓它往左移一點,繼續計算它的梯度值,若為正,就繼續往左移。
如果我們初始化的點在x=-1處,該處的梯度值是-2,為負,那就讓它往右移。
多元函數的邏輯也一樣,先初始化一個點,也就是隨便選擇一個位置,計算它的梯度,然后往梯度相反的方向,每次移動一點點,直到達到停止條件。
這個停止條件,可以是足夠大的迭代步數,也可以是一個比較小的閾值,當兩次迭代之間的差值小于該閾值時,認為梯度已經下降到最低點附近了。
二元函數的梯度下降示例如上圖(圖片來自梯度下降),對于這種非凸函數,可能會出現這種情況:初始化的點不同,最后的結果也不同,也就是陷入局部最小值。
這種問題比較有效的解決方法,就是多取幾個初始點。不過對于我們接下來講的多元線性回歸,以及后面要講的邏輯回歸,都不存在這個問題,因為他們的損失函數都是凸函數,有全局最小值。
用數學公式來描述梯度下降的步驟,就是:
解釋下公式含義:
-
為k時刻的點坐標,
為下一刻要移動到的點的坐標,例如
就代表初始化的點坐標,
就代表第一步到移動到的位置;
- g代表梯度,前面有個負號,就代表梯度下降,即朝著梯度相反的反向移動;
-
被稱為步長,用它乘以梯度值來控制每次移動的距離,這個值的設定也是一門學問,設定的過小,迭代的次數就會過多,設定的過大,容易一步跨太遠,直接跳過了最小值。
在這里插入圖片描述
4.梯度下降法求解多元線性回歸
回到前面的多元線性回歸,我們用梯度下降算法求損失函數的最小值。
首先,求梯度,也就是前面我們已經給出的求偏導的公式:
將梯度代入隨機梯度下降公式:
這個式子中,X矩陣和Y向量都是已知的,步長是人為設定的一個值,只有參數是未知的,而每一步的
是由
決定的,也就是每一步的點坐標。
算法過程:
- 初始化
向量的值,即
,將其代入
得到當前位置的梯度;
- 用步長
乘以當前梯度,得到從當前位置下降的距離;
- 更新
,其更新表達式為
;
- 重復以上步驟,直到更新到某個
,達到停止條件,這個
就是我們求解的參數向量。
參考鏈接:
深入淺出--梯度下降法及其實現
梯度下降與隨機梯度下降概念及推導過程
文中圖片水印為本人博客地址:https://blog.csdn.net/simplification