目的
scikit-learnで作成した分類器(決定木/ランダムフォレスト)を、外部ファイルとしてエクスポートしたい。 つまり、他のプログラムで読み込める、scikit-learnやpythonを使わずとも実装できる形式で出力したい。決定木
作成した分類器をそのままexport_graphvizするだけ。from sklearn.datasets import load_iris from sklearn import tree clf = tree.DecisionTreeClassifier() iris = load_iris() clf = clf.fit(iris.data, iris.target) tree.export_graphviz(clf, out_file='tree.dot')
ランダムフォレスト
そのままexport_graphvizすることはできず、一手間必要。ランダムフォレストで作成した分類器のclf.estimators_で決定木のリストを取得できるので、それぞれのリストに対して上記同様にexport_graphvizする。
この場合、木の本数分dotファイルが生成される。下記の例では100本の木を作成するので、tree_0.dot~tree_99.dotが出力される。
from sklearn.datasets import load_iris from sklearn import tree from sklearn.ensemble import RandomForestClassifier clf = RandomForestClassifier(n_estimators=100) iris = load_iris() clf = clf.fit(iris.data, iris.target) for i,val in enumerate(clf.estimators_): tree.export_graphviz(clf.estimators_[i], out_file='tree_%d.dot'%i)
出力例
digraph Tree { 0 [label="X[3] <= 0.8000\ngini = 0.666666666667\nsamples = 150", shape="box"] ; 1 [label="gini = 0.0000\nsamples = 50\nvalue = [ 50. 0. 0.]", shape="box"] ; 0 -> 1 ; 2 [label="X[3] <= 1.7500\ngini = 0.5\nsamples = 100", shape="box"] ; 0 -> 2 ; 3 [label="X[2] <= 4.9500\ngini = 0.168038408779\nsamples = 54", shape="box"] ; 2 -> 3 ; 4 [label="X[3] <= 1.6500\ngini = 0.0407986111111\nsamples = 48", shape="box"] ; 3 -> 4 ; 5 [label="gini = 0.0000\nsamples = 47\nvalue = [ 0. 47. 0.]", shape="box"] ; 4 -> 5 ; 6 [label="gini = 0.0000\nsamples = 1\nvalue = [ 0. 0. 1.]", shape="box"] ; 4 -> 6 ; 7 [label="X[3] <= 1.5500\ngini = 0.444444444444\nsamples = 6", shape="box"] ; 3 -> 7 ; 8 [label="gini = 0.0000\nsamples = 3\nvalue = [ 0. 0. 3.]", shape="box"] ; 7 -> 8 ; 9 [label="X[2] <= 5.4500\ngini = 0.444444444444\nsamples = 3", shape="box"] ; 7 -> 9 ; 10 [label="gini = 0.0000\nsamples = 2\nvalue = [ 0. 2. 0.]", shape="box"] ; 9 -> 10 ; 11 [label="gini = 0.0000\nsamples = 1\nvalue = [ 0. 0. 1.]", shape="box"] ; 9 -> 11 ; 12 [label="X[2] <= 4.8500\ngini = 0.0425330812854\nsamples = 46", shape="box"] ; 2 -> 12 ; 13 [label="X[1] <= 3.1000\ngini = 0.444444444444\nsamples = 3", shape="box"] ; 12 -> 13 ; 14 [label="gini = 0.0000\nsamples = 2\nvalue = [ 0. 0. 2.]", shape="box"] ; 13 -> 14 ; 15 [label="gini = 0.0000\nsamples = 1\nvalue = [ 0. 1. 0.]", shape="box"] ; 13 -> 15 ; 16 [label="gini = 0.0000\nsamples = 43\nvalue = [ 0. 0. 43.]", shape="box"] ; 12 -> 16 ; }
参考リンク
- sklearn.tree.export_graphviz — scikit-learn 0.17.1 documentation
- 3.2.4.3.1. sklearn.ensemble.RandomForestClassifier — scikit-learn 0.17.1 documentation
- 3.2.4.3.2. sklearn.ensemble.RandomForestRegressor — scikit-learn 0.17.1 documentation
- python - Random Forest interpretation in scikit-learn - Stack Overflow