02 첫 번째 머신러닝 만들어 보기 - 붓꽃 품종 예측하기¶
sklearn.datasets
: 사이킷런에서 자체적으로 제공하는 데이터 세트 생성sklearn.tree
: 트리 기반 ML 알고리즘 구현. ML 알고리즘은 의사 결정 트리(Decision Tree) 알고리즘으로, 이를 구현한DecisionTreeClassifier
적용sklearn.model_selection
: 학습 데이터와 검증 데이터, 예측 데이터로 데이터를 분리하거나 최적의 하이퍼 파라미터로 평가load_iris()
: 붓꽃 데이터 세트 생성train_test_split()
: 데이터 세트를 학습 데이터와 테스트 데이터로 분리
피처는 속성(컬럼, 열), 레이블은 품종
from sklearn.datasets import load_iris
from sklearn.tree import DecisionTreeClassifier
from sklearn.model_selection import train_test_split
load_iris() 함수를 이용해 붓꽃 데이터 세트를 로딩한 후, 피처들과 데이터 값이 어떻게 구성돼 있는지 확인하기 위해 DataFrame으로 변환
import pandas as pd
#붓꽃 데이터 세트 로딩
iris = load_iris()
iris
{'data': array([[5.1, 3.5, 1.4, 0.2], [4.9, 3. , 1.4, 0.2], [4.7, 3.2, 1.3, 0.2], [4.6, 3.1, 1.5, 0.2], [5. , 3.6, 1.4, 0.2], [5.4, 3.9, 1.7, 0.4], [4.6, 3.4, 1.4, 0.3], [5. , 3.4, 1.5, 0.2], [4.4, 2.9, 1.4, 0.2], [4.9, 3.1, 1.5, 0.1], [5.4, 3.7, 1.5, 0.2], [4.8, 3.4, 1.6, 0.2], [4.8, 3. , 1.4, 0.1], [4.3, 3. , 1.1, 0.1], [5.8, 4. , 1.2, 0.2], [5.7, 4.4, 1.5, 0.4], [5.4, 3.9, 1.3, 0.4], [5.1, 3.5, 1.4, 0.3], [5.7, 3.8, 1.7, 0.3], [5.1, 3.8, 1.5, 0.3], [5.4, 3.4, 1.7, 0.2], [5.1, 3.7, 1.5, 0.4], [4.6, 3.6, 1. , 0.2], [5.1, 3.3, 1.7, 0.5], [4.8, 3.4, 1.9, 0.2], [5. , 3. , 1.6, 0.2], [5. , 3.4, 1.6, 0.4], [5.2, 3.5, 1.5, 0.2], [5.2, 3.4, 1.4, 0.2], [4.7, 3.2, 1.6, 0.2], [4.8, 3.1, 1.6, 0.2], [5.4, 3.4, 1.5, 0.4], [5.2, 4.1, 1.5, 0.1], [5.5, 4.2, 1.4, 0.2], [4.9, 3.1, 1.5, 0.2], [5. , 3.2, 1.2, 0.2], [5.5, 3.5, 1.3, 0.2], [4.9, 3.6, 1.4, 0.1], [4.4, 3. , 1.3, 0.2], [5.1, 3.4, 1.5, 0.2], [5. , 3.5, 1.3, 0.3], [4.5, 2.3, 1.3, 0.3], [4.4, 3.2, 1.3, 0.2], [5. , 3.5, 1.6, 0.6], [5.1, 3.8, 1.9, 0.4], [4.8, 3. , 1.4, 0.3], [5.1, 3.8, 1.6, 0.2], [4.6, 3.2, 1.4, 0.2], [5.3, 3.7, 1.5, 0.2], [5. , 3.3, 1.4, 0.2], [7. , 3.2, 4.7, 1.4], [6.4, 3.2, 4.5, 1.5], [6.9, 3.1, 4.9, 1.5], [5.5, 2.3, 4. , 1.3], [6.5, 2.8, 4.6, 1.5], [5.7, 2.8, 4.5, 1.3], [6.3, 3.3, 4.7, 1.6], [4.9, 2.4, 3.3, 1. ], [6.6, 2.9, 4.6, 1.3], [5.2, 2.7, 3.9, 1.4], [5. , 2. , 3.5, 1. ], [5.9, 3. , 4.2, 1.5], [6. , 2.2, 4. , 1. ], [6.1, 2.9, 4.7, 1.4], [5.6, 2.9, 3.6, 1.3], [6.7, 3.1, 4.4, 1.4], [5.6, 3. , 4.5, 1.5], [5.8, 2.7, 4.1, 1. ], [6.2, 2.2, 4.5, 1.5], [5.6, 2.5, 3.9, 1.1], [5.9, 3.2, 4.8, 1.8], [6.1, 2.8, 4. , 1.3], [6.3, 2.5, 4.9, 1.5], [6.1, 2.8, 4.7, 1.2], [6.4, 2.9, 4.3, 1.3], [6.6, 3. , 4.4, 1.4], [6.8, 2.8, 4.8, 1.4], [6.7, 3. , 5. , 1.7], [6. , 2.9, 4.5, 1.5], [5.7, 2.6, 3.5, 1. ], [5.5, 2.4, 3.8, 1.1], [5.5, 2.4, 3.7, 1. ], [5.8, 2.7, 3.9, 1.2], [6. , 2.7, 5.1, 1.6], [5.4, 3. , 4.5, 1.5], [6. , 3.4, 4.5, 1.6], [6.7, 3.1, 4.7, 1.5], [6.3, 2.3, 4.4, 1.3], [5.6, 3. , 4.1, 1.3], [5.5, 2.5, 4. , 1.3], [5.5, 2.6, 4.4, 1.2], [6.1, 3. , 4.6, 1.4], [5.8, 2.6, 4. , 1.2], [5. , 2.3, 3.3, 1. ], [5.6, 2.7, 4.2, 1.3], [5.7, 3. , 4.2, 1.2], [5.7, 2.9, 4.2, 1.3], [6.2, 2.9, 4.3, 1.3], [5.1, 2.5, 3. , 1.1], [5.7, 2.8, 4.1, 1.3], [6.3, 3.3, 6. , 2.5], [5.8, 2.7, 5.1, 1.9], [7.1, 3. , 5.9, 2.1], [6.3, 2.9, 5.6, 1.8], [6.5, 3. , 5.8, 2.2], [7.6, 3. , 6.6, 2.1], [4.9, 2.5, 4.5, 1.7], [7.3, 2.9, 6.3, 1.8], [6.7, 2.5, 5.8, 1.8], [7.2, 3.6, 6.1, 2.5], [6.5, 3.2, 5.1, 2. ], [6.4, 2.7, 5.3, 1.9], [6.8, 3. , 5.5, 2.1], [5.7, 2.5, 5. , 2. ], [5.8, 2.8, 5.1, 2.4], [6.4, 3.2, 5.3, 2.3], [6.5, 3. , 5.5, 1.8], [7.7, 3.8, 6.7, 2.2], [7.7, 2.6, 6.9, 2.3], [6. , 2.2, 5. , 1.5], [6.9, 3.2, 5.7, 2.3], [5.6, 2.8, 4.9, 2. ], [7.7, 2.8, 6.7, 2. ], [6.3, 2.7, 4.9, 1.8], [6.7, 3.3, 5.7, 2.1], [7.2, 3.2, 6. , 1.8], [6.2, 2.8, 4.8, 1.8], [6.1, 3. , 4.9, 1.8], [6.4, 2.8, 5.6, 2.1], [7.2, 3. , 5.8, 1.6], [7.4, 2.8, 6.1, 1.9], [7.9, 3.8, 6.4, 2. ], [6.4, 2.8, 5.6, 2.2], [6.3, 2.8, 5.1, 1.5], [6.1, 2.6, 5.6, 1.4], [7.7, 3. , 6.1, 2.3], [6.3, 3.4, 5.6, 2.4], [6.4, 3.1, 5.5, 1.8], [6. , 3. , 4.8, 1.8], [6.9, 3.1, 5.4, 2.1], [6.7, 3.1, 5.6, 2.4], [6.9, 3.1, 5.1, 2.3], [5.8, 2.7, 5.1, 1.9], [6.8, 3.2, 5.9, 2.3], [6.7, 3.3, 5.7, 2.5], [6.7, 3. , 5.2, 2.3], [6.3, 2.5, 5. , 1.9], [6.5, 3. , 5.2, 2. ], [6.2, 3.4, 5.4, 2.3], [5.9, 3. , 5.1, 1.8]]), 'target': array([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2]), 'frame': None, 'target_names': array(['setosa', 'versicolor', 'virginica'], dtype='<U10'), 'DESCR': '.. _iris_dataset:\n\nIris plants dataset\n--------------------\n\n**Data Set Characteristics:**\n\n :Number of Instances: 150 (50 in each of three classes)\n :Number of Attributes: 4 numeric, predictive attributes and the class\n :Attribute Information:\n - sepal length in cm\n - sepal width in cm\n - petal length in cm\n - petal width in cm\n - class:\n - Iris-Setosa\n - Iris-Versicolour\n - Iris-Virginica\n \n :Summary Statistics:\n\n ============== ==== ==== ======= ===== ====================\n Min Max Mean SD Class Correlation\n ============== ==== ==== ======= ===== ====================\n sepal length: 4.3 7.9 5.84 0.83 0.7826\n sepal width: 2.0 4.4 3.05 0.43 -0.4194\n petal length: 1.0 6.9 3.76 1.76 0.9490 (high!)\n petal width: 0.1 2.5 1.20 0.76 0.9565 (high!)\n ============== ==== ==== ======= ===== ====================\n\n :Missing Attribute Values: None\n :Class Distribution: 33.3% for each of 3 classes.\n :Creator: R.A. Fisher\n :Donor: Michael Marshall (MARSHALL%PLU@io.arc.nasa.gov)\n :Date: July, 1988\n\nThe famous Iris database, first used by Sir R.A. Fisher. The dataset is taken\nfrom Fisher\'s paper. Note that it\'s the same as in R, but not as in the UCI\nMachine Learning Repository, which has two wrong data points.\n\nThis is perhaps the best known database to be found in the\npattern recognition literature. Fisher\'s paper is a classic in the field and\nis referenced frequently to this day. (See Duda & Hart, for example.) The\ndata set contains 3 classes of 50 instances each, where each class refers to a\ntype of iris plant. One class is linearly separable from the other 2; the\nlatter are NOT linearly separable from each other.\n\n.. topic:: References\n\n - Fisher, R.A. "The use of multiple measurements in taxonomic problems"\n Annual Eugenics, 7, Part II, 179-188 (1936); also in "Contributions to\n Mathematical Statistics" (John Wiley, NY, 1950).\n - Duda, R.O., & Hart, P.E. (1973) Pattern Classification and Scene Analysis.\n (Q327.D83) John Wiley & Sons. ISBN 0-471-22361-1. See page 218.\n - Dasarathy, B.V. (1980) "Nosing Around the Neighborhood: A New System\n Structure and Classification Rule for Recognition in Partially Exposed\n Environments". IEEE Transactions on Pattern Analysis and Machine\n Intelligence, Vol. PAMI-2, No. 1, 67-71.\n - Gates, G.W. (1972) "The Reduced Nearest Neighbor Rule". IEEE Transactions\n on Information Theory, May 1972, 431-433.\n - See also: 1988 MLC Proceedings, 54-64. Cheeseman et al"s AUTOCLASS II\n conceptual clustering system finds 3 classes in the data.\n - Many, many more ...', 'feature_names': ['sepal length (cm)', 'sepal width (cm)', 'petal length (cm)', 'petal width (cm)'], 'filename': 'iris.csv', 'data_module': 'sklearn.datasets.data'}
# iris.data는 Iris 데이터 세트에서 피처(feature)만으로 된 데이터를 numpy로 가짐
iris_data = iris.data
iris_data
array([[5.1, 3.5, 1.4, 0.2], [4.9, 3. , 1.4, 0.2], [4.7, 3.2, 1.3, 0.2], [4.6, 3.1, 1.5, 0.2], [5. , 3.6, 1.4, 0.2], [5.4, 3.9, 1.7, 0.4], [4.6, 3.4, 1.4, 0.3], [5. , 3.4, 1.5, 0.2], [4.4, 2.9, 1.4, 0.2], [4.9, 3.1, 1.5, 0.1], [5.4, 3.7, 1.5, 0.2], [4.8, 3.4, 1.6, 0.2], [4.8, 3. , 1.4, 0.1], [4.3, 3. , 1.1, 0.1], [5.8, 4. , 1.2, 0.2], [5.7, 4.4, 1.5, 0.4], [5.4, 3.9, 1.3, 0.4], [5.1, 3.5, 1.4, 0.3], [5.7, 3.8, 1.7, 0.3], [5.1, 3.8, 1.5, 0.3], [5.4, 3.4, 1.7, 0.2], [5.1, 3.7, 1.5, 0.4], [4.6, 3.6, 1. , 0.2], [5.1, 3.3, 1.7, 0.5], [4.8, 3.4, 1.9, 0.2], [5. , 3. , 1.6, 0.2], [5. , 3.4, 1.6, 0.4], [5.2, 3.5, 1.5, 0.2], [5.2, 3.4, 1.4, 0.2], [4.7, 3.2, 1.6, 0.2], [4.8, 3.1, 1.6, 0.2], [5.4, 3.4, 1.5, 0.4], [5.2, 4.1, 1.5, 0.1], [5.5, 4.2, 1.4, 0.2], [4.9, 3.1, 1.5, 0.2], [5. , 3.2, 1.2, 0.2], [5.5, 3.5, 1.3, 0.2], [4.9, 3.6, 1.4, 0.1], [4.4, 3. , 1.3, 0.2], [5.1, 3.4, 1.5, 0.2], [5. , 3.5, 1.3, 0.3], [4.5, 2.3, 1.3, 0.3], [4.4, 3.2, 1.3, 0.2], [5. , 3.5, 1.6, 0.6], [5.1, 3.8, 1.9, 0.4], [4.8, 3. , 1.4, 0.3], [5.1, 3.8, 1.6, 0.2], [4.6, 3.2, 1.4, 0.2], [5.3, 3.7, 1.5, 0.2], [5. , 3.3, 1.4, 0.2], [7. , 3.2, 4.7, 1.4], [6.4, 3.2, 4.5, 1.5], [6.9, 3.1, 4.9, 1.5], [5.5, 2.3, 4. , 1.3], [6.5, 2.8, 4.6, 1.5], [5.7, 2.8, 4.5, 1.3], [6.3, 3.3, 4.7, 1.6], [4.9, 2.4, 3.3, 1. ], [6.6, 2.9, 4.6, 1.3], [5.2, 2.7, 3.9, 1.4], [5. , 2. , 3.5, 1. ], [5.9, 3. , 4.2, 1.5], [6. , 2.2, 4. , 1. ], [6.1, 2.9, 4.7, 1.4], [5.6, 2.9, 3.6, 1.3], [6.7, 3.1, 4.4, 1.4], [5.6, 3. , 4.5, 1.5], [5.8, 2.7, 4.1, 1. ], [6.2, 2.2, 4.5, 1.5], [5.6, 2.5, 3.9, 1.1], [5.9, 3.2, 4.8, 1.8], [6.1, 2.8, 4. , 1.3], [6.3, 2.5, 4.9, 1.5], [6.1, 2.8, 4.7, 1.2], [6.4, 2.9, 4.3, 1.3], [6.6, 3. , 4.4, 1.4], [6.8, 2.8, 4.8, 1.4], [6.7, 3. , 5. , 1.7], [6. , 2.9, 4.5, 1.5], [5.7, 2.6, 3.5, 1. ], [5.5, 2.4, 3.8, 1.1], [5.5, 2.4, 3.7, 1. ], [5.8, 2.7, 3.9, 1.2], [6. , 2.7, 5.1, 1.6], [5.4, 3. , 4.5, 1.5], [6. , 3.4, 4.5, 1.6], [6.7, 3.1, 4.7, 1.5], [6.3, 2.3, 4.4, 1.3], [5.6, 3. , 4.1, 1.3], [5.5, 2.5, 4. , 1.3], [5.5, 2.6, 4.4, 1.2], [6.1, 3. , 4.6, 1.4], [5.8, 2.6, 4. , 1.2], [5. , 2.3, 3.3, 1. ], [5.6, 2.7, 4.2, 1.3], [5.7, 3. , 4.2, 1.2], [5.7, 2.9, 4.2, 1.3], [6.2, 2.9, 4.3, 1.3], [5.1, 2.5, 3. , 1.1], [5.7, 2.8, 4.1, 1.3], [6.3, 3.3, 6. , 2.5], [5.8, 2.7, 5.1, 1.9], [7.1, 3. , 5.9, 2.1], [6.3, 2.9, 5.6, 1.8], [6.5, 3. , 5.8, 2.2], [7.6, 3. , 6.6, 2.1], [4.9, 2.5, 4.5, 1.7], [7.3, 2.9, 6.3, 1.8], [6.7, 2.5, 5.8, 1.8], [7.2, 3.6, 6.1, 2.5], [6.5, 3.2, 5.1, 2. ], [6.4, 2.7, 5.3, 1.9], [6.8, 3. , 5.5, 2.1], [5.7, 2.5, 5. , 2. ], [5.8, 2.8, 5.1, 2.4], [6.4, 3.2, 5.3, 2.3], [6.5, 3. , 5.5, 1.8], [7.7, 3.8, 6.7, 2.2], [7.7, 2.6, 6.9, 2.3], [6. , 2.2, 5. , 1.5], [6.9, 3.2, 5.7, 2.3], [5.6, 2.8, 4.9, 2. ], [7.7, 2.8, 6.7, 2. ], [6.3, 2.7, 4.9, 1.8], [6.7, 3.3, 5.7, 2.1], [7.2, 3.2, 6. , 1.8], [6.2, 2.8, 4.8, 1.8], [6.1, 3. , 4.9, 1.8], [6.4, 2.8, 5.6, 2.1], [7.2, 3. , 5.8, 1.6], [7.4, 2.8, 6.1, 1.9], [7.9, 3.8, 6.4, 2. ], [6.4, 2.8, 5.6, 2.2], [6.3, 2.8, 5.1, 1.5], [6.1, 2.6, 5.6, 1.4], [7.7, 3. , 6.1, 2.3], [6.3, 3.4, 5.6, 2.4], [6.4, 3.1, 5.5, 1.8], [6. , 3. , 4.8, 1.8], [6.9, 3.1, 5.4, 2.1], [6.7, 3.1, 5.6, 2.4], [6.9, 3.1, 5.1, 2.3], [5.8, 2.7, 5.1, 1.9], [6.8, 3.2, 5.9, 2.3], [6.7, 3.3, 5.7, 2.5], [6.7, 3. , 5.2, 2.3], [6.3, 2.5, 5. , 1.9], [6.5, 3. , 5.2, 2. ], [6.2, 3.4, 5.4, 2.3], [5.9, 3. , 5.1, 1.8]])
# iris.target은 붓꽃 데이터 세트에서 레이블(결정 값) 데이터를 numpy로 가짐
iris_label = iris.target
print('iris target 값: ', iris_label)
print('iris target 명: ', iris.target_names)
iris target 값: [0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2] iris target 명: ['setosa' 'versicolor' 'virginica']
# 붓꽃 데이터 세트를 자세히 보기 위해 DataFrame으로 변환
iris_df = pd.DataFrame(data=iris_data, columns=iris.feature_names)
iris_df['label'] = iris.target
iris_df.head()
sepal length (cm) | sepal width (cm) | petal length (cm) | petal width (cm) | label | |
---|---|---|---|---|---|
0 | 5.1 | 3.5 | 1.4 | 0.2 | 0 |
1 | 4.9 | 3.0 | 1.4 | 0.2 | 0 |
2 | 4.7 | 3.2 | 1.3 | 0.2 | 0 |
3 | 4.6 | 3.1 | 1.5 | 0.2 | 0 |
4 | 5.0 | 3.6 | 1.4 | 0.2 | 0 |
피처
: sepal_length, sepal_width, petal length, petal width레이블(Label, 결정값)
: 0,1,2 세 가지 값0
: Setosa 품종1
: versicolor 품종2
: virginica 품종
학습용 데이터와 테스트용 데이터 분리¶
train_test_split() : 데이터 세트 분리¶
학습 데이터와 테스트 데이터를 test_size 파라미터 입력값의 비율로 쉽게 분할
예시) test_size=0.2로 입력 파라미터를 설정
전체 데이터 중 테스트 데이터
가 20%, 학습 데이터
가 80%로 데이터 분할
X_train, X_test, y_train, y_test = train_test_split(iris_data, iris_label, test_size=0.2, random_state=11)
iris_data
: 피처 데이터 세트iris_label
: 레이블(Label) 데이터 세트test_size=0.2
: 전체 데이터 세트 중 테스트 데이터 세트의 비율random_state
: 호출할 때마다 같은 학습/테스트 용 데이터 세트를 생성하기 위해 주어지는 난수 발생 값.train_test_split()
는 호출 시 무작위로 데이터를 분리하므로radom_state
를 지정하지 않으면 수행할 때마다 다른 학습/테스트 용 데이터를 만들 수 있음. 본 예제는 실습용 예제이므로 수행할 때마다 동일한 데이터 세트로 분리하기 위해random_state
를 일정한 숫자 값으로 부여. (random_state
는 random 값을 만드는 seed와 같은 의미. 숫자 자체는 어떤 값을 지정해도 상관 X)
X_train
: 학습용 피처 데이터 세트X_test
: 테스트용 피처 데이터 세트y_train
: 학습용 레이블 데이터 세트y_test
: 테스트용 레이블 데이터 세트
의사 결정 트리를 이용해 학습과 예측 수행¶
이제 학습 데이터를 확보했으니 이 데이터를 기반으로 머신러닝 분류 알고리즘의 하나인 의사 결정 트리를 이용해 학습과 예측을 수행하자.
먼저 사이킷런의 의사 결정 트리 클래스인 DecisionTreeClassifier
를 객체로 생성
(DecisionTreeClassifier 객체 생성 시 입력된 random_state=11
역시 예제 코드를 수행할 때마다 동일한 학습/예츨 결과를 출력하기 위한 용도로만 사용됨)
fit() : 학습¶
생성된 DecisionTreeClassifier 객체의 fit()
메서드에 학습용 피처 데이터 속성과 결정값 데이터 세트를 입력해 호출하면 학습을 수행
# DecisionTreeClassifier 객체 생성
dt_clf = DecisionTreeClassifier(random_state=11)
# 학습 수행
dt_clf.fit(X_train, y_train)
DecisionTreeClassifier(random_state=11)In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.
DecisionTreeClassifier(random_state=11)
predict() : 예측 수행¶
이제 의사 결정 트리 기반의 DecisionTreeClassifier
객체는 학습 데이터를 기반으로 학습 완료
이렇게 학습된 DecisionTreeClassifier 객체를 이용해 예측 수행
예측은 반드시 학습 데이터가 아닌 다른 데이터를 이용해야 하며, 일반적으로 테스트 데이터 세트를 이용
DecisionTreeClassifier 객체의 predict()
메서드에 테스트용 피처 데이터 세트를 입력해 호출하면 학습된 모델 기반에서 테스트 데이터 세트에 대한 예측값 반환
# 학습이 완료된 DecisionTreeClassifier 객체에서 테스트 데이터 세트로 예측 수행.
pred = dt_clf.predict(X_test)
예측 결과를 기반으로 의사 결정 트리 기반의 DecisionTreeClassifier의 예측 성능을 평가해보자.
일반적으로 머신러닝 모델의 성능 평가 방법은 여러 가지가 있으나, 여기서는 정확도를 측정해 보자.
정확도
: 예측 결과가 실제 레이블 값과 얼마나 정확하게 맞는지를 평가하는 지표
예측한 붓꽃 품종과 실제 테스트 데이터 세트의 붓꽃 품종이 얼마나 일치하는지 확인해 보자.
accuracy_score() : 평가¶
사이킷런은 정확도 측정을 위해 accuracy_score() 함수 제공
accuracy_score()의 첫 번째 파라미터
: 실제 레이블 데이터 세트accuracy_score()의 두 번째 파라미터
: 예측 레이블 데이터 세트
from sklearn.metrics import accuracy_score
print('예측 정확도: {0:4f}'.format(accuracy_score(y_test, pred)))
예측 정확도: 0.933333
학습한 의사 결정 트리의 알고리즘 예측 정확도가 약 0.9333(93.33%)로 측정됨
붓꽃 데이터 세트로 분류를 예측한 프로세스¶
- 데이터 세트 분리: 데이터를 학습 데이터와 테스트 데이터로 분리
- 모델 학습 : 학습 데이터를 기반으로 ML 알고리즘을 적용해 모델을 학습시킴
- 예측 수행 : 학습된 ML 모델을 이용해 테스트 데이터의 분류(즉, 붓꽃 종류)를 예측
- 평가 : 이렇게 에측된 결괏값과 테스트 데이터의 실제 결괏값을 비교해 ML 모델 성능 평가
퀴즈
iris 데이터를 가지고
- train과 test 비율을 8:2로 해서
- DecisionTreeClassifier를 이용해서
- 정확도를 구해 봅시다.
import pydataset
iris = pydataset.data('iris')
iris
Sepal.Length | Sepal.Width | Petal.Length | Petal.Width | Species | |
---|---|---|---|---|---|
1 | 5.1 | 3.5 | 1.4 | 0.2 | setosa |
2 | 4.9 | 3.0 | 1.4 | 0.2 | setosa |
3 | 4.7 | 3.2 | 1.3 | 0.2 | setosa |
4 | 4.6 | 3.1 | 1.5 | 0.2 | setosa |
5 | 5.0 | 3.6 | 1.4 | 0.2 | setosa |
6 | 5.4 | 3.9 | 1.7 | 0.4 | setosa |
7 | 4.6 | 3.4 | 1.4 | 0.3 | setosa |
8 | 5.0 | 3.4 | 1.5 | 0.2 | setosa |
9 | 4.4 | 2.9 | 1.4 | 0.2 | setosa |
10 | 4.9 | 3.1 | 1.5 | 0.1 | setosa |
11 | 5.4 | 3.7 | 1.5 | 0.2 | setosa |
12 | 4.8 | 3.4 | 1.6 | 0.2 | setosa |
13 | 4.8 | 3.0 | 1.4 | 0.1 | setosa |
14 | 4.3 | 3.0 | 1.1 | 0.1 | setosa |
15 | 5.8 | 4.0 | 1.2 | 0.2 | setosa |
16 | 5.7 | 4.4 | 1.5 | 0.4 | setosa |
17 | 5.4 | 3.9 | 1.3 | 0.4 | setosa |
18 | 5.1 | 3.5 | 1.4 | 0.3 | setosa |
19 | 5.7 | 3.8 | 1.7 | 0.3 | setosa |
20 | 5.1 | 3.8 | 1.5 | 0.3 | setosa |
21 | 5.4 | 3.4 | 1.7 | 0.2 | setosa |
22 | 5.1 | 3.7 | 1.5 | 0.4 | setosa |
23 | 4.6 | 3.6 | 1.0 | 0.2 | setosa |
24 | 5.1 | 3.3 | 1.7 | 0.5 | setosa |
25 | 4.8 | 3.4 | 1.9 | 0.2 | setosa |
26 | 5.0 | 3.0 | 1.6 | 0.2 | setosa |
27 | 5.0 | 3.4 | 1.6 | 0.4 | setosa |
28 | 5.2 | 3.5 | 1.5 | 0.2 | setosa |
29 | 5.2 | 3.4 | 1.4 | 0.2 | setosa |
30 | 4.7 | 3.2 | 1.6 | 0.2 | setosa |
31 | 4.8 | 3.1 | 1.6 | 0.2 | setosa |
32 | 5.4 | 3.4 | 1.5 | 0.4 | setosa |
33 | 5.2 | 4.1 | 1.5 | 0.1 | setosa |
34 | 5.5 | 4.2 | 1.4 | 0.2 | setosa |
35 | 4.9 | 3.1 | 1.5 | 0.2 | setosa |
36 | 5.0 | 3.2 | 1.2 | 0.2 | setosa |
37 | 5.5 | 3.5 | 1.3 | 0.2 | setosa |
38 | 4.9 | 3.6 | 1.4 | 0.1 | setosa |
39 | 4.4 | 3.0 | 1.3 | 0.2 | setosa |
40 | 5.1 | 3.4 | 1.5 | 0.2 | setosa |
41 | 5.0 | 3.5 | 1.3 | 0.3 | setosa |
42 | 4.5 | 2.3 | 1.3 | 0.3 | setosa |
43 | 4.4 | 3.2 | 1.3 | 0.2 | setosa |
44 | 5.0 | 3.5 | 1.6 | 0.6 | setosa |
45 | 5.1 | 3.8 | 1.9 | 0.4 | setosa |
46 | 4.8 | 3.0 | 1.4 | 0.3 | setosa |
47 | 5.1 | 3.8 | 1.6 | 0.2 | setosa |
48 | 4.6 | 3.2 | 1.4 | 0.2 | setosa |
49 | 5.3 | 3.7 | 1.5 | 0.2 | setosa |
50 | 5.0 | 3.3 | 1.4 | 0.2 | setosa |
51 | 7.0 | 3.2 | 4.7 | 1.4 | versicolor |
52 | 6.4 | 3.2 | 4.5 | 1.5 | versicolor |
53 | 6.9 | 3.1 | 4.9 | 1.5 | versicolor |
54 | 5.5 | 2.3 | 4.0 | 1.3 | versicolor |
55 | 6.5 | 2.8 | 4.6 | 1.5 | versicolor |
56 | 5.7 | 2.8 | 4.5 | 1.3 | versicolor |
57 | 6.3 | 3.3 | 4.7 | 1.6 | versicolor |
58 | 4.9 | 2.4 | 3.3 | 1.0 | versicolor |
59 | 6.6 | 2.9 | 4.6 | 1.3 | versicolor |
60 | 5.2 | 2.7 | 3.9 | 1.4 | versicolor |
61 | 5.0 | 2.0 | 3.5 | 1.0 | versicolor |
62 | 5.9 | 3.0 | 4.2 | 1.5 | versicolor |
63 | 6.0 | 2.2 | 4.0 | 1.0 | versicolor |
64 | 6.1 | 2.9 | 4.7 | 1.4 | versicolor |
65 | 5.6 | 2.9 | 3.6 | 1.3 | versicolor |
66 | 6.7 | 3.1 | 4.4 | 1.4 | versicolor |
67 | 5.6 | 3.0 | 4.5 | 1.5 | versicolor |
68 | 5.8 | 2.7 | 4.1 | 1.0 | versicolor |
69 | 6.2 | 2.2 | 4.5 | 1.5 | versicolor |
70 | 5.6 | 2.5 | 3.9 | 1.1 | versicolor |
71 | 5.9 | 3.2 | 4.8 | 1.8 | versicolor |
72 | 6.1 | 2.8 | 4.0 | 1.3 | versicolor |
73 | 6.3 | 2.5 | 4.9 | 1.5 | versicolor |
74 | 6.1 | 2.8 | 4.7 | 1.2 | versicolor |
75 | 6.4 | 2.9 | 4.3 | 1.3 | versicolor |
76 | 6.6 | 3.0 | 4.4 | 1.4 | versicolor |
77 | 6.8 | 2.8 | 4.8 | 1.4 | versicolor |
78 | 6.7 | 3.0 | 5.0 | 1.7 | versicolor |
79 | 6.0 | 2.9 | 4.5 | 1.5 | versicolor |
80 | 5.7 | 2.6 | 3.5 | 1.0 | versicolor |
81 | 5.5 | 2.4 | 3.8 | 1.1 | versicolor |
82 | 5.5 | 2.4 | 3.7 | 1.0 | versicolor |
83 | 5.8 | 2.7 | 3.9 | 1.2 | versicolor |
84 | 6.0 | 2.7 | 5.1 | 1.6 | versicolor |
85 | 5.4 | 3.0 | 4.5 | 1.5 | versicolor |
86 | 6.0 | 3.4 | 4.5 | 1.6 | versicolor |
87 | 6.7 | 3.1 | 4.7 | 1.5 | versicolor |
88 | 6.3 | 2.3 | 4.4 | 1.3 | versicolor |
89 | 5.6 | 3.0 | 4.1 | 1.3 | versicolor |
90 | 5.5 | 2.5 | 4.0 | 1.3 | versicolor |
91 | 5.5 | 2.6 | 4.4 | 1.2 | versicolor |
92 | 6.1 | 3.0 | 4.6 | 1.4 | versicolor |
93 | 5.8 | 2.6 | 4.0 | 1.2 | versicolor |
94 | 5.0 | 2.3 | 3.3 | 1.0 | versicolor |
95 | 5.6 | 2.7 | 4.2 | 1.3 | versicolor |
96 | 5.7 | 3.0 | 4.2 | 1.2 | versicolor |
97 | 5.7 | 2.9 | 4.2 | 1.3 | versicolor |
98 | 6.2 | 2.9 | 4.3 | 1.3 | versicolor |
99 | 5.1 | 2.5 | 3.0 | 1.1 | versicolor |
100 | 5.7 | 2.8 | 4.1 | 1.3 | versicolor |
101 | 6.3 | 3.3 | 6.0 | 2.5 | virginica |
102 | 5.8 | 2.7 | 5.1 | 1.9 | virginica |
103 | 7.1 | 3.0 | 5.9 | 2.1 | virginica |
104 | 6.3 | 2.9 | 5.6 | 1.8 | virginica |
105 | 6.5 | 3.0 | 5.8 | 2.2 | virginica |
106 | 7.6 | 3.0 | 6.6 | 2.1 | virginica |
107 | 4.9 | 2.5 | 4.5 | 1.7 | virginica |
108 | 7.3 | 2.9 | 6.3 | 1.8 | virginica |
109 | 6.7 | 2.5 | 5.8 | 1.8 | virginica |
110 | 7.2 | 3.6 | 6.1 | 2.5 | virginica |
111 | 6.5 | 3.2 | 5.1 | 2.0 | virginica |
112 | 6.4 | 2.7 | 5.3 | 1.9 | virginica |
113 | 6.8 | 3.0 | 5.5 | 2.1 | virginica |
114 | 5.7 | 2.5 | 5.0 | 2.0 | virginica |
115 | 5.8 | 2.8 | 5.1 | 2.4 | virginica |
116 | 6.4 | 3.2 | 5.3 | 2.3 | virginica |
117 | 6.5 | 3.0 | 5.5 | 1.8 | virginica |
118 | 7.7 | 3.8 | 6.7 | 2.2 | virginica |
119 | 7.7 | 2.6 | 6.9 | 2.3 | virginica |
120 | 6.0 | 2.2 | 5.0 | 1.5 | virginica |
121 | 6.9 | 3.2 | 5.7 | 2.3 | virginica |
122 | 5.6 | 2.8 | 4.9 | 2.0 | virginica |
123 | 7.7 | 2.8 | 6.7 | 2.0 | virginica |
124 | 6.3 | 2.7 | 4.9 | 1.8 | virginica |
125 | 6.7 | 3.3 | 5.7 | 2.1 | virginica |
126 | 7.2 | 3.2 | 6.0 | 1.8 | virginica |
127 | 6.2 | 2.8 | 4.8 | 1.8 | virginica |
128 | 6.1 | 3.0 | 4.9 | 1.8 | virginica |
129 | 6.4 | 2.8 | 5.6 | 2.1 | virginica |
130 | 7.2 | 3.0 | 5.8 | 1.6 | virginica |
131 | 7.4 | 2.8 | 6.1 | 1.9 | virginica |
132 | 7.9 | 3.8 | 6.4 | 2.0 | virginica |
133 | 6.4 | 2.8 | 5.6 | 2.2 | virginica |
134 | 6.3 | 2.8 | 5.1 | 1.5 | virginica |
135 | 6.1 | 2.6 | 5.6 | 1.4 | virginica |
136 | 7.7 | 3.0 | 6.1 | 2.3 | virginica |
137 | 6.3 | 3.4 | 5.6 | 2.4 | virginica |
138 | 6.4 | 3.1 | 5.5 | 1.8 | virginica |
139 | 6.0 | 3.0 | 4.8 | 1.8 | virginica |
140 | 6.9 | 3.1 | 5.4 | 2.1 | virginica |
141 | 6.7 | 3.1 | 5.6 | 2.4 | virginica |
142 | 6.9 | 3.1 | 5.1 | 2.3 | virginica |
143 | 5.8 | 2.7 | 5.1 | 1.9 | virginica |
144 | 6.8 | 3.2 | 5.9 | 2.3 | virginica |
145 | 6.7 | 3.3 | 5.7 | 2.5 | virginica |
146 | 6.7 | 3.0 | 5.2 | 2.3 | virginica |
147 | 6.3 | 2.5 | 5.0 | 1.9 | virginica |
148 | 6.5 | 3.0 | 5.2 | 2.0 | virginica |
149 | 6.2 | 3.4 | 5.4 | 2.3 | virginica |
150 | 5.9 | 3.0 | 5.1 | 1.8 | virginica |
import numpy as np
iris = iris.iloc[np.random.permutation(len(iris)), :]
iris
Sepal.Length | Sepal.Width | Petal.Length | Petal.Width | Species | |
---|---|---|---|---|---|
123 | 7.7 | 2.8 | 6.7 | 2.0 | virginica |
138 | 6.4 | 3.1 | 5.5 | 1.8 | virginica |
117 | 6.5 | 3.0 | 5.5 | 1.8 | virginica |
30 | 4.7 | 3.2 | 1.6 | 0.2 | setosa |
134 | 6.3 | 2.8 | 5.1 | 1.5 | virginica |
118 | 7.7 | 3.8 | 6.7 | 2.2 | virginica |
100 | 5.7 | 2.8 | 4.1 | 1.3 | versicolor |
59 | 6.6 | 2.9 | 4.6 | 1.3 | versicolor |
139 | 6.0 | 3.0 | 4.8 | 1.8 | virginica |
124 | 6.3 | 2.7 | 4.9 | 1.8 | virginica |
82 | 5.5 | 2.4 | 3.7 | 1.0 | versicolor |
3 | 4.7 | 3.2 | 1.3 | 0.2 | setosa |
88 | 6.3 | 2.3 | 4.4 | 1.3 | versicolor |
146 | 6.7 | 3.0 | 5.2 | 2.3 | virginica |
114 | 5.7 | 2.5 | 5.0 | 2.0 | virginica |
26 | 5.0 | 3.0 | 1.6 | 0.2 | setosa |
13 | 4.8 | 3.0 | 1.4 | 0.1 | setosa |
142 | 6.9 | 3.1 | 5.1 | 2.3 | virginica |
47 | 5.1 | 3.8 | 1.6 | 0.2 | setosa |
20 | 5.1 | 3.8 | 1.5 | 0.3 | setosa |
54 | 5.5 | 2.3 | 4.0 | 1.3 | versicolor |
62 | 5.9 | 3.0 | 4.2 | 1.5 | versicolor |
102 | 5.8 | 2.7 | 5.1 | 1.9 | virginica |
37 | 5.5 | 3.5 | 1.3 | 0.2 | setosa |
22 | 5.1 | 3.7 | 1.5 | 0.4 | setosa |
85 | 5.4 | 3.0 | 4.5 | 1.5 | versicolor |
119 | 7.7 | 2.6 | 6.9 | 2.3 | virginica |
112 | 6.4 | 2.7 | 5.3 | 1.9 | virginica |
43 | 4.4 | 3.2 | 1.3 | 0.2 | setosa |
132 | 7.9 | 3.8 | 6.4 | 2.0 | virginica |
19 | 5.7 | 3.8 | 1.7 | 0.3 | setosa |
91 | 5.5 | 2.6 | 4.4 | 1.2 | versicolor |
58 | 4.9 | 2.4 | 3.3 | 1.0 | versicolor |
41 | 5.0 | 3.5 | 1.3 | 0.3 | setosa |
69 | 6.2 | 2.2 | 4.5 | 1.5 | versicolor |
90 | 5.5 | 2.5 | 4.0 | 1.3 | versicolor |
143 | 5.8 | 2.7 | 5.1 | 1.9 | virginica |
101 | 6.3 | 3.3 | 6.0 | 2.5 | virginica |
66 | 6.7 | 3.1 | 4.4 | 1.4 | versicolor |
104 | 6.3 | 2.9 | 5.6 | 1.8 | virginica |
76 | 6.6 | 3.0 | 4.4 | 1.4 | versicolor |
113 | 6.8 | 3.0 | 5.5 | 2.1 | virginica |
103 | 7.1 | 3.0 | 5.9 | 2.1 | virginica |
126 | 7.2 | 3.2 | 6.0 | 1.8 | virginica |
65 | 5.6 | 2.9 | 3.6 | 1.3 | versicolor |
136 | 7.7 | 3.0 | 6.1 | 2.3 | virginica |
44 | 5.0 | 3.5 | 1.6 | 0.6 | setosa |
28 | 5.2 | 3.5 | 1.5 | 0.2 | setosa |
64 | 6.1 | 2.9 | 4.7 | 1.4 | versicolor |
51 | 7.0 | 3.2 | 4.7 | 1.4 | versicolor |
99 | 5.1 | 2.5 | 3.0 | 1.1 | versicolor |
70 | 5.6 | 2.5 | 3.9 | 1.1 | versicolor |
39 | 4.4 | 3.0 | 1.3 | 0.2 | setosa |
98 | 6.2 | 2.9 | 4.3 | 1.3 | versicolor |
105 | 6.5 | 3.0 | 5.8 | 2.2 | virginica |
122 | 5.6 | 2.8 | 4.9 | 2.0 | virginica |
115 | 5.8 | 2.8 | 5.1 | 2.4 | virginica |
116 | 6.4 | 3.2 | 5.3 | 2.3 | virginica |
83 | 5.8 | 2.7 | 3.9 | 1.2 | versicolor |
127 | 6.2 | 2.8 | 4.8 | 1.8 | virginica |
93 | 5.8 | 2.6 | 4.0 | 1.2 | versicolor |
34 | 5.5 | 4.2 | 1.4 | 0.2 | setosa |
53 | 6.9 | 3.1 | 4.9 | 1.5 | versicolor |
15 | 5.8 | 4.0 | 1.2 | 0.2 | setosa |
7 | 4.6 | 3.4 | 1.4 | 0.3 | setosa |
40 | 5.1 | 3.4 | 1.5 | 0.2 | setosa |
63 | 6.0 | 2.2 | 4.0 | 1.0 | versicolor |
71 | 5.9 | 3.2 | 4.8 | 1.8 | versicolor |
77 | 6.8 | 2.8 | 4.8 | 1.4 | versicolor |
24 | 5.1 | 3.3 | 1.7 | 0.5 | setosa |
130 | 7.2 | 3.0 | 5.8 | 1.6 | virginica |
125 | 6.7 | 3.3 | 5.7 | 2.1 | virginica |
29 | 5.2 | 3.4 | 1.4 | 0.2 | setosa |
78 | 6.7 | 3.0 | 5.0 | 1.7 | versicolor |
56 | 5.7 | 2.8 | 4.5 | 1.3 | versicolor |
96 | 5.7 | 3.0 | 4.2 | 1.2 | versicolor |
1 | 5.1 | 3.5 | 1.4 | 0.2 | setosa |
135 | 6.1 | 2.6 | 5.6 | 1.4 | virginica |
17 | 5.4 | 3.9 | 1.3 | 0.4 | setosa |
110 | 7.2 | 3.6 | 6.1 | 2.5 | virginica |
79 | 6.0 | 2.9 | 4.5 | 1.5 | versicolor |
60 | 5.2 | 2.7 | 3.9 | 1.4 | versicolor |
6 | 5.4 | 3.9 | 1.7 | 0.4 | setosa |
8 | 5.0 | 3.4 | 1.5 | 0.2 | setosa |
81 | 5.5 | 2.4 | 3.8 | 1.1 | versicolor |
92 | 6.1 | 3.0 | 4.6 | 1.4 | versicolor |
18 | 5.1 | 3.5 | 1.4 | 0.3 | setosa |
140 | 6.9 | 3.1 | 5.4 | 2.1 | virginica |
148 | 6.5 | 3.0 | 5.2 | 2.0 | virginica |
89 | 5.6 | 3.0 | 4.1 | 1.3 | versicolor |
137 | 6.3 | 3.4 | 5.6 | 2.4 | virginica |
38 | 4.9 | 3.6 | 1.4 | 0.1 | setosa |
23 | 4.6 | 3.6 | 1.0 | 0.2 | setosa |
128 | 6.1 | 3.0 | 4.9 | 1.8 | virginica |
144 | 6.8 | 3.2 | 5.9 | 2.3 | virginica |
52 | 6.4 | 3.2 | 4.5 | 1.5 | versicolor |
27 | 5.0 | 3.4 | 1.6 | 0.4 | setosa |
9 | 4.4 | 2.9 | 1.4 | 0.2 | setosa |
97 | 5.7 | 2.9 | 4.2 | 1.3 | versicolor |
108 | 7.3 | 2.9 | 6.3 | 1.8 | virginica |
42 | 4.5 | 2.3 | 1.3 | 0.3 | setosa |
145 | 6.7 | 3.3 | 5.7 | 2.5 | virginica |
74 | 6.1 | 2.8 | 4.7 | 1.2 | versicolor |
48 | 4.6 | 3.2 | 1.4 | 0.2 | setosa |
36 | 5.0 | 3.2 | 1.2 | 0.2 | setosa |
5 | 5.0 | 3.6 | 1.4 | 0.2 | setosa |
45 | 5.1 | 3.8 | 1.9 | 0.4 | setosa |
31 | 4.8 | 3.1 | 1.6 | 0.2 | setosa |
129 | 6.4 | 2.8 | 5.6 | 2.1 | virginica |
84 | 6.0 | 2.7 | 5.1 | 1.6 | versicolor |
141 | 6.7 | 3.1 | 5.6 | 2.4 | virginica |
133 | 6.4 | 2.8 | 5.6 | 2.2 | virginica |
75 | 6.4 | 2.9 | 4.3 | 1.3 | versicolor |
106 | 7.6 | 3.0 | 6.6 | 2.1 | virginica |
12 | 4.8 | 3.4 | 1.6 | 0.2 | setosa |
16 | 5.7 | 4.4 | 1.5 | 0.4 | setosa |
131 | 7.4 | 2.8 | 6.1 | 1.9 | virginica |
147 | 6.3 | 2.5 | 5.0 | 1.9 | virginica |
46 | 4.8 | 3.0 | 1.4 | 0.3 | setosa |
10 | 4.9 | 3.1 | 1.5 | 0.1 | setosa |
4 | 4.6 | 3.1 | 1.5 | 0.2 | setosa |
50 | 5.0 | 3.3 | 1.4 | 0.2 | setosa |
86 | 6.0 | 3.4 | 4.5 | 1.6 | versicolor |
61 | 5.0 | 2.0 | 3.5 | 1.0 | versicolor |
68 | 5.8 | 2.7 | 4.1 | 1.0 | versicolor |
33 | 5.2 | 4.1 | 1.5 | 0.1 | setosa |
35 | 4.9 | 3.1 | 1.5 | 0.2 | setosa |
21 | 5.4 | 3.4 | 1.7 | 0.2 | setosa |
149 | 6.2 | 3.4 | 5.4 | 2.3 | virginica |
87 | 6.7 | 3.1 | 4.7 | 1.5 | versicolor |
107 | 4.9 | 2.5 | 4.5 | 1.7 | virginica |
94 | 5.0 | 2.3 | 3.3 | 1.0 | versicolor |
109 | 6.7 | 2.5 | 5.8 | 1.8 | virginica |
72 | 6.1 | 2.8 | 4.0 | 1.3 | versicolor |
150 | 5.9 | 3.0 | 5.1 | 1.8 | virginica |
55 | 6.5 | 2.8 | 4.6 | 1.5 | versicolor |
11 | 5.4 | 3.7 | 1.5 | 0.2 | setosa |
67 | 5.6 | 3.0 | 4.5 | 1.5 | versicolor |
121 | 6.9 | 3.2 | 5.7 | 2.3 | virginica |
49 | 5.3 | 3.7 | 1.5 | 0.2 | setosa |
25 | 4.8 | 3.4 | 1.9 | 0.2 | setosa |
111 | 6.5 | 3.2 | 5.1 | 2.0 | virginica |
14 | 4.3 | 3.0 | 1.1 | 0.1 | setosa |
80 | 5.7 | 2.6 | 3.5 | 1.0 | versicolor |
95 | 5.6 | 2.7 | 4.2 | 1.3 | versicolor |
57 | 6.3 | 3.3 | 4.7 | 1.6 | versicolor |
32 | 5.4 | 3.4 | 1.5 | 0.4 | setosa |
73 | 6.3 | 2.5 | 4.9 | 1.5 | versicolor |
2 | 4.9 | 3.0 | 1.4 | 0.2 | setosa |
120 | 6.0 | 2.2 | 5.0 | 1.5 | virginica |
iris_data = iris[['Sepal.Length','Sepal.Width', 'Petal.Length', 'Petal.Width']]
iris_data
Sepal.Length | Sepal.Width | Petal.Length | Petal.Width | |
---|---|---|---|---|
123 | 7.7 | 2.8 | 6.7 | 2.0 |
138 | 6.4 | 3.1 | 5.5 | 1.8 |
117 | 6.5 | 3.0 | 5.5 | 1.8 |
30 | 4.7 | 3.2 | 1.6 | 0.2 |
134 | 6.3 | 2.8 | 5.1 | 1.5 |
118 | 7.7 | 3.8 | 6.7 | 2.2 |
100 | 5.7 | 2.8 | 4.1 | 1.3 |
59 | 6.6 | 2.9 | 4.6 | 1.3 |
139 | 6.0 | 3.0 | 4.8 | 1.8 |
124 | 6.3 | 2.7 | 4.9 | 1.8 |
82 | 5.5 | 2.4 | 3.7 | 1.0 |
3 | 4.7 | 3.2 | 1.3 | 0.2 |
88 | 6.3 | 2.3 | 4.4 | 1.3 |
146 | 6.7 | 3.0 | 5.2 | 2.3 |
114 | 5.7 | 2.5 | 5.0 | 2.0 |
26 | 5.0 | 3.0 | 1.6 | 0.2 |
13 | 4.8 | 3.0 | 1.4 | 0.1 |
142 | 6.9 | 3.1 | 5.1 | 2.3 |
47 | 5.1 | 3.8 | 1.6 | 0.2 |
20 | 5.1 | 3.8 | 1.5 | 0.3 |
54 | 5.5 | 2.3 | 4.0 | 1.3 |
62 | 5.9 | 3.0 | 4.2 | 1.5 |
102 | 5.8 | 2.7 | 5.1 | 1.9 |
37 | 5.5 | 3.5 | 1.3 | 0.2 |
22 | 5.1 | 3.7 | 1.5 | 0.4 |
85 | 5.4 | 3.0 | 4.5 | 1.5 |
119 | 7.7 | 2.6 | 6.9 | 2.3 |
112 | 6.4 | 2.7 | 5.3 | 1.9 |
43 | 4.4 | 3.2 | 1.3 | 0.2 |
132 | 7.9 | 3.8 | 6.4 | 2.0 |
19 | 5.7 | 3.8 | 1.7 | 0.3 |
91 | 5.5 | 2.6 | 4.4 | 1.2 |
58 | 4.9 | 2.4 | 3.3 | 1.0 |
41 | 5.0 | 3.5 | 1.3 | 0.3 |
69 | 6.2 | 2.2 | 4.5 | 1.5 |
90 | 5.5 | 2.5 | 4.0 | 1.3 |
143 | 5.8 | 2.7 | 5.1 | 1.9 |
101 | 6.3 | 3.3 | 6.0 | 2.5 |
66 | 6.7 | 3.1 | 4.4 | 1.4 |
104 | 6.3 | 2.9 | 5.6 | 1.8 |
76 | 6.6 | 3.0 | 4.4 | 1.4 |
113 | 6.8 | 3.0 | 5.5 | 2.1 |
103 | 7.1 | 3.0 | 5.9 | 2.1 |
126 | 7.2 | 3.2 | 6.0 | 1.8 |
65 | 5.6 | 2.9 | 3.6 | 1.3 |
136 | 7.7 | 3.0 | 6.1 | 2.3 |
44 | 5.0 | 3.5 | 1.6 | 0.6 |
28 | 5.2 | 3.5 | 1.5 | 0.2 |
64 | 6.1 | 2.9 | 4.7 | 1.4 |
51 | 7.0 | 3.2 | 4.7 | 1.4 |
99 | 5.1 | 2.5 | 3.0 | 1.1 |
70 | 5.6 | 2.5 | 3.9 | 1.1 |
39 | 4.4 | 3.0 | 1.3 | 0.2 |
98 | 6.2 | 2.9 | 4.3 | 1.3 |
105 | 6.5 | 3.0 | 5.8 | 2.2 |
122 | 5.6 | 2.8 | 4.9 | 2.0 |
115 | 5.8 | 2.8 | 5.1 | 2.4 |
116 | 6.4 | 3.2 | 5.3 | 2.3 |
83 | 5.8 | 2.7 | 3.9 | 1.2 |
127 | 6.2 | 2.8 | 4.8 | 1.8 |
93 | 5.8 | 2.6 | 4.0 | 1.2 |
34 | 5.5 | 4.2 | 1.4 | 0.2 |
53 | 6.9 | 3.1 | 4.9 | 1.5 |
15 | 5.8 | 4.0 | 1.2 | 0.2 |
7 | 4.6 | 3.4 | 1.4 | 0.3 |
40 | 5.1 | 3.4 | 1.5 | 0.2 |
63 | 6.0 | 2.2 | 4.0 | 1.0 |
71 | 5.9 | 3.2 | 4.8 | 1.8 |
77 | 6.8 | 2.8 | 4.8 | 1.4 |
24 | 5.1 | 3.3 | 1.7 | 0.5 |
130 | 7.2 | 3.0 | 5.8 | 1.6 |
125 | 6.7 | 3.3 | 5.7 | 2.1 |
29 | 5.2 | 3.4 | 1.4 | 0.2 |
78 | 6.7 | 3.0 | 5.0 | 1.7 |
56 | 5.7 | 2.8 | 4.5 | 1.3 |
96 | 5.7 | 3.0 | 4.2 | 1.2 |
1 | 5.1 | 3.5 | 1.4 | 0.2 |
135 | 6.1 | 2.6 | 5.6 | 1.4 |
17 | 5.4 | 3.9 | 1.3 | 0.4 |
110 | 7.2 | 3.6 | 6.1 | 2.5 |
79 | 6.0 | 2.9 | 4.5 | 1.5 |
60 | 5.2 | 2.7 | 3.9 | 1.4 |
6 | 5.4 | 3.9 | 1.7 | 0.4 |
8 | 5.0 | 3.4 | 1.5 | 0.2 |
81 | 5.5 | 2.4 | 3.8 | 1.1 |
92 | 6.1 | 3.0 | 4.6 | 1.4 |
18 | 5.1 | 3.5 | 1.4 | 0.3 |
140 | 6.9 | 3.1 | 5.4 | 2.1 |
148 | 6.5 | 3.0 | 5.2 | 2.0 |
89 | 5.6 | 3.0 | 4.1 | 1.3 |
137 | 6.3 | 3.4 | 5.6 | 2.4 |
38 | 4.9 | 3.6 | 1.4 | 0.1 |
23 | 4.6 | 3.6 | 1.0 | 0.2 |
128 | 6.1 | 3.0 | 4.9 | 1.8 |
144 | 6.8 | 3.2 | 5.9 | 2.3 |
52 | 6.4 | 3.2 | 4.5 | 1.5 |
27 | 5.0 | 3.4 | 1.6 | 0.4 |
9 | 4.4 | 2.9 | 1.4 | 0.2 |
97 | 5.7 | 2.9 | 4.2 | 1.3 |
108 | 7.3 | 2.9 | 6.3 | 1.8 |
42 | 4.5 | 2.3 | 1.3 | 0.3 |
145 | 6.7 | 3.3 | 5.7 | 2.5 |
74 | 6.1 | 2.8 | 4.7 | 1.2 |
48 | 4.6 | 3.2 | 1.4 | 0.2 |
36 | 5.0 | 3.2 | 1.2 | 0.2 |
5 | 5.0 | 3.6 | 1.4 | 0.2 |
45 | 5.1 | 3.8 | 1.9 | 0.4 |
31 | 4.8 | 3.1 | 1.6 | 0.2 |
129 | 6.4 | 2.8 | 5.6 | 2.1 |
84 | 6.0 | 2.7 | 5.1 | 1.6 |
141 | 6.7 | 3.1 | 5.6 | 2.4 |
133 | 6.4 | 2.8 | 5.6 | 2.2 |
75 | 6.4 | 2.9 | 4.3 | 1.3 |
106 | 7.6 | 3.0 | 6.6 | 2.1 |
12 | 4.8 | 3.4 | 1.6 | 0.2 |
16 | 5.7 | 4.4 | 1.5 | 0.4 |
131 | 7.4 | 2.8 | 6.1 | 1.9 |
147 | 6.3 | 2.5 | 5.0 | 1.9 |
46 | 4.8 | 3.0 | 1.4 | 0.3 |
10 | 4.9 | 3.1 | 1.5 | 0.1 |
4 | 4.6 | 3.1 | 1.5 | 0.2 |
50 | 5.0 | 3.3 | 1.4 | 0.2 |
86 | 6.0 | 3.4 | 4.5 | 1.6 |
61 | 5.0 | 2.0 | 3.5 | 1.0 |
68 | 5.8 | 2.7 | 4.1 | 1.0 |
33 | 5.2 | 4.1 | 1.5 | 0.1 |
35 | 4.9 | 3.1 | 1.5 | 0.2 |
21 | 5.4 | 3.4 | 1.7 | 0.2 |
149 | 6.2 | 3.4 | 5.4 | 2.3 |
87 | 6.7 | 3.1 | 4.7 | 1.5 |
107 | 4.9 | 2.5 | 4.5 | 1.7 |
94 | 5.0 | 2.3 | 3.3 | 1.0 |
109 | 6.7 | 2.5 | 5.8 | 1.8 |
72 | 6.1 | 2.8 | 4.0 | 1.3 |
150 | 5.9 | 3.0 | 5.1 | 1.8 |
55 | 6.5 | 2.8 | 4.6 | 1.5 |
11 | 5.4 | 3.7 | 1.5 | 0.2 |
67 | 5.6 | 3.0 | 4.5 | 1.5 |
121 | 6.9 | 3.2 | 5.7 | 2.3 |
49 | 5.3 | 3.7 | 1.5 | 0.2 |
25 | 4.8 | 3.4 | 1.9 | 0.2 |
111 | 6.5 | 3.2 | 5.1 | 2.0 |
14 | 4.3 | 3.0 | 1.1 | 0.1 |
80 | 5.7 | 2.6 | 3.5 | 1.0 |
95 | 5.6 | 2.7 | 4.2 | 1.3 |
57 | 6.3 | 3.3 | 4.7 | 1.6 |
32 | 5.4 | 3.4 | 1.5 | 0.4 |
73 | 6.3 | 2.5 | 4.9 | 1.5 |
2 | 4.9 | 3.0 | 1.4 | 0.2 |
120 | 6.0 | 2.2 | 5.0 | 1.5 |
iris_species = iris[['Species']]
iris_species
Species | |
---|---|
123 | virginica |
138 | virginica |
117 | virginica |
30 | setosa |
134 | virginica |
118 | virginica |
100 | versicolor |
59 | versicolor |
139 | virginica |
124 | virginica |
82 | versicolor |
3 | setosa |
88 | versicolor |
146 | virginica |
114 | virginica |
26 | setosa |
13 | setosa |
142 | virginica |
47 | setosa |
20 | setosa |
54 | versicolor |
62 | versicolor |
102 | virginica |
37 | setosa |
22 | setosa |
85 | versicolor |
119 | virginica |
112 | virginica |
43 | setosa |
132 | virginica |
19 | setosa |
91 | versicolor |
58 | versicolor |
41 | setosa |
69 | versicolor |
90 | versicolor |
143 | virginica |
101 | virginica |
66 | versicolor |
104 | virginica |
76 | versicolor |
113 | virginica |
103 | virginica |
126 | virginica |
65 | versicolor |
136 | virginica |
44 | setosa |
28 | setosa |
64 | versicolor |
51 | versicolor |
99 | versicolor |
70 | versicolor |
39 | setosa |
98 | versicolor |
105 | virginica |
122 | virginica |
115 | virginica |
116 | virginica |
83 | versicolor |
127 | virginica |
93 | versicolor |
34 | setosa |
53 | versicolor |
15 | setosa |
7 | setosa |
40 | setosa |
63 | versicolor |
71 | versicolor |
77 | versicolor |
24 | setosa |
130 | virginica |
125 | virginica |
29 | setosa |
78 | versicolor |
56 | versicolor |
96 | versicolor |
1 | setosa |
135 | virginica |
17 | setosa |
110 | virginica |
79 | versicolor |
60 | versicolor |
6 | setosa |
8 | setosa |
81 | versicolor |
92 | versicolor |
18 | setosa |
140 | virginica |
148 | virginica |
89 | versicolor |
137 | virginica |
38 | setosa |
23 | setosa |
128 | virginica |
144 | virginica |
52 | versicolor |
27 | setosa |
9 | setosa |
97 | versicolor |
108 | virginica |
42 | setosa |
145 | virginica |
74 | versicolor |
48 | setosa |
36 | setosa |
5 | setosa |
45 | setosa |
31 | setosa |
129 | virginica |
84 | versicolor |
141 | virginica |
133 | virginica |
75 | versicolor |
106 | virginica |
12 | setosa |
16 | setosa |
131 | virginica |
147 | virginica |
46 | setosa |
10 | setosa |
4 | setosa |
50 | setosa |
86 | versicolor |
61 | versicolor |
68 | versicolor |
33 | setosa |
35 | setosa |
21 | setosa |
149 | virginica |
87 | versicolor |
107 | virginica |
94 | versicolor |
109 | virginica |
72 | versicolor |
150 | virginica |
55 | versicolor |
11 | setosa |
67 | versicolor |
121 | virginica |
49 | setosa |
25 | setosa |
111 | virginica |
14 | setosa |
80 | versicolor |
95 | versicolor |
57 | versicolor |
32 | setosa |
73 | versicolor |
2 | setosa |
120 | virginica |
X_train, X_test, y_train, y_test = train_test_split(iris_data, iris_species,
test_size=0.2, random_state=11)
dt_clf = DecisionTreeClassifier(random_state=11)
dt_clf.fit(X_train, y_train)
DecisionTreeClassifier(random_state=11)In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.
DecisionTreeClassifier(random_state=11)
pred = dt_clf.predict(X_test)
accuracy_score(y_test, pred)
0.9666666666666667
03 사이킷런의 기반 프레임워크 익히기¶
p95부터
내장된 예제 데이터 세트¶
붓꽃 데이터 세트 생성
from sklearn.datasets import load_iris
iris_data = load_iris()
print(type(iris_data))
<class 'sklearn.utils._bunch.Bunch'>
load_iris() API의 반환 결과는 sklearn.utils.Bunch 클래스
Bunch 클래스는 파이썬 딕셔너리 자료형과 유사
데이터 세트에 내장되어 있는 대부분의 데이터 세트는 이와 딕셔너리 형태의 값을 반환
딕셔너리 형태이므로 load_iris() 데이터 세트의 key 값을 확인해보자.
이들 중 'data', 'target_names', 'feature_names'가 주요한 key 값
keys = iris_data.keys()
print('붓꽃 데이터 세트의 키들: ', keys)
붓꽃 데이터 세트의 키들: dict_keys(['data', 'target', 'frame', 'target_names', 'DESCR', 'feature_names', 'filename', 'data_module'])
데이터 키
: 피처들의 데이터 값
데이터 세트가 딕셔너리 형태이기 때문에 피처 데이터 값을 추출하기 위해서는 데이터 세트, data(또는 데이터 세트['data']) 이용
마찬가지로 target, feature_names, DESCR key가 가리키는 데이터 값의 추출도 동일하게 수행
load_iris()가 반환하는 객체의 키인 feature_names, target_name, data, target이 가리키는 값을 다음 예제 코드에 출력
print('\n feature_names의 type: ', type(iris_data.feature_names))
feature_names의 type: <class 'list'>
print(' feature_names의 shape: ', len(iris_data.feature_names))
feature_names의 shape: 4
print(iris_data.feature_names)
['sepal length (cm)', 'sepal width (cm)', 'petal length (cm)', 'petal width (cm)']
print('\n target_names의 type: ', type(iris_data.target_names))
target_names의 type: <class 'numpy.ndarray'>
print(' target_names의 shape:', len(iris_data.target_names))
target_names의 shape: 3
print('\n data의 type: ', type(iris_data.data))
data의 type: <class 'numpy.ndarray'>
print(' data의 shape: ', iris_data.data.shape)
data의 shape: (150, 4)
print(iris_data['data'])
[[5.1 3.5 1.4 0.2] [4.9 3. 1.4 0.2] [4.7 3.2 1.3 0.2] [4.6 3.1 1.5 0.2] [5. 3.6 1.4 0.2] [5.4 3.9 1.7 0.4] [4.6 3.4 1.4 0.3] [5. 3.4 1.5 0.2] [4.4 2.9 1.4 0.2] [4.9 3.1 1.5 0.1] [5.4 3.7 1.5 0.2] [4.8 3.4 1.6 0.2] [4.8 3. 1.4 0.1] [4.3 3. 1.1 0.1] [5.8 4. 1.2 0.2] [5.7 4.4 1.5 0.4] [5.4 3.9 1.3 0.4] [5.1 3.5 1.4 0.3] [5.7 3.8 1.7 0.3] [5.1 3.8 1.5 0.3] [5.4 3.4 1.7 0.2] [5.1 3.7 1.5 0.4] [4.6 3.6 1. 0.2] [5.1 3.3 1.7 0.5] [4.8 3.4 1.9 0.2] [5. 3. 1.6 0.2] [5. 3.4 1.6 0.4] [5.2 3.5 1.5 0.2] [5.2 3.4 1.4 0.2] [4.7 3.2 1.6 0.2] [4.8 3.1 1.6 0.2] [5.4 3.4 1.5 0.4] [5.2 4.1 1.5 0.1] [5.5 4.2 1.4 0.2] [4.9 3.1 1.5 0.2] [5. 3.2 1.2 0.2] [5.5 3.5 1.3 0.2] [4.9 3.6 1.4 0.1] [4.4 3. 1.3 0.2] [5.1 3.4 1.5 0.2] [5. 3.5 1.3 0.3] [4.5 2.3 1.3 0.3] [4.4 3.2 1.3 0.2] [5. 3.5 1.6 0.6] [5.1 3.8 1.9 0.4] [4.8 3. 1.4 0.3] [5.1 3.8 1.6 0.2] [4.6 3.2 1.4 0.2] [5.3 3.7 1.5 0.2] [5. 3.3 1.4 0.2] [7. 3.2 4.7 1.4] [6.4 3.2 4.5 1.5] [6.9 3.1 4.9 1.5] [5.5 2.3 4. 1.3] [6.5 2.8 4.6 1.5] [5.7 2.8 4.5 1.3] [6.3 3.3 4.7 1.6] [4.9 2.4 3.3 1. ] [6.6 2.9 4.6 1.3] [5.2 2.7 3.9 1.4] [5. 2. 3.5 1. ] [5.9 3. 4.2 1.5] [6. 2.2 4. 1. ] [6.1 2.9 4.7 1.4] [5.6 2.9 3.6 1.3] [6.7 3.1 4.4 1.4] [5.6 3. 4.5 1.5] [5.8 2.7 4.1 1. ] [6.2 2.2 4.5 1.5] [5.6 2.5 3.9 1.1] [5.9 3.2 4.8 1.8] [6.1 2.8 4. 1.3] [6.3 2.5 4.9 1.5] [6.1 2.8 4.7 1.2] [6.4 2.9 4.3 1.3] [6.6 3. 4.4 1.4] [6.8 2.8 4.8 1.4] [6.7 3. 5. 1.7] [6. 2.9 4.5 1.5] [5.7 2.6 3.5 1. ] [5.5 2.4 3.8 1.1] [5.5 2.4 3.7 1. ] [5.8 2.7 3.9 1.2] [6. 2.7 5.1 1.6] [5.4 3. 4.5 1.5] [6. 3.4 4.5 1.6] [6.7 3.1 4.7 1.5] [6.3 2.3 4.4 1.3] [5.6 3. 4.1 1.3] [5.5 2.5 4. 1.3] [5.5 2.6 4.4 1.2] [6.1 3. 4.6 1.4] [5.8 2.6 4. 1.2] [5. 2.3 3.3 1. ] [5.6 2.7 4.2 1.3] [5.7 3. 4.2 1.2] [5.7 2.9 4.2 1.3] [6.2 2.9 4.3 1.3] [5.1 2.5 3. 1.1] [5.7 2.8 4.1 1.3] [6.3 3.3 6. 2.5] [5.8 2.7 5.1 1.9] [7.1 3. 5.9 2.1] [6.3 2.9 5.6 1.8] [6.5 3. 5.8 2.2] [7.6 3. 6.6 2.1] [4.9 2.5 4.5 1.7] [7.3 2.9 6.3 1.8] [6.7 2.5 5.8 1.8] [7.2 3.6 6.1 2.5] [6.5 3.2 5.1 2. ] [6.4 2.7 5.3 1.9] [6.8 3. 5.5 2.1] [5.7 2.5 5. 2. ] [5.8 2.8 5.1 2.4] [6.4 3.2 5.3 2.3] [6.5 3. 5.5 1.8] [7.7 3.8 6.7 2.2] [7.7 2.6 6.9 2.3] [6. 2.2 5. 1.5] [6.9 3.2 5.7 2.3] [5.6 2.8 4.9 2. ] [7.7 2.8 6.7 2. ] [6.3 2.7 4.9 1.8] [6.7 3.3 5.7 2.1] [7.2 3.2 6. 1.8] [6.2 2.8 4.8 1.8] [6.1 3. 4.9 1.8] [6.4 2.8 5.6 2.1] [7.2 3. 5.8 1.6] [7.4 2.8 6.1 1.9] [7.9 3.8 6.4 2. ] [6.4 2.8 5.6 2.2] [6.3 2.8 5.1 1.5] [6.1 2.6 5.6 1.4] [7.7 3. 6.1 2.3] [6.3 3.4 5.6 2.4] [6.4 3.1 5.5 1.8] [6. 3. 4.8 1.8] [6.9 3.1 5.4 2.1] [6.7 3.1 5.6 2.4] [6.9 3.1 5.1 2.3] [5.8 2.7 5.1 1.9] [6.8 3.2 5.9 2.3] [6.7 3.3 5.7 2.5] [6.7 3. 5.2 2.3] [6.3 2.5 5. 1.9] [6.5 3. 5.2 2. ] [6.2 3.4 5.4 2.3] [5.9 3. 5.1 1.8]]
print('\n target의 type: ', type(iris_data.target))
target의 type: <class 'numpy.ndarray'>
print(' target의 shape: ', iris_data.target.shape)
target의 shape: (150,)
print(iris_data.target)
[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2]
04 Model Selection 모듈 소개¶
학습/테스트 데이터 세트 분리 - train_test_split()¶
먼저 테스트 데이터 세트를 이용하지 않고 학습 데이터 세트로만 학습하고 예측하면 무엇이 문제인지 살펴보자.
다음 예제는 학습과 예측을 동일한 데이터 세트로 수행한 결과
from sklearn.datasets import load_iris
from sklearn.tree import DecisionTreeClassifier
from sklearn.metrics import accuracy_score
iris = load_iris()
dt_clf = DecisionTreeClassifier()
train_data = iris.data
train_label = iris.target
dt_clf.fit(train_data, train_label)
# 학습 데이터 세트로 예측 수행
pred = dt_clf.predict(train_data)
print('예측 정확도:', accuracy_score(train_label, pred))
예측 정확도: 1.0
정확도가 100%이다.뭔가 이상하다
위의 예측 결과가 100% 정확한 이유는 이미 학습한 학습 데이터 세트를 기반으로 예측했기 때문
따라서 예측을 수행하는 데이터 세트는 학습을 수행한 학습용 데이터 세트가 아닌 전용의 테스트 데이터 세트여야 함
사이킷런의 train_test_split()
를 통해 원본 데이터 세트에서 학습 및 테스트 데이터 세트를 쉽게 분리 가능
train_test_split()
를 이용해 붓꽃 데이터 세트를 학습 및 테스트 데이터 세트로 분리하자
먼저 sklearn.model_selection
모듈에서 train_test_split
을 로드하기
train_test_split()
는 첫 번째 파라미터로 피처 데이터 세트
, 두 번째 파라미터로 레이블 데이터 세트
를 입력받음
그리고 선택적으로 다음 파라미터를 입력받음
test_size¶
전체 데이터에서 테스트 데이터 세트 크기를 얼마로 샘플링할 것인가를 결정
디폴트는 0.25, 즉 25%
train_size¶
전체 데이터에서 학습용 데이터 세트 크기를 얼마로 샘플링할 것인가를 결정
test_size parameter를 통상적으로 사용하기 때문에 train_size는 잘 사용되지 X
shuffle¶
데이터를 분리하기 전에 데이터를 미리 섞을지를 결정
디폴트는 True. 데이터를 분사시켜서 좀 더 효율적인 학습 및 테스트 데이터 세트를 만드는 데 사용
random_state¶
random_state는 호출할 때마다 동일한 학습/테스트용 데이터 세트를 생성하기 위해 주어지는 난수 값
train_test_split()는 호출 시 무작위로 데이터를 분리하므로 random_state를 지정하지 않으면 수행할 때마다 다른 학습/테스트용 데이터 생성
train_test_split()
의 반환값은 튜플 형태
순차적으로 학습용 데이터의 피처 데이터 세트
, 테스트용 데이터의 피처 데이터 세트
, 학습용 데이터의 레이블 데이터 세트
, 테스트용 데이터의 레이블 데이터 세트
반환
붓꽃 데이터 세트를 train_test_split()을 이용해 테스트 데이터 세트를 전체의 30%로, 학습 데이터 세트를 70%로 분리
앞의 예제와는 다르게 random_state=121로 변경해 데이터 세트 변화시키기
from sklearn.tree import DecisionTreeClassifier
from sklearn.metrics import accuracy_score
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
dt_clf = DecisionTreeClassifier()
iris_data = load_iris()
X_train, X_test, y_train, y_test = train_test_split(iris_data.data, iris_data.target, test_size=0.3, random_state=121)
학습 데이터를 기반으로 DecisionTreeClassifier를 학습하고 이 모델을 이용해 예측 정확도를 측정
dt_clf.fit(X_train, y_train)
DecisionTreeClassifier()In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.
DecisionTreeClassifier()
pred = dt_clf.predict(X_test)
print('예측 정확도: {0:.4f}'.format(accuracy_score(y_test, pred)))
예측 정확도: 0.9556
테스트 데이터로 예측을 수행한 결과 정확도가 약 95.56%
붓꽃 데이터는 150개의 데이터로 데이터 양이 크지 않아 전체의 30% 정도인 테스트 데이터는 45개 정도밖에 되지 않으므로 이를 통해 알고리즘의 예측 성능을 판단하기에는 그리 적절하지 X
학습을 위한 데이터의 양을 일정 수준 이상으로 보장하는 것도 중요하지만, *학습된 모델에 대해 다양한 데이터를 기반으로 예측 성능을 평가*해 보는 것도 매우 중요
교차 검증¶
K 폴드 교차 검증¶
사이킷런에서는 K 폴드 교차 검증 프로세스를 구현하기 위해 KFold
와 StratifiedKFold
클래스 제공
먼저 KFold 클래스를 이용해 붓꽃 데이터 세트를 교차 검증하고 예측 정확도를 알아보자.
붓꽃 데이터 세트와 DecisionTreeClassifier를 다시 생성. 그리고 5개의 폴드 세트로 분리하는 KFold 객체 생성
from sklearn.tree import DecisionTreeClassifier
from sklearn.metrics import accuracy_score
from sklearn.model_selection import KFold
import numpy as np
iris = load_iris()
features = iris.data
label = iris.target
dt_clf = DecisionTreeClassifier(random_state=156)
# 5개의 폴드 세트로 분리하는 KFold 객체와 폴드 세트별 정확도를 담을 리스트 객체 생성
kfold = KFold(n_splits=5)
cv_accuracy = []
print('붓꽃 데이터 세트 크기: ', features.shape[0]) #행
붓꽃 데이터 세트 크기: 150
features.shape #행, 열
(150, 4)
KFold(n_splits=5)로 KFold 객체를 생성했으니 이제 생성된 KFold 객체의 split()
을 호출해 전체 붓꽃 데이터를 5개의 폴드 데이터 세트로 분리
전체 붓꽃 데이터는 모두 150개. 따라서 학습용 데이터 세트
는 이 중 4/5인 120개, 검증 데이터 세트
는 1/5인 30개로 분할
KFold 객체는 split()
을 호출하면 학습용/검증용 데이터로 분할할 수 있는 인덱스를 반환
다음 예제는 5개의 폴드 세트를 생성하는 KFold 객체의 split()
을 호출해 교차 검증 수행 시마다 학습과 검증을 반복해 예측 정확도 측정
그리고 split()
이 어떤 값을 실제로 반환하는지도 확인해 보기 위해 검증 데이터 세트의 인덱스도 추출
features
array([[5.1, 3.5, 1.4, 0.2], [4.9, 3. , 1.4, 0.2], [4.7, 3.2, 1.3, 0.2], [4.6, 3.1, 1.5, 0.2], [5. , 3.6, 1.4, 0.2], [5.4, 3.9, 1.7, 0.4], [4.6, 3.4, 1.4, 0.3], [5. , 3.4, 1.5, 0.2], [4.4, 2.9, 1.4, 0.2], [4.9, 3.1, 1.5, 0.1], [5.4, 3.7, 1.5, 0.2], [4.8, 3.4, 1.6, 0.2], [4.8, 3. , 1.4, 0.1], [4.3, 3. , 1.1, 0.1], [5.8, 4. , 1.2, 0.2], [5.7, 4.4, 1.5, 0.4], [5.4, 3.9, 1.3, 0.4], [5.1, 3.5, 1.4, 0.3], [5.7, 3.8, 1.7, 0.3], [5.1, 3.8, 1.5, 0.3], [5.4, 3.4, 1.7, 0.2], [5.1, 3.7, 1.5, 0.4], [4.6, 3.6, 1. , 0.2], [5.1, 3.3, 1.7, 0.5], [4.8, 3.4, 1.9, 0.2], [5. , 3. , 1.6, 0.2], [5. , 3.4, 1.6, 0.4], [5.2, 3.5, 1.5, 0.2], [5.2, 3.4, 1.4, 0.2], [4.7, 3.2, 1.6, 0.2], [4.8, 3.1, 1.6, 0.2], [5.4, 3.4, 1.5, 0.4], [5.2, 4.1, 1.5, 0.1], [5.5, 4.2, 1.4, 0.2], [4.9, 3.1, 1.5, 0.2], [5. , 3.2, 1.2, 0.2], [5.5, 3.5, 1.3, 0.2], [4.9, 3.6, 1.4, 0.1], [4.4, 3. , 1.3, 0.2], [5.1, 3.4, 1.5, 0.2], [5. , 3.5, 1.3, 0.3], [4.5, 2.3, 1.3, 0.3], [4.4, 3.2, 1.3, 0.2], [5. , 3.5, 1.6, 0.6], [5.1, 3.8, 1.9, 0.4], [4.8, 3. , 1.4, 0.3], [5.1, 3.8, 1.6, 0.2], [4.6, 3.2, 1.4, 0.2], [5.3, 3.7, 1.5, 0.2], [5. , 3.3, 1.4, 0.2], [7. , 3.2, 4.7, 1.4], [6.4, 3.2, 4.5, 1.5], [6.9, 3.1, 4.9, 1.5], [5.5, 2.3, 4. , 1.3], [6.5, 2.8, 4.6, 1.5], [5.7, 2.8, 4.5, 1.3], [6.3, 3.3, 4.7, 1.6], [4.9, 2.4, 3.3, 1. ], [6.6, 2.9, 4.6, 1.3], [5.2, 2.7, 3.9, 1.4], [5. , 2. , 3.5, 1. ], [5.9, 3. , 4.2, 1.5], [6. , 2.2, 4. , 1. ], [6.1, 2.9, 4.7, 1.4], [5.6, 2.9, 3.6, 1.3], [6.7, 3.1, 4.4, 1.4], [5.6, 3. , 4.5, 1.5], [5.8, 2.7, 4.1, 1. ], [6.2, 2.2, 4.5, 1.5], [5.6, 2.5, 3.9, 1.1], [5.9, 3.2, 4.8, 1.8], [6.1, 2.8, 4. , 1.3], [6.3, 2.5, 4.9, 1.5], [6.1, 2.8, 4.7, 1.2], [6.4, 2.9, 4.3, 1.3], [6.6, 3. , 4.4, 1.4], [6.8, 2.8, 4.8, 1.4], [6.7, 3. , 5. , 1.7], [6. , 2.9, 4.5, 1.5], [5.7, 2.6, 3.5, 1. ], [5.5, 2.4, 3.8, 1.1], [5.5, 2.4, 3.7, 1. ], [5.8, 2.7, 3.9, 1.2], [6. , 2.7, 5.1, 1.6], [5.4, 3. , 4.5, 1.5], [6. , 3.4, 4.5, 1.6], [6.7, 3.1, 4.7, 1.5], [6.3, 2.3, 4.4, 1.3], [5.6, 3. , 4.1, 1.3], [5.5, 2.5, 4. , 1.3], [5.5, 2.6, 4.4, 1.2], [6.1, 3. , 4.6, 1.4], [5.8, 2.6, 4. , 1.2], [5. , 2.3, 3.3, 1. ], [5.6, 2.7, 4.2, 1.3], [5.7, 3. , 4.2, 1.2], [5.7, 2.9, 4.2, 1.3], [6.2, 2.9, 4.3, 1.3], [5.1, 2.5, 3. , 1.1], [5.7, 2.8, 4.1, 1.3], [6.3, 3.3, 6. , 2.5], [5.8, 2.7, 5.1, 1.9], [7.1, 3. , 5.9, 2.1], [6.3, 2.9, 5.6, 1.8], [6.5, 3. , 5.8, 2.2], [7.6, 3. , 6.6, 2.1], [4.9, 2.5, 4.5, 1.7], [7.3, 2.9, 6.3, 1.8], [6.7, 2.5, 5.8, 1.8], [7.2, 3.6, 6.1, 2.5], [6.5, 3.2, 5.1, 2. ], [6.4, 2.7, 5.3, 1.9], [6.8, 3. , 5.5, 2.1], [5.7, 2.5, 5. , 2. ], [5.8, 2.8, 5.1, 2.4], [6.4, 3.2, 5.3, 2.3], [6.5, 3. , 5.5, 1.8], [7.7, 3.8, 6.7, 2.2], [7.7, 2.6, 6.9, 2.3], [6. , 2.2, 5. , 1.5], [6.9, 3.2, 5.7, 2.3], [5.6, 2.8, 4.9, 2. ], [7.7, 2.8, 6.7, 2. ], [6.3, 2.7, 4.9, 1.8], [6.7, 3.3, 5.7, 2.1], [7.2, 3.2, 6. , 1.8], [6.2, 2.8, 4.8, 1.8], [6.1, 3. , 4.9, 1.8], [6.4, 2.8, 5.6, 2.1], [7.2, 3. , 5.8, 1.6], [7.4, 2.8, 6.1, 1.9], [7.9, 3.8, 6.4, 2. ], [6.4, 2.8, 5.6, 2.2], [6.3, 2.8, 5.1, 1.5], [6.1, 2.6, 5.6, 1.4], [7.7, 3. , 6.1, 2.3], [6.3, 3.4, 5.6, 2.4], [6.4, 3.1, 5.5, 1.8], [6. , 3. , 4.8, 1.8], [6.9, 3.1, 5.4, 2.1], [6.7, 3.1, 5.6, 2.4], [6.9, 3.1, 5.1, 2.3], [5.8, 2.7, 5.1, 1.9], [6.8, 3.2, 5.9, 2.3], [6.7, 3.3, 5.7, 2.5], [6.7, 3. , 5.2, 2.3], [6.3, 2.5, 5. , 1.9], [6.5, 3. , 5.2, 2. ], [6.2, 3.4, 5.4, 2.3], [5.9, 3. , 5.1, 1.8]])
n_iter = 0
# KFold 객체의 split()을 호출하면 폴드별 학습용, 검증용 테스트의 로우 인덱스를 array로 반환
for train_index, test_index in kfold.split(features):
# kfold.aplit()으로 반환된 인덱스를 이요앻 학습용, 검증용 테스트 데이터 추출\
X_train, X_test = features[train_index], features[test_index]
y_train, y_test = label[train_index], label[test_index]
# 학습 및 예측
dt_clf.fit(X_train, y_train)
pred = dt_clf.predict(X_test)
n_iter += 1
# 반복 시마다 정확도 측정
accuracy = np.round(accuracy_score(y_test, pred), 4)
train_size = X_train.shape[0] #행
test_size = X_test.shape[0] #행
print('\n#{0} 교차 검증 정확도: {1}, 학습 데이터 크기: {2}, 검증 데이터 크기: {3}'
.format(n_iter, accuracy, train_size, test_size))
print('#{0} 검증 세트 인덱스:{1}'.format(n_iter, test_index))
cv_accuracy.append(accuracy)
#개별 iteration별 정확도를 합하여 평균 정확도 계산
print('\n## 평균 검증 정확도: ', np.mean(cv_accuracy))
#1 교차 검증 정확도: 1.0, 학습 데이터 크기: 120, 검증 데이터 크기: 30 #1 검증 세트 인덱스:[ 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29] #2 교차 검증 정확도: 0.9667, 학습 데이터 크기: 120, 검증 데이터 크기: 30 #2 검증 세트 인덱스:[30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59] #3 교차 검증 정확도: 0.8667, 학습 데이터 크기: 120, 검증 데이터 크기: 30 #3 검증 세트 인덱스:[60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89] #4 교차 검증 정확도: 0.9333, 학습 데이터 크기: 120, 검증 데이터 크기: 30 #4 검증 세트 인덱스:[ 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119] #5 교차 검증 정확도: 0.7333, 학습 데이터 크기: 120, 검증 데이터 크기: 30 #5 검증 세트 인덱스:[120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149] ## 평균 검증 정확도: 0.9
5번 교차 검증 결과 평균 검증 정확도는 0.9
그리고 교차 검증 시마다 검증 세트의 인덱스가 달라짐
검증 세트의 인덱스를 보면 교차 검증 시마다 split() 함수가 어떻게 인덱스를 할당하는지 알 수 있음
첫 번째 교차 검증에서는 0번 ~29번까지, 두 번째는 30번 ~ 59번, 세 번째는 60번 ~ 89번, 네 번째는 90 ~ 119번, 다섯 번째는 120 ~ 149번으로 각각 30개의 검증 세트 인덱스를 생성했고, 이를 기반으로 검증 세트 추출
Stratified K 폴드¶
Stratified K 폴드
는 불균형한(imbalanced) 분포도를 가진 레이블(결정 클래스) 데이터 집합을 위한 K 폴드 방식
불균형한 분포도를 가진 레이블 데이터 집합
: 특정 레이블 값이 특이하게 많거나 매우 적어서 값의 분포가 한쪽으로 치우치는 것
Stratified K 폴드
는 K 폴드가 레이블 데이터 집합이 원본 데이터 집합의 레이블 분포를 학습 및 테스트 세트에 제대로 분배하지 못하는 경우의 문제 해결
이를 위해 Stratified K 폴드는 원본 데이터의 레이블 분포를 먼저 고려한 뒤 이 분포와 동일하게 학습과 검증 데이터 세트를 분배
먼저 K 폴드가 어떤 문제를 가지고 있는지 확인해 보고 이를 사이킷런의 StratifiedKFold 클래스
를 이용해 개선해보자.
이를 위해 붓꽃 데이터 세트를 간단하게 DataFrame으로 생성하고 테이블 값의 분포도 확인
import pandas as pd
iris = load_iris()
iris_df = pd.DataFrame(data=iris.data, columns = iris.feature_names)
iris_df['label'] = iris.target
iris_df
sepal length (cm) | sepal width (cm) | petal length (cm) | petal width (cm) | label | |
---|---|---|---|---|---|
0 | 5.1 | 3.5 | 1.4 | 0.2 | 0 |
1 | 4.9 | 3.0 | 1.4 | 0.2 | 0 |
2 | 4.7 | 3.2 | 1.3 | 0.2 | 0 |
3 | 4.6 | 3.1 | 1.5 | 0.2 | 0 |
4 | 5.0 | 3.6 | 1.4 | 0.2 | 0 |
... | ... | ... | ... | ... | ... |
145 | 6.7 | 3.0 | 5.2 | 2.3 | 2 |
146 | 6.3 | 2.5 | 5.0 | 1.9 | 2 |
147 | 6.5 | 3.0 | 5.2 | 2.0 | 2 |
148 | 6.2 | 3.4 | 5.4 | 2.3 | 2 |
149 | 5.9 | 3.0 | 5.1 | 1.8 | 2 |
150 rows × 5 columns
iris_df['label'].value_counts()
0 50 1 50 2 50 Name: label, dtype: int64
레이블 값은 모두 50개로 동일. 즉 Setosa 품종, Versicolor 품종, Virginica 품종 모두가 50개
이슈가 발생하는 현상을 도출하기 위해 3개의 폴드 세트를 KFold로 생성하고, 각 교차 검증 시마다 생성되는 학습/검증 레이블 데이터 값의 분포도 확인
kfold = KFold(n_splits=3)
n_iter = 0
for train_index, test_index in kfold.split(iris_df):
n_iter += 1
label_train = iris_df['label'].iloc[train_index]
label_test = iris_df['label'].iloc[test_index]
print('## 교차 검증: {0}'.format(n_iter))
print('학습 레이블 데이터 분포:\n', label_train.value_counts())
print('검증 레이블 데이터 분포:\n', label_test.value_counts())
## 교차 검증: 1 학습 레이블 데이터 분포: 1 50 2 50 Name: label, dtype: int64 검증 레이블 데이터 분포: 0 50 Name: label, dtype: int64 ## 교차 검증: 2 학습 레이블 데이터 분포: 0 50 2 50 Name: label, dtype: int64 검증 레이블 데이터 분포: 1 50 Name: label, dtype: int64 ## 교차 검증: 3 학습 레이블 데이터 분포: 0 50 1 50 Name: label, dtype: int64 검증 레이블 데이터 분포: 2 50 Name: label, dtype: int64
교차 검증 시마다 3개의 폴드 세트로 만들어지는 학습 레이블과 검증 레이블이 완전히 다른 값으로 추출됨
예를 들어 첫 번째 교차 검증에서는 학습 레이블의 1,2 값이 각각 50개가 추출되었고, 검증 레이블의 0값이 50개 추출됨.
학습 레이블은 1,2 밖에 없으므로 0의 경우는 전혀 학습하지 못함
반대로 검증 레이블은 0밖에 없으므로 학습 모델은 절대 0을 예측하지 못함
이런 유형으로 교차 검증 데이터 세트를 분할하면 검증 예측 정확도는 0이 될 수 밖에 없음
StratifiedKFold
는 이렇게 KFold로 분할된 레이블 데이터 세트가 전체 레이블 값의 분포도를 반영하지 못하는 문제를 해결해줌
이번에는 동일한 데이터 분할을 StratifiedKFold로 수행하고 학습/검증 레이블 데이터의 분포도를 확인하자.
StratifiedKFold의 사용법은 KFold 사용법과 거의 비슷
단 하나 큰 차이는 StratifiedKFold는 레이블 데이터 분포도에 따라 학습/검증 데이터를 나누기 때문에 split() 메서드에 인자로 피처 데이터 세트
뿐만 아니라 레이블 데이터 세트
도 반드시 필요
(K 폴드의 경우 레이블 데이터 세트는 split() 메서드의 인자로 입력하지 않아도 무방)
폴드 세트는 3개로 설정
from sklearn.model_selection import StratifiedKFold
skf = StratifiedKFold(n_splits=3)
n_iter = 0
for train_index, test_index in skf.split(iris_df, iris_df['label']):
n_iter += 1
label_train = iris_df['label'].iloc[train_index]
label_test = iris_df['label'].iloc[test_index]
print('## 교차 검증: {0}'.format(n_iter))
print('학습 레이블 데이터 분포:\n', label_train.value_counts())
print('검증 레이블 데이터 분포:\n', label_test.value_counts())
## 교차 검증: 1 학습 레이블 데이터 분포: 2 34 0 33 1 33 Name: label, dtype: int64 검증 레이블 데이터 분포: 0 17 1 17 2 16 Name: label, dtype: int64 ## 교차 검증: 2 학습 레이블 데이터 분포: 1 34 0 33 2 33 Name: label, dtype: int64 검증 레이블 데이터 분포: 0 17 2 17 1 16 Name: label, dtype: int64 ## 교차 검증: 3 학습 레이블 데이터 분포: 0 34 1 33 2 33 Name: label, dtype: int64 검증 레이블 데이터 분포: 1 17 2 17 0 16 Name: label, dtype: int64
출력 결과를 보면 학습 레이블과 검증 레이블 데이터 값의 분포도가 거의 동일하게 할당됨
전체 150개 데이터에서 학습으로 100개, 검증으로 50개가 교차 검증 단계별로 할당됨
첫 번째 교차 검증에서 100개의 학습 레이블
은 0,1,2 값이 각각 34,33,33개로, 레이블값별로 거의 동일하게 할당됐고,
50개의 검증 레이블
역시 0,1,2 값이 각각 17,17,16개로, 레이블값별로 거의 동일하게 할당
이렇게 분할이 되어야 레이블 값 0,1,2를 모두 학습할 수 있고, 이에 기반해 검증 수행가능
StratifiedKFold를 이용해 붓꽃 데이터를 교차 검증해보자
다음 코드는 StratifiedKFold를 이용해 데이터 분리
피처 데이터와 레이블 데이터는 앞의 붓꽃 StratifiedKFold 예제에서 추출한 데이터를 그대로 이용
dt_clf = DecisionTreeClassifier(random_state=156)
skfold = StratifiedKFold(n_splits=3)
n_iter = 0
cv_accuracy = []
# StratifiedKFold의 split() 호출 시 반드시 레이블 데이터 세트도 추가 입력 필요
for train_index, test_index in skfold.split(features, label):
# split()로 반환된 인덱스를 이용해 학습용, 검증용 테스트 데이터 추출
X_train, X_test = features[train_index], features[test_index]
y_train, y_test = label[train_index], label[test_index]
# 학습 및 예측
dt_clf.fit(X_train, y_train)
pred = dt_clf.predict(X_test)
# 반복 시마다 정확도 측정
n_iter += 1
accuracy = np.round(accuracy_score(y_test, pred), 4)
train_size = X_train.shape[0]
test_size = X_test.shape[0]
print('\n#{0} 교차 검증 정확도:{1}, 학습 데이터 크기: {2}, 검증 데이터 크기: {3}'
.format(n_iter, accuracy, train_size, test_size))
print('#{0} 검증 세트 인덱스:{1}'.format(n_iter, test_index))
cv_accuracy.append(accuracy)
# 교차 검증별 정확도 및 평균 정확도 계산
print('\n## 교차 검증별 정확도:', np.round(cv_accuracy, 4))
print('## 평균 검증 정확도:', np.round(np.mean(cv_accuracy), 4))
#1 교차 검증 정확도:0.98, 학습 데이터 크기: 100, 검증 데이터 크기: 50 #1 검증 세트 인덱스:[ 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115] #2 교차 검증 정확도:0.94, 학습 데이터 크기: 100, 검증 데이터 크기: 50 #2 검증 세트 인덱스:[ 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132] #3 교차 검증 정확도:0.98, 학습 데이터 크기: 100, 검증 데이터 크기: 50 #3 검증 세트 인덱스:[ 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149] ## 교차 검증별 정확도: [0.98 0.94 0.98] ## 평균 검증 정확도: 0.9667
3개의 Stratified K 폴드로 교차 검증한 결과 평균 검증 정확도가 약 96.67%로 측정됨
다음으로 이러한 교차 검증을 보다 간편하게 제공해주는 사이킷런의 API를 살펴보자
교차 검증을 보다 간편하게 - cross_val_score()¶
사이킷런은 교차 검증을 좀 더 편리하게 수행할 수 있게 해주는 API 제공
대표적인 것이 cross_val_score()
KFold로 데이터를 학습하고 예측하는 코드를 보면
- 폴드 세트 설정
- for 루프에서 반복으로 학습 및 테스트 데이터의 인덱스 추출
- 반복적으로 학습과 예측을 수행하고 예측 성능 반환
cross_val_score()
: 이런 일련의 과정을 한꺼번에 수행해주는 API
cross_val_score() API의 선언 형태
cross_val_score(estimator, X, y=None, scoring=None, cv=None, n_jobs=1, verbose=, fit_params=None, pre_dispatch='2*n_jobs')
이 중 estimator, X, y, scoring, cv가 주요 파라미터
estimator
: 사이킷런의 분류 알고리즘 클래스인 Classifier 또는 회귀 알고리즘 클래스인 Regressor를 의미X
: 피처 데이터 세트y
: 레이블 데이터 세트scoring
: 예측 성능 평가 지표cv
: 교차 검증 폴드 수
cross_val_score() 수행 후 반환 값은 scoring 파라미터로 지정된 성능 지표 측정값을 배열 형태로 반환
cross_val_score의 입력값에 따른 분할
classifier 입력
: Stratified K 폴드 방식으로 레이블값의 분포에 따라 학습/데이터 세트 분할회귀인 경우
: Stratified K 폴드 방식으로 분할 불가능하므로 K 폴드 방식으로 분할
다음 코드에서 cross_val_score()의 자세한 사용법을 살펴보자.
교차 검증 폴드 수는 3, 성능 평가 지표는 정확도인 accuracy로 하자
from sklearn.tree import DecisionTreeClassifier
from sklearn.model_selection import cross_val_score, cross_validate
from sklearn.datasets import load_iris
iris_data = load_iris()
dt_clf = DecisionTreeClassifier(random_state=156)
data = iris_data.data
label = iris_data.target
# 성능 지표는 정확도(accuracy), 교차 검증 세트는 3개
scores = cross_val_score(dt_clf,data, label, scoring='accuracy', cv=3)
print("교차 검증별 정확도: ", np.round(scores, 4))
print("평균 검증 정확도: ", np.round(np.mean(scores), 4))
교차 검증별 정확도: [0.98 0.94 0.98] 평균 검증 정확도: 0.9667
cross_val_score()는 cv로 지정된 횟수만큼 scoring 파라미터로 지정된 평가 지표로 평가 결괏값을 배열로 반환. 그리고 일반적으로 이를 평균해 평가 수치로 사용
cross_val_score() API는 내부에서 Estimator를 학습(fig), 예측(predict), 평가(evaluation)시켜주므로 간단하게 교차 검증 수행 가능
붓꽃 데이터의 cross_val_score() 수행 결과와 앞 예제의 붓꽃 데이터 StratifiedKFold의 수행 결과를 비교해 보면 각 교차 검증별 정확도
와 평균 검증 정확도
가 모두 동일
이는 cross_val_score()가 내부적으로 StratifiedKFold를 이용하기 때문
GridSearchCV - 교차 검증과 최적 하이퍼 파라미터 튜닝을 한 번에¶
사이킷런은 GridSearchCV API를 이용해 Classifier나 Regressor와 같은 알고리즘에 사용되는 하이퍼 파라미터를 순차적으로 입력하면서 편리하게 최적의 파라미터를 도출할 수 잇는 방안 제공
(Grid는 격자라는 뜻으로, 촘촘하게 파라미터를 입력하면서 테스트를 하는 방식)
예를 들어 결정 트리 알고리즘의 여러 하이퍼 파라미터를 순차적으로 변경하면서 최고 성능을 가지는 파라미터 조합을 찾고자 한다면 다음과 같이 파라미터의 집합을 만들고 이를 순차적으로 적용하면서 최적화 수행 가능
grid_parameters = {'max_depth':[1,2,3],
'min_samples_split':[2,3]}
GridSearchCV 클래스의 생성자로 들어가는 주요 파라미터
estimator
: classifier, regressor, pipeline이 사용될 수 있음param_grid
: key + 리스트 값을 가지는 딕셔너리가 주어짐. estimator의 튜닝을 위해 파라미터명과 사용될 여러 파라미터 값을 지정scoring
: 예측 성능을 측정할 평가 방법 지정. 보통은 사이킷런의 성능 평가 지표를 지정하는 문자열(예: 정확도의 경우 'accuracy')로 지정하나 별도의 성능 평가 지표 함수도 지정 가능cv
: 교차 검증을 위해 분할되는 학습/테스트 세트의 개수 지정refit
: 디폴트가 True이며 True로 생성 시 가장 최적의 하이퍼 파라미터를 찾은 뒤 입력된 estimator 객체를 해당 하이퍼 파라미터로 재학습시킴
간단한 예제를 통해서 GridSearchCV API의 사용법을 익혀보자.
결정 트리 알고리즘의 여러 가지 최적화 파라미터를 순차적으로 적용해 붓꽃 데이터를 예측 분석하는 데 GridSearchCV 이용
train_test_split()을 이용해 학습 데이터와 테스트 데이터를 먼저 분리하고 학습 데이터에서 GridSearchCV를 이용해 최적 하이퍼 파라미터 추출.
결정 트리 알고리즘을 구현한 DecisionTreeClassifier의 중요 하이퍼 파라미터인 max_depth와 min_samples_split의 값을 변화시키면서 최적화 진행
테스트할 하이퍼 파라미터 세트는 딕셔너리 형태
하이퍼 파라미터의 명칭
: 문자열 key 값하이퍼 파라미터의 값
: 리스트형
from sklearn.datasets import load_iris
from sklearn.tree import DecisionTreeClassifier
from sklearn.model_selection import GridSearchCV, train_test_split
# 데이터를 로딩하고 학습 데이터와 테스트 데이터 분리
iris_data = load_iris()
X_train, X_test, y_train, y_test = train_test_split(iris_data.data, iris_data.target,
test_size=0.2, random_state=121)
dtree = DecisionTreeClassifier()
## 파라미터를 딕셔너리 형태로 설정
parameters = {'max_depth': [1,2,3], 'min_samples_split': [2,3]}
p.190 ~196참고, p.92
학습 데이터 세트를 GridSearchCV 객체의 fit(학습 데이터 세트) 메서드에 인자로 입력
GridSearchCV 객체의 fit(학습 데이터 세트) 메서드를 수행하면 학습 데이터를 cv에 기술된 폴딩 세트로 분할해 param_grid에 기술된 하이퍼 파라미터를 순차적으로 변경하면서 학습/평가를 수행하고 그 결과를 cv_results_ 속성에 기록
cv_results_는 gridsearchcv의 결과 세트로서 딕셔너리 형태로 key 값과 리스트 형태의 value 값을 가짐
cv_results_를 Pandas의 DataFrame으로 변환하면 내용을 좀 더 쉽게 볼 수 있음
import pandas as pd
# param_grid의 하이퍼 파라미터를 3개의 train, test set fold로 나누어 테스트 수행 설정
### refit=True가 default임. True이면 가장 좋은 파라미터 설정으로 재학습시킴.
grid_dtree = GridSearchCV(dtree, param_grid=parameters, cv=3, refit=True)
# 붓꽃 학습 데이터로 param_grid의 하이퍼 파라미터를 순차적으로 학습/평가
grid_dtree.fit(X_train, y_train)
# GridSearchCV 결과를 추출해 DataFrame으로 변환
scores_df = pd.DataFrame(grid_dtree.cv_results_)
scores_df[['params', 'mean_test_score', 'rank_test_score',
'split0_test_score', 'split1_test_score', 'split2_test_score']]
params | mean_test_score | rank_test_score | split0_test_score | split1_test_score | split2_test_score | |
---|---|---|---|---|---|---|
0 | {'max_depth': 1, 'min_samples_split': 2} | 0.700000 | 5 | 0.700 | 0.7 | 0.70 |
1 | {'max_depth': 1, 'min_samples_split': 3} | 0.700000 | 5 | 0.700 | 0.7 | 0.70 |
2 | {'max_depth': 2, 'min_samples_split': 2} | 0.958333 | 3 | 0.925 | 1.0 | 0.95 |
3 | {'max_depth': 2, 'min_samples_split': 3} | 0.958333 | 3 | 0.925 | 1.0 | 0.95 |
4 | {'max_depth': 3, 'min_samples_split': 2} | 0.975000 | 1 | 0.975 | 1.0 | 0.95 |
5 | {'max_depth': 3, 'min_samples_split': 3} | 0.975000 | 1 | 0.975 | 1.0 | 0.95 |
위의 결과에서 총 6개의 결과를 볼 수 있으며, 이는 하이퍼 파라미터가 max_depth와 min_samples_split을 순차적으로 총 6번 변경하면서 학습 및 평가를 수행했음을 나타냄
위 결과의 'params' 칼럼에는 수행할 때마다 적용된 하이퍼 파라미터값을 가짐
맨 마지막에서 두 번째 행(인덱스 번호 : 4)을 보면 'rank_test_score' 칼럼 값이 1
이는 해당 하이퍼 파라미터의 조합인 max_depth:3, min_samples_split: 2로 평가한 결과 예측 성능이 1위라는 의미
split0_test_score, split1_test_score, split2_test_score는 CV가 3인 경우, 즉 3개의 폴딩 세트에서 각각 테스트한 성능 수치
mean_test_score는 이 세 개의 성능 수치를 평균한 것
주요 칼럼별 의미
params 칼럼
: 수행할 때마다 적용된 개별 하이퍼 파라미터값rank_test_score
: 하이퍼 파라미터별로 성능이 좋은 score 순위. 1이 가장 뛰어난 순위이며 이때의 파라미터가 최적의 하이퍼 파라미터mean_test_score
: 개별 하이퍼 파라미터별로 CV의 폴딩 테스트 세트에 대해 총 수행한 평가 평균값
GridSearchCV 객체의 fit()
을 수행하면 최고 성능을 나타낸 하이퍼 파라미터의 값
과 그때의 평가 결과 값
이 각각 best_params_
, best_score_
속성에 기록됨
(즉, cv_results_의 rank_test_score가 1일 때의 값)
이 속성을 이용해 최적 하이퍼 파라미터의 값과 그때의 정확도를 알아보자
print('GridSearchCV 최적 파라미터:', grid_dtree.best_params_)
GridSearchCV 최적 파라미터: {'max_depth': 3, 'min_samples_split': 2}
print('GridSearchCV 최고 정확도:{0:.4f}'.format(grid_dtree.best_score_))
GridSearchCV 최고 정확도:0.9750
max_depth
가 3, min_samples_split
가 2일 때 검증용 폴드 세트에서 평균 최고 정확도가 97.50%로 측정됨
GridSearchCV 객체의 생성 파라미터로 refit=True가 디폴트
refit=True이면 GridSearchCV가 최적 성능을 나타내는 하이퍼 파라미터로 Estimator를 학습해 best_estimator_로 저장
이미 학습된 best_estimator_를 이용해 앞에서 train_test_split()으로 분리한 테스트 데이터 세트에 대해 예측하고 성능을 평가해 보자
# GridSearchCV의 refit으로 이미 학습된 estimator 반환
estimator = grid_dtree.best_estimator_
#GridSearchCV의 best_estimator_는 이미 최적 학습이 됐으므로 별도 학습이 필요 X
pred = estimator.predict(X_test)
print('테스트 데이터 세트 정확도: {0:.4f}'.format(accuracy_score(y_test, pred)))
테스트 데이터 세트 정확도: 0.9667
별도의 테스트 데이터 세트로 정확도를 측정한 결과 약 96.67%의 결과 도출
일반적으로 학습 데이터를 GridSearchCV를 이용해 최적 하이퍼 파라미터 튜닝을 수행한 뒤에 별도의 테스트 세트에서 이를 평가하는 것이 일반적인 머신러닝 모델 적용 방법
05 데이터 전처리¶
데이터 인코딩¶
사이킷런의 머신러닝 알고리즘은 문자열 값
을 입력값으로 허용하지 X
레이블 인코딩(Label encoding)¶
카테고리 피처를 코드형 숫자 값으로 변환
예를 들어 상품 데이터의 상품 구분이 TV, 냉장고, 전자레인지, 컴퓨터, 선풍기, 믹서 값으로 돼 있다면
TV: 1, 냉장고: 2, 전자레인지: 3, 컴퓨터: 4, 선풍기: 5, 믹서: 6과 같은 숫자형 값으로 변환
약간 주의해야 할 점은 '01', '02'와 같은 코드 값 역시 문자열이므로 1,2와 같은 숫자형 값으로 변환돼야 함
fit(), transform()¶
사이킷런의 레이블 인코딩(Label encoding)은 LabelEncoder 클래스로 구현
LabelEncoder를 객체를 생성한 후 fit()
과 transform()
을 호출해 레이블 인코딩 수행
from sklearn.preprocessing import LabelEncoder
items = ['TV', '냉장고', '전자레인지', '컴퓨터', '선풍기', '선풍기', '믹서', '믹서']
# LabelEncoder를 객체로 생성한 후, fit()과 transform()으로 레이블 인코딩 수행
encoder = LabelEncoder()
encoder.fit(items)
labels = encoder.transform(items)
print('인코딩 변환값:', labels)
인코딩 변환값: [0 1 4 5 3 3 2 2]
classes_¶
TV는 0, 냉장고는 1, 전자레인지는 4, 컴퓨터는 5, 선풍기는 3, 믹서는 2로 변환됨
위 예제는 데이터가 작아서 문자열 값이 어떤 숫자 값으로 인코딩됐는지 직관적으로 알 수 있지만, 많은 경우에 이를 알지 못함
이 경우에는 LabelEncoder 객체의 classes_ 속성값으로 확인
print('인코딩 클래스: ', encoder.classes_)
인코딩 클래스: ['TV' '냉장고' '믹서' '선풍기' '전자레인지' '컴퓨터']
inverse_transform()¶
classes_ 속성은 0번부터 순서대로 변환된 인코딩 값에 대한 원본값을 가지고 있음
따라서 TV가 0, 냉장고 1, 믹서 2, 선풍기 3, 전자레인지 4, 컴퓨터가 5로 인코딩됐음을 알 수 있음
inverse_transform()을 통해 인코딩된 값을 다시 디코딩 가능
print('디코딩 원본값:', encoder.inverse_transform([4,5,2,0,1,1,3,3]))
디코딩 원본값: ['전자레인지' '컴퓨터' '믹서' 'TV' '냉장고' '냉장고' '선풍기' '선풍기']
원-핫 인코딩(One-Hot Encoding)¶
원-핫 인코딩은 피처 값의 유형에 따라 새로운 피처를 추가해 고유 값에 해당하는 칼럼에만 1을 표시하고 나머지 칼럼에는 0을 표시하는 방식
즉, 행
형태로 돼 있는 피처의 고유 값을 열
형태로 차원을 변환한 뒤, 고유 값에 해당하는 칼럼에만 1을 표시하고 나머지 칼럼에는 0을 표시
원-핫 인코딩은 사이킷런에서 OneHotEncoder
클래스로 변환 가능
원-핫 인코딩 사용시 주의할 점
- 입력값으로 2차원 데이터 필요
- OneHotEncoder를 이용해 변환한 값이 희소 행렬(Sparse Matrix) 형태이므로 이를 다시 toarray() 메서드를 이용해 밀집 행렬(Dense Matrix)로 변환해야 함
OneHotEncoder를 이용해 앞의 데이터를 원-핫 인코딩으로 변환해 보자
np.array.reshape 복습
https://yganalyst.github.io/data_handling/memo_5/
from sklearn.preprocessing import OneHotEncoder
import numpy as np
items = ['TV', '냉장고', '전자레인지', '컴퓨터', '선풍기', '선풍기', '믹서', '믹서']
# 2차원 ndarray로 변환
items = np.array(items).reshape(-1,1) # 행(row)의 위치에 -1을 넣고 열의 값을 지정해주면 변환될 배열의 행의 수는 알아서 지정이 된다
# 원-핫 인코딩 적용
oh_encoder = OneHotEncoder()
oh_encoder.fit(items)
oh_labels = oh_encoder.transform(items)
# OneHotEncoder로 변환한 결과는 희소행렬이므로 toarray()를 이용해 밀집 행렬로 변환
print('원-핫 인코딩 데이터')
print(oh_labels.toarray())
print('원-핫 인코딩 데이터 차원')
print(oh_labels.shape)
원-핫 인코딩 데이터 [[1. 0. 0. 0. 0. 0.] [0. 1. 0. 0. 0. 0.] [0. 0. 0. 0. 1. 0.] [0. 0. 0. 0. 0. 1.] [0. 0. 0. 1. 0. 0.] [0. 0. 0. 1. 0. 0.] [0. 0. 1. 0. 0. 0.] [0. 0. 1. 0. 0. 0.]] 원-핫 인코딩 데이터 차원 (8, 6)
8개의 레코드와 1개의 칼럼을 가진 원본 데이터가 8개의 레코드와 6개의 칼럼을 가진 데이터로 변환됨
TV가 0, 냉장고 1, 믹서2, 선풍기 3, 전자레인지 4, 컴퓨터가 5로 인코딩됐으므로
첫 번째 칼럼이 TV, 두 번째 칼럼이 냉장고, 세 번째 칼럼이 믹서, 네 번째 칼럼이 선풍기, 다섯 번째 칼럼이 전자레인지, 여섯 번째 칼럼이 컴퓨터를 나타냄
따라서 원본 데이터의 첫 번째 레코드가 TV이므로 변환된 데이터의 첫 번째 레코드의 첫 번째 칼럼이 1이고, 나머지 칼럼은 모두 0이 됨
이어서 원본 데이터의 두 번째 레코드가 냉장고 이므로 변환된 데이터의 두 번째 레코드의 냉장고에 해당하는 칼럼인 두 번째 칼럼이 1이고, 나머지 칼럼은 모두 0
get_dummies¶
판다스에 있는 원-핫 인코딩을 더 쉽게 지원하는 API
사이킷런의 OnehotEncoder와 다르게 문자열 카테고리 값을 숫자 형으로 변환할 필요 없이 바로 변환
import pandas as pd
df = pd.DataFrame({'item': ['TV', '냉장고', '전자레인지', '컴퓨터', '선풍기', '선풍기', '믹서', '믹서']})
pd.get_dummies(df)
item_TV | item_냉장고 | item_믹서 | item_선풍기 | item_전자레인지 | item_컴퓨터 | |
---|---|---|---|---|---|---|
0 | 1 | 0 | 0 | 0 | 0 | 0 |
1 | 0 | 1 | 0 | 0 | 0 | 0 |
2 | 0 | 0 | 0 | 0 | 1 | 0 |
3 | 0 | 0 | 0 | 0 | 0 | 1 |
4 | 0 | 0 | 0 | 1 | 0 | 0 |
5 | 0 | 0 | 0 | 1 | 0 | 0 |
6 | 0 | 0 | 1 | 0 | 0 | 0 |
7 | 0 | 0 | 1 | 0 | 0 | 0 |
get_dummies()를 이용하면 숫자형 값으로 변환 없이도 바로 변환 가능
피처 스케일링과 정규화¶
StandardScaler¶
StandardScaler는 표준화를 쉽게 지원하기 위한 클래스
즉, 개별 피처를 평균이 0이고, 분산이 1인 값으로 변환
StandardScaler가 어떻게 데이터 값을 변환하는지 데이터 세트로 확인해 보자
from sklearn.datasets import load_iris
import pandas as pd
# 붓꽃 데이터 세트를 로딩하고 DataFrame으로 변환
iris = load_iris()
iris_data = iris.data
iris_df = pd.DataFrame(data=iris_data, columns=iris.feature_names)
iris_df
sepal length (cm) | sepal width (cm) | petal length (cm) | petal width (cm) | |
---|---|---|---|---|
0 | 5.1 | 3.5 | 1.4 | 0.2 |
1 | 4.9 | 3.0 | 1.4 | 0.2 |
2 | 4.7 | 3.2 | 1.3 | 0.2 |
3 | 4.6 | 3.1 | 1.5 | 0.2 |
4 | 5.0 | 3.6 | 1.4 | 0.2 |
... | ... | ... | ... | ... |
145 | 6.7 | 3.0 | 5.2 | 2.3 |
146 | 6.3 | 2.5 | 5.0 | 1.9 |
147 | 6.5 | 3.0 | 5.2 | 2.0 |
148 | 6.2 | 3.4 | 5.4 | 2.3 |
149 | 5.9 | 3.0 | 5.1 | 1.8 |
150 rows × 4 columns
print('feature 들의 평균 값')
print(iris_df.mean())
print('\nfeature 들의 분산 값')
print(iris_df.var())
feature 들의 평균 값 sepal length (cm) 5.843333 sepal width (cm) 3.057333 petal length (cm) 3.758000 petal width (cm) 1.199333 dtype: float64 feature 들의 분산 값 sepal length (cm) 0.685694 sepal width (cm) 0.189979 petal length (cm) 3.116278 petal width (cm) 0.581006 dtype: float64
이제 StandardScaler를 이용해 각 피처를 한 번에 표준화해 변환
StandardScaler 객체를 생성한 후에 fit()과 transform() 메서드에 변환 대상 피처 데이터 세트를 입력하고 호출하면 간단히 변환
transform()을 호출할 때 스케일 변환된 데이터 세트가 넘파이의 ndarray이므로 이를 DataFrame으로 변환해 평균값과 분산 값을 다시 확인해 보자
from sklearn.preprocessing import StandardScaler
# StandardScaler 객체 생성
scaler = StandardScaler()
# StandardScaler로 데이터 세트 변환. fit()과 transform() 호출
scaler.fit(iris_df)
iris_scaled = scaler.transform(iris_df)
# transform() 시 스케일 변환된 데이터 세트가 NumPy ndarray로 반환돼 이를 DataFramme으로 변환
iris_df_scaled = pd.DataFrame(data = iris_scaled, columns=iris.feature_names)
print('feature 들의 평균 값')
print(iris_df_scaled.mean())
print('\nfeature 들의 분산 값')
print(iris_df_scaled.var())
feature 들의 평균 값 sepal length (cm) -1.690315e-15 sepal width (cm) -1.842970e-15 petal length (cm) -1.698641e-15 petal width (cm) -1.409243e-15 dtype: float64 feature 들의 분산 값 sepal length (cm) 1.006711 sepal width (cm) 1.006711 petal length (cm) 1.006711 petal width (cm) 1.006711 dtype: float64
모든 칼럼 값의 평균이 0에 아주 가까운 값으로, 그리고 분산은 1에 아주 가까운 값으로 변환됨
MinMaxScaler¶
MinMaxScaler는 데이터값을 0과 1 사이의 범위 값으로 변환
(음수 값이 있으면 -1에서 1 값으로 변환)
데이터의 분포가 가우시안 분포가 아닐 경우에 Min, Max Scale 적용 가능
다음 예제를 통해 MinMaxScaler가 어떻게 동작하는지 확인
from sklearn.preprocessing import MinMaxScaler
# MinMaxScaler 객체 생성
scaler = MinMaxScaler()
# MinMaxScaler로 데이터 세트 변환. fit()과 transform() 호출
scaler.fit(iris_df)
iris_scaled = scaler.transform(iris_df)
# transform() 시 스케일 변환된 데이터 세트가 NumPy ndarray로 반환돼 이를 DataFrame으로 변환
iris_df_scaled = pd.DataFrame(data=iris_scaled, columns = iris.feature_names)
print('feature들의 최솟값')
print(iris_df_scaled.min())
print('\nfeature들의 최댓값')
print(iris_df_scaled.max())
feature들의 최솟값 sepal length (cm) 0.0 sepal width (cm) 0.0 petal length (cm) 0.0 petal width (cm) 0.0 dtype: float64 feature들의 최댓값 sepal length (cm) 1.0 sepal width (cm) 1.0 petal length (cm) 1.0 petal width (cm) 1.0 dtype: float64
모든 피처에 0에서 1사이의 값으로 변환되는 스케일링이 적용됨
학습 데이터와 테스트 데이터의 스케일링 변환 시 유의점¶
StandardScaler나 MinMaxScaler와 같은 Scaler 객체를 이용해 데이터의 스케일링 변환 시 fit(), transform(), fit_transform() 메서드 이용
fit()
: 데이터 변환을 위한 기준 정보 설정(예를 들어 데이터 세트의 최댓값/최솟값 설정 등)transform()
: 이렇게 설정된 정보를 이용해 데이터 변환fit_transform()
: fit()과 transform()을 한 번에 적용하는 기능 수행
학습 데이터 세트
와 테스트 데이터 세트
에 fit()과 transform()을 적용 시 주의점
- Scaler 객체를 이용해
학습 데이터 세트
로 fit()과 transform()을 적용하면테스트 데이터 세트
로는 다시 fit()을 수행하지 않고
학습 데이터 세트
로 fit()을 수행한 결과를 이용해 transform() 변환을 적용
- 즉,
학습 데이터
로 fit()이 적용된 스케일링 기준 정보를 그대로테스트 데이터
에 적용해야 함 - 그렇지 않고
테스트 데이터
로 다시 새로운 스케일링 기준 정보를 만들게 되면학습 데이터
와테스트 데이터
의 스케일링 기준 정보가 서로 달라지기 때문에 올바를 예측 결과 도출 X
다음 코드를 통해 테스트 데이터에 fit()을 적용할 때 어떠한 문제가 발생하는지 알아보자
먼저 np.arange()를 이용해 학습 데이터를 0부터 10까지, 테스트 데이터를 0부터 5까지 값을 가지는 ndarray로 생성
from sklearn.preprocessing import MinMaxScaler
import numpy as np
# 학습 데이터는 0부터 10까지, 테스트 데이터는 0부터 5까지 값을 가지는 데이터 세트로 생성
# Scaler 클래스의 fit(), transform()은 2차원 이상 데이터만 가능하므로 reshape(-1,1)로 차원 변경
train_array = np.arange(0,11).reshape(-1,1)
test_array = np.arange(0,6).reshape(-1,1)
학습 데이터인 train_array부터 MinMaxScaler를 이용해 변환
학습 데이터는 0부터 10까지 값을 가지는데, 이 데이터에 MinMaxScaler 객체의 fit()을 적용하면 최솟값 0, 최댓값 10이 설정되며 1/10 Scale이 적용됨
이제 transform()을 호출하면 1/10 scale로 학습 데이터를 변환하게 되며 원본 데이터 1은 0.1로 2는 0.2, 그리고 5는 0.5, 10은 1로 변환
# MinMaxScaler 객체에 별도의 feature_range 파라미터 값을 지정하지 않으면 0~1 값으로 변환
scaler = MinMaxScaler()
# fit()하게 되면 train_array 데이터의 최솟값이 0, 최댓값이 10으로 설정
scaler.fit(train_array)
# 1/10 scale로 train_array 데이터 변환함. 원본 10->1로 변환됨
train_scaled = scaler.transform(train_array)
print('원본 train_array 데이터: ', np.round(train_array.reshape(-1), 2))
print('Scale된 train_array 데이터: ', np.round(train_scaled.reshape(-1), 2))
원본 train_array 데이터: [ 0 1 2 3 4 5 6 7 8 9 10] Scale된 train_array 데이터: [0. 0.1 0.2 0.3 0.4 0.5 0.6 0.7 0.8 0.9 1. ]
이번에는 테스트 데이터 세트를 변환하는데, fit()을 호출해 스케일링 기준 정보를 다시 적용한 뒤
transform()을 수행한 결과 확인
# MinMaxScaler에 test_array를 fit()하게 되면 원본 데이터의 최솟값이 0, 최댓값이 5로 설정됨
scaler.fit(test_array)
# 1/5 scale로 test_array 데이터 변환. 원본 5 -> 1로 변환
test_scaled = scaler.transform(test_array)
# test_array의 scale 변환 출력
print('원본 test_array 데이터: ', np.round(test_array.reshape(-1), 2))
print('Scale된 test_array 데이터: ' ,np.round(test_scaled.reshape(-1), 2))
원본 test_array 데이터: [0 1 2 3 4 5] Scale된 test_array 데이터: [0. 0.2 0.4 0.6 0.8 1. ]
출력 결과를 확인하면 학습 데이터
와 테스트 데이터
의 스케일링이 맞지 않음을 알 수 있다.
테스트 데이터
의 경우 : 최솟값 0, 최댓값 5이므로 1/5로 스케일링됨. 따라서 원본값 1은 0.2로, 원본값 5는 1로 변환됨.학습 데이터
의 경우 : 최솟값 0, 최댓값 10이므로 1/10로 스케일링됨. 따라서 원본값 2는 0.2로, 원본값 10은 1로 변환
이렇게 되면 학습 데이터
와 테스트 데이터
의 서로 다른 원본값이 동일한 값으로 반환되는 결과를 초래
머신러닝 모델은 학습 데이터를 기반으로 학습되기 때문에 반드시 테스트 데이터는 학습 데이터의 스케일링 기준에 따라야 하며, 테스트 데이터의 1 값은 학습 데이터와 동일하게 0.1 값으로 변환돼야 함
결론
- 테스트 데이터에 다시
fit()
을 적용해서는 안 되며 학습 데이터로 이미fit()
이 적용된 Scaler 객체를 이용해 transform()으로 변환해야 함
다음 코드는 테스트 데이터에 fit()을 호출하지 않고 학습 데이터로 fit()을 수행한 MinMaxScaler 객체의 transform()을 이용해 데이터 변환
출력 결과를 확인해 보면 학습 데이터, 테스트 데이터 모두 1/10 수준으로 스케일링되어 1이 0.1로, 5가 0.5로, 학습 데이터, 테스트 데이터 모두 동일하게 변환됨
scaler = MinMaxScaler()
scaler.fit(train_array)
train_scaled = scaler.transform(train_array)
print('원본 train_array 데이터 : ', np.round(train_array.reshape(-1), 2))
print('Scale된 train_array 데이터 : ', np.round(train_scaled.reshape(-1), 2))
# test_array에 Scale 변환을 할 때는 반드시 fit()을 호출하지 않고 transform()만으로 변환해야 함
test_scaled = scaler.transform(test_array)
print('\n원본 test_array 데이터 : ', np.round(test_array.reshape(-1), 2))
print('Scale된 test_array 데이터 : ', np.round(test_scaled.reshape(-1), 2))
원본 train_array 데이터 : [ 0 1 2 3 4 5 6 7 8 9 10] Scale된 train_array 데이터 : [0. 0.1 0.2 0.3 0.4 0.5 0.6 0.7 0.8 0.9 1. ] 원본 test_array 데이터 : [0 1 2 3 4 5] Scale된 test_array 데이터 : [0. 0.1 0.2 0.3 0.4 0.5]
fit_transform()을 적용할 때도 마찬가지.
fit_transform()은 fit()과 transform()을 순차적으로 수행하는 메서드이므로 학습 데이터
에서는 상관없지만 테스트 데이터
에서는 절대 사용 X
학습과 테스트 데이터 세트로 분리하기 전에 먼저 전체 데이터 세트에 스케일링을 적용한 뒤 학습과 테스트 데이터 세트로 분리하는 것이 더 바람직함
학습 데이터와 테스트 데이터의 fit(), transform(), fit_transform()을 이용해 스케일링 변환 시 주의할 점
- 가능하다면 전체 데이터의 스케일링 변환을 적용한 뒤 학습과 테스트 데이터로 분리
- 1이 여의치 않다면 테스트 데이터 변환 시에는 fit()이나 fit_transform()을 적용하지 않고 학습 데이터로 이미 fit()된 Scaler 객체를 이용해 transform()으로 변환
'Python, Jupyter 🐍 > [python]파이썬 머신러닝 완벽 가이드' 카테고리의 다른 글
정밀도와 재현율 (0) | 2023.05.06 |
---|---|
[에러]`load_boston` has been removed from scikit-learn since version 1.2. (0) | 2023.05.02 |
머신러닝 개요 (0) | 2023.04.28 |