Notice
Recent Posts
Recent Comments
Link
«   2024/10   »
1 2 3 4 5
6 7 8 9 10 11 12
13 14 15 16 17 18 19
20 21 22 23 24 25 26
27 28 29 30 31
Tags
more
Archives
Today
Total
관리 메뉴

BASEMENT

의사결정트리 본문

Programming/Machine Learning

의사결정트리

2_34 2020. 10. 10. 18:47

의사결정트리 (decision tree)

 

 

1. 개념

 

- 데이터네 내재되어 있는 패턴을 변수의 조합으로 나타내는 예측 분류 모델을 나무 형태로 만드는 방법

- 어떤 항목에 대한 관측값과 목표값을 연결시켜주는 예측 모델로서 결정트리를 사용

- 지도학습 방법, 분류회귀에 사용

- 의사 결정에 이르기까지 yes/no로 분류하여 사용하고 질문을 던져 대상을 좁히는 방법

 

 

1) 장점

 

자료를 가공할 필요가 거의 없고 다른 기법들의 경우 자료를 정규화하거나 임의의 변수를 생성하거나 값이 없는 변수를 제거해야 하는 경우

- 수치 자료와 범주 자료 모두 적용 가능

- 대규모 데이터 셋에서도 잘 동작함. 방대한 분량의 데이터를 일반적인 컴퓨터 환경에서 합리적인 시간 안에 분석 가능

 

2) 단점

 

- 결정 트리 학습자가 훈련 데이터를 제대로 일반화하지 못할 경우 너무 복합해짐

- 데이터의 특성이 특정 변수에 수직/수평적으로 구분되지 못할 때 분류율이 떨어지고, 트리가 복잡해짐

- 연속형 변수들에 대한 분리 경계점에서 예측오류가 클 가능성이 있음

 

 

 

2. 종류

 

1) Classification Tree

 

- 종속 변수가 이산형인 경우, 각각 범주에 속하는 빈도에 기초해 분리발생

- 분류 트리 분석은 예측된 결과로 입력 데이터가 분류되는 클래스를 출력

 

2) Regression Tree

 

- 종속 변수가 연속형인 경우, 평균과 표준편차에 의해 노드 분리

- 회귀 트리 분석은 예측된 결과로 특정 의미를 지니는 실수 값을 출력

(예: 주택의 가격, 환자의 입원 기간)

 

 

3. 의사결정트리 구성요소

 

- 노드(node) : 나무에서 분할되는 부분

- root node : 처음 노드

- terminal node/leaf node : 마지막 노드 -> 분리된 집합의 개수

- 부모노드(parent node)

- 자식노드(child node)

- 가지(Branch) : 뿌리 마디로부터 끝마디까지 연결된 마디

- 깊이(Depth) : 뿌리마디로부터 끝마디를 이루는 층

 

 

 

4. 지니불순도 & 엔트로피

 

1) 지니불순도

 

- 집합에 이질적인 것이 얼마나 섞였는지 측정하는 지표. CART 알고리즘에서 사용

- 집합에서 한 항목을 뽑아 무작위로 라벨을 추정할 때 틀릴 확률

- 집합에 있는 항목이 모두 같다면 지니 불순도는 최솟값(0)을 갖게 되며, 이 집합은 완전히 순수함

 

2) 엔트로피(entropy)

 

- m개의 레코드가 속하는 A영역에 대한 엔트로피(log)

 

=> 각 영역의 순도가 증가 / 불확실성(엔트로피)가 최소가 되도록 학습해야함 

 

 

 

5. 가지치기(Pruning)

 

- 오버피팅을 막기 위한 전략으로 불필요한 가지 제거

- Full tree를 생성한 후 적적할 수준에서 terminal node를 결합함

- 분기 수가 증가할 때 처음에는 새로운 데이터에 대한 오분류율이 감소하나, 일정 수준 이상이 되면 오분류율이 오히려 증가하는 현상

 

 

 

6. 재귀적 분기

 

특정 영역인 하나의 노드 내에서 하나의 변수 값을 기준으로 분기하여 새로 생성된 자식 노드들의 동질성이 최대화 되도록 분기점을 선택함 (동질성 최대화 = 불순도 최소화)

- 범주형 변수 : 지니 계수

- 수치형 변수 : 분산

 

 

 

7. 의사결정트리 사용

from sklearn.tree import DecisionTreeClassifier

tree = DecisionTreeClassifier(random_state=0)
tree.fit(X_train, y_train)

print("훈련 세트 정확도 : {:.3f}".format(tree.score(X_train, y_train)))
print("테스트 세트 정확도 : {:.3f}".format(tree.score(X_test, y_test)))

- max_depth : 최대 깊이 설정

- min_samples_split : 분할되기 위해 노드가 가져야 하는 최소 샘플 수

- min_samples_leaf : 리프 노드가 가지고 있어야 하는 최소 샘플 수

