Waakvlam blog

Waakvlam blog

「創りたい気持ち」に火をつけるWaakvlamの活動記録

Random Forestが生成した決定木からルールを抽出する

UEQareer Advent Calendar 2018 | 23日目
こんにちは!
Appleが大好きななかりんです٩( ᐛ )و
2日前とは別のアドベントカレンダーです.

今回は機械学習のRandom Forestを題材にして書きます.
(実はiOSアプリ開発以外もやってます!)
PythonのsklearnのRandomForestClassifierを使います.

RFへの入力データがどのように分類されたかを調べたかったのですが,
画像で出力する以外の記事がなかなか見当たらなかったため,
試行錯誤の結果をまとめてみました.

【目次】

あらすじ

学習後のRandom Forestにデータを入力して
出力した結果がどのように分類されたかを分析します.

とりあえずRandom Forestで学習させよう

sklearnが用意しているirisのデータセットを用います.
コードは以下のようになります.

import numpy as np
from sklearn import datasets
from sklearn import __version__ as sklearn_version
from sklearn.tree import DecisionTreeClassifier
from sklearn.ensemble import RandomForestClassifier
from sklearn.model_selection import train_test_split

def main():
    use_feature_index = [2, 3]
    iris = datasets.load_iris()
    X = iris.data[:, use_feature_index]
    y = iris.target
    feature_names = np.array(iris.feature_names)[use_feature_index]
    X_train, X_test, y_train, y_test = train_test_split(
        X, y, test_size=0.3, random_state=0)

    rf = RandomForestClassifier(n_estimators=50, max_depth=3, random_state=0)
    rf.fit(X_train, y_train)
    score_s = rf.score(X_test, y_test)

    # scoreの出力
    print("sklearn random forest score:" + str(score_s))

if __name__ == "__main__":
    main()

出力は以下のようになります.

sklearn random forest score:0.977777777778

RFが生成した木の中身を見よう

よく見る可視化の定番のやつです.
以下のものが必要になります.

導入に関して,
Windowsの方は こちら
Macの方は こちらの下の方

コードは以下のようになります.

import numpy as np
from sklearn import datasets
from sklearn import __version__ as sklearn_version
from sklearn.tree import DecisionTreeClassifier
from sklearn.ensemble import RandomForestClassifier
from sklearn.model_selection import train_test_split
from sklearn import tree
import pydotplus as pdp

def main():
    use_feature_index = [2, 3]
    iris = datasets.load_iris()
    X = iris.data[:, use_feature_index]
    y = iris.target
    features = np.array(iris.feature_names)[use_feature_index]
    X_train, X_test, y_train, y_test = train_test_split(
        X, y, test_size=0.3, random_state=0)

    rf = RandomForestClassifier(n_estimators=50, max_depth=3, random_state=0)
    rf.fit(X_train, y_train)
    score_s = rf.score(X_test, y_test)

    # scoreの出力
    print("-" * 50)
    print("sklearn random forest score:" + str(score_s))

    # 生成された木の1個目を可視化
    estimator = rf.estimators_[0]
    filename = "./tree.png"
    dot_data = tree.export_graphviz(
				estimator,
				out_file=None,
				filled=True,
				rounded=True,
				feature_names=features,
				class_names=iris.target_names,
				special_characters=True
				)
    graph = pdp.graph_from_dot_data(dot_data)
    graph.write_png(filename)

if __name__ == "__main__":
    main()

50個生成された木の1本目は下図のようになっています.

各ノードの不等式が条件で,valuesが各種の花の数となってます.

values[0] → setosaの数
values[1] → versicolorの数
values[2] → virginicaの数

さて,出力できた木の画像ですが,これは学習させたデータの分類の過程を示していて,
実際に入力したデータがどのように分類できたかがわかりません.

決定木にあるメソッドを使う

RFで生成されるものは結局ただの決定木です.
その決定木ですが,当然ながら,sklearnの決定木のように扱えます.
ドキュメントを読んでいたら,

decision_path(X, check_input=True)

こんなものがありました.
入力データの分類の過程が見れるらしい!
ということで実行.

>>> path = estimator.decision_path(X_test[0])
>>> print(path)
  (0, 0)	1
  (0, 6)	1
  (0, 10)	1

一瞬見た時,「なんだこれは?🤔」となりましたが,
それぞれの括弧の中の2つ目の数字がノード番号を表しています.
ノードの順番は以下のようになっています.

