SyntaxHighlighter

2016年6月21日火曜日

scikit-learnで作成した分類器をエクスポートする

目的

scikit-learnで作成した分類器(決定木/ランダムフォレスト)を、外部ファイルとしてエクスポートしたい。 つまり、他のプログラムで読み込める、scikit-learnやpythonを使わずとも実装できる形式で出力したい。

決定木

作成した分類器をそのままexport_graphvizするだけ。
from sklearn.datasets import load_iris
from sklearn import tree

clf = tree.DecisionTreeClassifier()
iris = load_iris()
clf = clf.fit(iris.data, iris.target)
tree.export_graphviz(clf, out_file='tree.dot')

ランダムフォレスト

そのままexport_graphvizすることはできず、一手間必要。
ランダムフォレストで作成した分類器のclf.estimators_で決定木のリストを取得できるので、それぞれのリストに対して上記同様にexport_graphvizする。
この場合、木の本数分dotファイルが生成される。下記の例では100本の木を作成するので、tree_0.dot~tree_99.dotが出力される。
from sklearn.datasets import load_iris
from sklearn import tree
from sklearn.ensemble import RandomForestClassifier

clf = RandomForestClassifier(n_estimators=100)
iris = load_iris()
clf = clf.fit(iris.data, iris.target)      
for i,val in enumerate(clf.estimators_):
    tree.export_graphviz(clf.estimators_[i], out_file='tree_%d.dot'%i)

出力例

digraph Tree {
0 [label="X[3] <= 0.8000\ngini = 0.666666666667\nsamples = 150", shape="box"] ;
1 [label="gini = 0.0000\nsamples = 50\nvalue = [ 50.   0.   0.]", shape="box"] ;
0 -> 1 ;
2 [label="X[3] <= 1.7500\ngini = 0.5\nsamples = 100", shape="box"] ;
0 -> 2 ;
3 [label="X[2] <= 4.9500\ngini = 0.168038408779\nsamples = 54", shape="box"] ;
2 -> 3 ;
4 [label="X[3] <= 1.6500\ngini = 0.0407986111111\nsamples = 48", shape="box"] ;
3 -> 4 ;
5 [label="gini = 0.0000\nsamples = 47\nvalue = [  0.  47.   0.]", shape="box"] ;
4 -> 5 ;
6 [label="gini = 0.0000\nsamples = 1\nvalue = [ 0.  0.  1.]", shape="box"] ;
4 -> 6 ;
7 [label="X[3] <= 1.5500\ngini = 0.444444444444\nsamples = 6", shape="box"] ;
3 -> 7 ;
8 [label="gini = 0.0000\nsamples = 3\nvalue = [ 0.  0.  3.]", shape="box"] ;
7 -> 8 ;
9 [label="X[2] <= 5.4500\ngini = 0.444444444444\nsamples = 3", shape="box"] ;
7 -> 9 ;
10 [label="gini = 0.0000\nsamples = 2\nvalue = [ 0.  2.  0.]", shape="box"] ;
9 -> 10 ;
11 [label="gini = 0.0000\nsamples = 1\nvalue = [ 0.  0.  1.]", shape="box"] ;
9 -> 11 ;
12 [label="X[2] <= 4.8500\ngini = 0.0425330812854\nsamples = 46", shape="box"] ;
2 -> 12 ;
13 [label="X[1] <= 3.1000\ngini = 0.444444444444\nsamples = 3", shape="box"] ;
12 -> 13 ;
14 [label="gini = 0.0000\nsamples = 2\nvalue = [ 0.  0.  2.]", shape="box"] ;
13 -> 14 ;
15 [label="gini = 0.0000\nsamples = 1\nvalue = [ 0.  1.  0.]", shape="box"] ;
13 -> 15 ;
16 [label="gini = 0.0000\nsamples = 43\nvalue = [  0.   0.  43.]", shape="box"] ;
12 -> 16 ;
}

参考リンク