- min_weight_fraction_leaf : 가중치가 부여된 전체 샘플 수에서의 ㅡ비율

- max_leaf_nodes : 리프 노드의 최대 수

- max_features : 각 노드에서 분할에 사용할 특성의 최대 수

 

 

 

8. 의사결정트리 예제

 

1)

from sklearn import tree

X = [[0,0], [2,3], [2,1], [4,7], [5,4], [3,2]]
Y = [0,0,0,1,1,1]

clf = tree.DecisionTreeClassifier()
clf = clf.fit(X, Y)

clf.predict([[2,2]])
# !pip install mglearn
import mglearn

mglearn.plots.plot_tree_progressive()

 

2) iris 데이터

from sklearn import datasets
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
import numpy as np
from sklearn import tree

iris = datasets.load_iris()

X = iris.data[:,:4]
y = iris.target

X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=0)

sc = StandardScaler()
sc.fit(X_train)

X_train_std = sc.transform(X_train)
X_test_std = sc.transform(X_test)
iris_tree = tree.DecisionTreeClassifier(criterion='entropy', max_depth=3, random_state=0)
iris_tree.fit(X_train, y_train)
from sklearn.metrics import accuracy_score

y_pred_tr = iris_tree.predict(X_test)
print('Accuracy: %.2f' % accuracy_score(y_test, y_pred_tr))
from sklearn.tree import export_graphviz
dot_data = export_graphviz(iris_tree,  out_file="iristree.dot", feature_names=['sepal length (cm)', 'sepal width (cm)', 'petal length (cm)', 'petal width (cm)'],
                          class_names=iris.target_names, filled=True, rounded=True, special_characters=True)

import graphviz

with open("iristree.dot",encoding='UTF-8') as f:
    dot_data = f.read()
display(graphviz.Source(dot_data))

 

3) breast_cancer 데이터

from sklearn.tree import DecisionTreeClassifier
from sklearn.datasets import load_breast_cancer
from sklearn.model_selection import train_test_split, cross_val_score

cancer = load_breast_cancer()

X_train, X_test, y_train, y_test = train_test_split(cancer.data, cancer.target, stratify=cancer.target, random_state=42)

tree = DecisionTreeClassifier(random_state=0)
tree.fit(X_train, y_train)

print("train set acc: {:.3f}".format(tree.score(X_train, y_train)))
print("test set acc: {:.3f}".format(tree.score(X_test, y_test)))
tree = DecisionTreeClassifier(max_depth=4, random_state=0)
tree.fit(X_train, y_train)

print("훈련 세트 정확도: {:.3f}".format(tree.score(X_train, y_train)))
print("테스트 세트 정확도: {:.3f}".format(tree.score(X_test, y_test)))
from sklearn.tree import export_graphviz

export_graphviz(tree, out_file="tree.dot", class_names=["악성","양성"], feature_names=cancer.feature_names, impurity=False, filled=True)

import graphviz

with open("tree.dot", encoding='UTF-8') as f:
  dot_graph = f.read()
display(graphviz.Source(dot_graph))
print("특성 중요도:\n", tree.feature_importances_)
import matplotlib.pyplot as plt
import numpy as np

def plot_feature_importances_cancer(model):
  n_features = cancer.data.shape[1]
  plt.barh(np.arange(n_features), model.feature_importances_, align='center')
  plt.yticks(np.arange(n_features), cancer.feature_names)
  plt.xlabel("feature importance")
  plt.ylabel("feature")
  plt.ylim(-1, n_features)

plot_feature_importances_cancer(tree)

 

4) boston 데이터

from sklearn.model_selection import train_test_split
from sklearn.datasets import load_boston
from sklearn.metrics import mean_squared_error

X, y = load_boston(return_X_y=True)
boston = load_boston()
X = boston.data
y = boston.target
colnames = boston.feature_names
colnames
x_train, x_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=42)
from sklearn.tree import DecisionTreeRegressor

model = DecisionTreeRegressor(max_depth=3)
model.fit(X=x_train, y=y_train)
y_pred = model.predict(X)
y_true = y

mse = mean_squared_error(y_true, y_pred)
print('mse: ', mse)
rmse = (np.sqrt(mse))
print('rmse: ', rmse)
plt.figure()
plt.title("Decision Tree Regressor (Model Actual vs Precited) with All Features")
plt.xlabel('TEST SET')
plt.ylabel('MEDV')
plt.plot(y_pred, 'o-', color="r", label="Predicted MEDV")
plt.plot(y_test, 'o-', color="g", label="Actual MEDV")

'Programming > Machine Learning' 카테고리의 다른 글

군집화  (0) 2020.10.10
랜덤포레스트 (Random Forest)  (0) 2020.10.10
SVM, SVR  (0) 2020.10.05
Naive Bayes (나이브 베이즈)  (0) 2020.10.05
KNN 알고리즘  (0) 2020.09.27
Comments