データを入力した時,そいつがどのように分類されるかがわかるようになりました.
しかし,大量の入力データ,大量な木があるため,
このように出力してもとても扱いづらいです.
せめて,(0, 6, 10)という感じにしたいですね.

ではやりましょう.

出力結果に一工夫

以下のようにすると0と1だけのリストに変換することができます.
1が通ったノードです.

>>> np.array(path.todense())[0]
array([1, 0, 0, 0, 0, 0, 1, 0, 0, 0, 1])

これを以下のように処理してあげれば
(0,6,10)というように変換できます.

>>> new_array = np.array(path.todense())[0]
>>> new_path = []
>>> for i in range(len(new_array)):
...     if new_array[i] == 1:
...             new_path.append(i)
... 
>>> new_path
[0, 6, 10]

これで,扱いやすくなりました.

応用編

せっかく扱いやすくしたので,色々試してみましょう.

複数の入力データの分類経路とその経路が使われた回数を出力

コードを見せるのが一番早いですね.

import collections # これ必要!
import csv # これ必要!

with open("./tree_analyze.csv", "w") as f:
    writer = csv.writer(f, lineterminator='\n')
    for tree in rf.estimators_:
        all_path = []
        for input in X_test:
            new_array = np.array(tree.decision_path(input).todense())[0]
            new_path = ""
            for i in range(len(new_array)):
                if new_array[i] == 1:
                    new_path = new_path + "," + str(i)
            all_path.append(new_path)
    
        counted_array = collections.Counter(all_path)
        values, counts = zip(*counted_array.most_common())
        output_line = []
        for i in range(len(values)):
            output_line.append(values[i])
            output_line.append(counts[i])
        writer.writerow(output_line)

各木に対して,使われた経路が多い順に
並べたものをCSVとして出力するプログラムです.

決定木のpredictというメソッドを使えば,出力結果がわかります.
こいつと正解データを組み合わせれば,
正しく分類できた経路だけを抽出することも可能ですね.

抽出した経路の判定条件を抽出する

経路がわかったら,各ノードの分類条件も知りたいですねよ.
sklearnの決定木にtree_というものがあります.
このtree_の中の

  • feature (各ノードで使用している特徴量)
  • threshold (各ノードで使用している特徴量の閾値)

を使ってあげれば各ノードの分類条件を抽出できます.

>>> tree.tree_.feature
array([ 0,  1, -2,  0, -2, -2,  1,  1, -2, -2, -2])
>>> tree.tree_.threshold
array([ 4.94999981,  0.75      , -2.        ,  4.44999981, -2.        ,
       -2.        ,  1.75      ,  1.60000002, -2.        , -2.        , -2.        ])

こいつらを使ったプログラムが以下のものです.

tree = rf.estimators_[0]
features = np.array(iris.feature_names)[use_feature_index]

path = estimator.decision_path(X_test[0])
new_array = np.array(path.todense())[0]
new_path = []
for i in range(len(new_array)):
    if new_array[i] == 1:
        new_path.append(i)

for i in range(1, len(new_path)):
    label = features[tree.tree_.feature[new_path[i-1]]]
    if new_path[i] - new_path[i-1] == 1:
        label = label + " ≦ "
    else:
        label = label + " > "
    print(label + str(tree.tree_.threshold[new_path[i-1]]))

(出力)
petal length (cm) > 4.94999980927
petal width (cm) > 1.75

先に示した決定木の画像と見比べると,
各ノードにある条件をしっかり抽出できてますね.

おわりに

決定木を画像として出力する記事はたくさんあります.
しかし,画像データではちょっと不便ですよね.
RFで生成された決定木の中身を見る場合,数が膨大になるし,
深さが深いほど分析が大変になります.
よって今回は画像ではなく扱いやすいリスト型とかで
抽出するということに試みました.
これで頑張ろうと思えば各分類経路をクラスタリングして,
入力データがどんな特徴を持って何種類あるのか
ということとかもわかるかもしれませんね.

※方法は上記の他にもあると思います


P.S.
機械学習に関してまだまだ未熟なので
頑張って勉強します...
世界のRFを使っている研究者と繋がりたい.

今後も宜しくお願いします٩( ᐛ )و
twitter.com
github.com
ほしい物リスト