Spark Mlib- Decision Tree
Q:決策樹(shù)是什么?
A:決策樹(shù)是模擬人類(lèi)決策過(guò)程,將判斷一件事情所要做的一系列決策的各種可能的集合,以數(shù)的形式展現(xiàn)出來(lái),的一中樹(shù)形圖。
Q:決策樹(shù)的結(jié)構(gòu)是怎樣的?
A:決策樹(shù)與普通樹(shù)一樣,由節(jié)點(diǎn)和邊組成。樹(shù)中每一個(gè)節(jié)點(diǎn)都是一個(gè)屬性(特征),或者說(shuō)是對(duì)特征的判斷。根據(jù)一個(gè)節(jié)點(diǎn)的判斷結(jié)果,決策(預(yù)測(cè))流程走向不同的子節(jié)點(diǎn),或者直接到達(dá)葉節(jié)點(diǎn),即決策(預(yù)測(cè))結(jié)束,得到結(jié)果。
Q:決策樹(shù)是怎么訓(xùn)練出來(lái)的?
A:典型的決策樹(shù)的訓(xùn)練過(guò)程如下,以根據(jù)色澤、根蒂、敲聲預(yù)測(cè)一個(gè)西瓜是否好瓜為例——
- {failImgCache = [];}if(failImgCache.indexOf(src) == -1 && src.trim().length){failImgCache.push(src);}$(this).closest('.md-image').addClass('md-img-error').removeClass('md-img-loaded'); " onload="var src = window.removeLastModifyQuery(this.getAttribute('src'));if(!src.trim()) return;if(loadedImgCache.indexOf(src) == -1 && src.trim().length){loadedImgCache.push(src);}$(this).closest('.md-image').addClass('md-img-loaded').removeClass('md-img-error');" style="box-sizing: border-box; border-width: 0px 4px 0px 2px; border-right-style: solid; border-left-style: solid; border-right-color: transparent; border-left-color: transparent; vertical-align: middle; max-width: 100%; cursor: default;">
1、輸入一個(gè)數(shù)據(jù)集D2、生成一個(gè)節(jié)點(diǎn)node3、如果數(shù)據(jù)集中的樣本全部屬于同一類(lèi),比如西瓜樣本全部是“好瓜”,那么node就是葉子節(jié)點(diǎn)(好瓜)4、如果樣本中的數(shù)據(jù)集不屬于同一類(lèi),比如西瓜中既有好瓜也有壞瓜,那就選擇一個(gè)屬性,把西瓜根據(jù)選好的屬性分類(lèi)。比如按照“紋理”屬性,把西瓜分為清晰、模糊、稍糊三類(lèi)。 - {failImgCache = [];}if(failImgCache.indexOf(src) == -1 && src.trim().length){failImgCache.push(src);}$(this).closest('.md-image').addClass('md-img-error').removeClass('md-img-loaded'); " onload="var src = window.removeLastModifyQuery(this.getAttribute('src'));if(!src.trim()) return;if(loadedImgCache.indexOf(src) == -1 && src.trim().length){loadedImgCache.push(src);}$(this).closest('.md-image').addClass('md-img-loaded').removeClass('md-img-error');" style="box-sizing: border-box; border-width: 0px 4px 0px 2px; border-right-style: solid; border-left-style: solid; border-right-color: transparent; border-left-color: transparent; vertical-align: middle; max-width: 100%; cursor: default;">
5、把第4步得到的幾個(gè)數(shù)據(jù)子集作為輸入數(shù)據(jù),分別執(zhí)行上面的第1到5步、直到不再執(zhí)行第4步為止(也就是葉子節(jié)點(diǎn)全部構(gòu)建完成,算法結(jié)束)。
Q:整個(gè)決策樹(shù)的訓(xùn)練算法很簡(jiǎn)潔,但是第4步的“選擇一個(gè)屬性,把西瓜根據(jù)選好的屬性分類(lèi)”,怎樣來(lái)選擇合適的屬性呢?
A:假設(shè)我們現(xiàn)在只有兩種屬性選擇,一種是“色澤”、另一種種是“觸感”。選擇“色澤”,我們可以把西瓜分成三類(lèi)——“淺白”全是壞瓜;“墨綠”全是好瓜;“青綠”85%是好瓜,15%是壞瓜。選擇“觸感”,我們可以把西瓜分為兩類(lèi)——“硬滑”一半是好瓜,一半是壞瓜;“軟粘”40%是好瓜、60%是壞瓜。
稍一思考,我們自然會(huì)選“色澤”作為本次的屬性。因?yàn)椤吧珴伞笨梢砸幌掳押芏嗪霉虾蛪墓蠀^(qū)分開(kāi)來(lái),也就是說(shuō)我們知道了一個(gè)西瓜的色澤后有很大幾率正確判斷它是好瓜還是壞瓜。而目前來(lái)說(shuō),知道“觸感”卻對(duì)我們的判斷沒(méi)什么幫助。因此,我們選擇“能給我們帶來(lái)更多信息”的屬性,或者說(shuō)能夠“減少混淆程度”的屬性。
假設(shè)Pk是指當(dāng)前集合中第k類(lèi)樣本占的比例。比如10個(gè)西瓜中紋理清晰、模糊、稍糊的各有3,3,4個(gè),那么p1=3/10,p2=3/10,p3=4/10。于是根據(jù)信息的定義,我們有信息熵: - {failImgCache = [];}if(failImgCache.indexOf(src) == -1 && src.trim().length){failImgCache.push(src);}$(this).closest('.md-image').addClass('md-img-error').removeClass('md-img-loaded'); " onload="var src = window.removeLastModifyQuery(this.getAttribute('src'));if(!src.trim()) return;if(loadedImgCache.indexOf(src) == -1 && src.trim().length){loadedImgCache.push(src);}$(this).closest('.md-image').addClass('md-img-loaded').removeClass('md-img-error');" style="box-sizing: border-box; border-width: 0px 4px 0px 2px; border-right-style: solid; border-left-style: solid; border-right-color: transparent; border-left-color: transparent; vertical-align: middle; max-width: 100%; cursor: default;">
衡量那一個(gè)屬性“能給我們帶來(lái)更多信息”,我們用“信息增益”這個(gè)指標(biāo): - {failImgCache = [];}if(failImgCache.indexOf(src) == -1 && src.trim().length){failImgCache.push(src);}$(this).closest('.md-image').addClass('md-img-error').removeClass('md-img-loaded'); " onload="var src = window.removeLastModifyQuery(this.getAttribute('src'));if(!src.trim()) return;if(loadedImgCache.indexOf(src) == -1 && src.trim().length){loadedImgCache.push(src);}$(this).closest('.md-image').addClass('md-img-loaded').removeClass('md-img-error');" style="box-sizing: border-box; border-width: 0px 4px 0px 2px; border-right-style: solid; border-left-style: solid; border-right-color: transparent; border-left-color: transparent; vertical-align: middle; max-width: 100%; cursor: default;">
信息增益越大,越能帶來(lái)信息。因此每一次在上述算法第4步選擇屬性作為節(jié)點(diǎn)時(shí),我們對(duì)待選屬性都做一次信息增益的計(jì)算,選擇信息增益最大的屬性。
Q:有哪些對(duì)上述決策樹(shù)算法改進(jìn)的方案呢?
A:有兩種思路——1、改進(jìn)算法第4步中選擇屬性用的指標(biāo),比如用“增益率”或者“基尼系數(shù)”來(lái)代替“信息增益”。2、用“剪枝處理”,也就是去掉一些無(wú)用的分枝來(lái)降低“過(guò)擬合”的風(fēng)險(xiǎn)。
Q:什么是增益率?什么是基尼系數(shù)?
A:信息增益對(duì)于選項(xiàng)多的屬性有偏好。如果我們把訓(xùn)練樣本中每個(gè)西瓜編號(hào),然后把編號(hào)也作為待選擇屬性,那么編號(hào)肯定能帶來(lái)最多的信息,最大程度降低混淆程度,所以編號(hào)這個(gè)屬性肯定會(huì)被選中。但是這然并卵,因?yàn)檫@樣訓(xùn)練出來(lái)的算法對(duì)于新的樣本根本不起作用。所以我們不會(huì)選擇編號(hào)作為屬性。類(lèi)似的,如果有些屬性的可選項(xiàng)特別多,比如色澤現(xiàn)在有淺白、白、青綠、綠、墨綠五個(gè)可選項(xiàng),那么色澤被選中的幾率比其他屬性要大。所以可以考慮用增益率代替信息增益: - {failImgCache = [];}if(failImgCache.indexOf(src) == -1 && src.trim().length){failImgCache.push(src);}$(this).closest('.md-image').addClass('md-img-error').removeClass('md-img-loaded'); " onload="var src = window.removeLastModifyQuery(this.getAttribute('src'));if(!src.trim()) return;if(loadedImgCache.indexOf(src) == -1 && src.trim().length){loadedImgCache.push(src);}$(this).closest('.md-image').addClass('md-img-loaded').removeClass('md-img-error');" style="box-sizing: border-box; border-width: 0px 4px 0px 2px; border-right-style: solid; border-left-style: solid; border-right-color: transparent; border-left-color: transparent; vertical-align: middle; max-width: 100%; cursor: default;">
基尼系數(shù)則是另一項(xiàng)可以考慮的指標(biāo): - {failImgCache = [];}if(failImgCache.indexOf(src) == -1 && src.trim().length){failImgCache.push(src);}$(this).closest('.md-image').addClass('md-img-error').removeClass('md-img-loaded'); " onload="var src = window.removeLastModifyQuery(this.getAttribute('src'));if(!src.trim()) return;if(loadedImgCache.indexOf(src) == -1 && src.trim().length){loadedImgCache.push(src);}$(this).closest('.md-image').addClass('md-img-loaded').removeClass('md-img-error');" style="box-sizing: border-box; border-width: 0px 4px 0px 2px; border-right-style: solid; border-left-style: solid; border-right-color: transparent; border-left-color: transparent; vertical-align: middle; max-width: 100%; cursor: default;">
基尼系數(shù)表示從一堆樣本中隨機(jī)抽取兩個(gè)樣本,這 兩個(gè)樣本不同類(lèi)的概率,也就剛是一個(gè)是好瓜一個(gè)是壞瓜的概率。這樣,基尼系數(shù)越低,越意味著樣本集中某一類(lèi)比另一類(lèi)樣本多,也就是說(shuō),混淆程度越低。若選擇某個(gè)屬性后各個(gè)劃分子集的基尼系數(shù)最低,那么就選擇這個(gè)屬性。
Q:剪枝處理的過(guò)程是怎樣進(jìn)行的?
A:剪枝分為預(yù)剪枝和后剪枝。預(yù)剪枝過(guò)程就是雜生成決策樹(shù)時(shí),用一些新的樣本去測(cè)試剛剛訓(xùn)練好的節(jié)點(diǎn),如果這個(gè)節(jié)點(diǎn)的存在并不能讓決策樹(shù)的泛化性能提高,也就是說(shuō)分類(lèi)精度或者其他衡量指標(biāo)提高了,就把這個(gè)剛訓(xùn)練出來(lái)的節(jié)點(diǎn)舍棄。后剪枝過(guò)程是用一些新的樣本來(lái)測(cè)試這課剛剛訓(xùn)練出來(lái)的決策樹(shù),從下往上開(kāi)始測(cè)試,如果某個(gè)節(jié)點(diǎn)被換成葉節(jié)點(diǎn)后整一棵決策樹(shù)的泛化性能提高了,,也就是說(shuō)分類(lèi)精度或者其他衡量指標(biāo)提高了那么就直接替換掉。
實(shí)現(xiàn)
import org.apache.spark.ml.Pipelineimport org.apache.spark.ml.classification.DecisionTreeClassificationModelimport org.apache.spark.ml.classification.DecisionTreeClassifierimport org.apache.spark.ml.evaluation.MulticlassClassificationEvaluatorimport org.apache.spark.ml.feature.{IndexToString, StringIndexer, VectorIndexer}?// Load the data stored in LIBSVM format as a DataFrame.val data = spark.read.format("libsvm").load("data/mllib/sample_libsvm_data.txt")?// Index labels, adding metadata to the label column.// Fit on whole dataset to include all labels in index.val labelIndexer = new StringIndexer() .setInputCol("label") .setOutputCol("indexedLabel") .fit(data)// Automatically identify categorical features, and index them.val featureIndexer = new VectorIndexer() .setInputCol("features") .setOutputCol("indexedFeatures") .setMaxCategories(4) // features with > 4 distinct values are treated as continuous. .fit(data)?// Split the data into training and test sets (30% held out for testing).val Array(trainingData, testData) = data.randomSplit(Array(0.7, 0.3))?// Train a DecisionTree model.val dt = new DecisionTreeClassifier() .setLabelCol("indexedLabel") .setFeaturesCol("indexedFeatures")?// Convert indexed labels back to original labels.val labelConverter = new IndexToString() .setInputCol("prediction") .setOutputCol("predictedLabel") .setLabels(labelIndexer.labels)?// Chain indexers and tree in a Pipeline.val pipeline = new Pipeline() .setStages(Array(labelIndexer, featureIndexer, dt, labelConverter))?// Train model. This also runs the indexers.val model = pipeline.fit(trainingData)?// Make predictions.val predictions = model.transform(testData)?// Select example rows to display.predictions.select("predictedLabel", "label", "features").show(5)?// Select (prediction, true label) and compute test error.val evaluator = new MulticlassClassificationEvaluator() .setLabelCol("indexedLabel") .setPredictionCol("prediction") .setMetricName("accuracy")val accuracy = evaluator.evaluate(predictions)println("Test Error = " + (1.0 - accuracy))?val treeModel = model.stages(2).asInstanceOf[DecisionTreeClassificationModel]println("Learned classification tree model:\n" + treeModel.toDebugString)