프로그래밍 공부

파이썬 9일차

3452 2025. 5. 19. 17:40

의사결정 나무

트리기반 모델은 Feature를 조건 기반으로 참 거짓으로 나눠 스무고개 하듯이 학습을 이어간다.

그렇기 때문에 결과에 대한 설명이 가능하고 범주와 연속형 수치 모두 예측이 가능하지만 과대적합 발생 확률이 높고 출력변수가 연속형인 회귀 모델에서는 예측력이 떨어질수 있다.

 

import pandas as pd
import numpy as np
from sklearn.model_selection import train_test_split

df = pd.read_csv('./국민건강보험공단_건강검진정보_20211229.CSV', encoding='cp949')

sample_df = df[['신장(5Cm단위)', '성별코드', '체중(5Kg 단위)', '음주여부']]

sample_df[:10]

실습 할 데이터

 

 

print('Info 정보 확인')
sample_df.info()

데이터의 결측치를 확인한다.

 

 

sample = sample_df.dropna()

print('Drop 후 Info 정보 확인')
sample.info()

계산에 방해되는 결측치는 제거한다.

 

sample = sample.astype('str')

y = sample.음주여부

X = sample.drop('음주여부', axis=1)

y.value_counts()

원핫 인코딩을 위해 데이터를 오브젝트로 변화시킨 후 데이터의 편향성을 확인한다.

 

X_train, X_valid, y_train, y_valid = train_test_split(X, y, test_size=0.2, shuffle=True, random_state=34)

학습용 데이터와 검증용 데이터를 분리한다.

 

from sklearn.tree import DecisionTreeClassifier

dt = DecisionTreeClassifier(random_state = 1001, max_depth=2)

dt_model = dt.fit(X_train, y_train)

print('학습 정확도 = ', dt_model.score(X_train, y_train))
print('검증 정확도 = ', dt_model.score(X_valid, y_valid))

주어진 데이터로 학습 시켰더니 학습 정확도와 검증 정확도가 거의 동일한것을 알수있다.

따라서 과적합이 적고 데이터 분포가 비슷하며 모델이 적절히 학습 되었다고 볼수있다.

 

import matplotlib.pyplot as plt

!pip install graphviz

import graphviz

from sklearn.tree import export_graphviz

tree_graph = graphviz.Source(export_graphviz(dt_model, feature_names=['height', 'sex', 'weight'], class_names=['X','O'], impurity=True, filled=True))

tree_graph

tree_graph.render('tree_depth5', format='png')

해당 결과를 바탕으로 의사결정나무 그래프를 나타냈다.

 

이 결과는 성별코드가 1.5 보다 큰지를 비교해서 크면 2(여성) 작으면1(남성)으로 분류하고, 그 후 각각 여성과 남성에 대해서 키를 기준으로 키가 일정치보다 크다 작다로 분류해서 남성이고 키가 162.5보다 작은 사람은 전체 799843명중에 88379명이더라 정도의 통계 요약 도구 정도로 쓰인다.

 

 

불순도

불순도는 얼마나 다양한 데이터가 섞여있는지를 나타내는것이다.

 

예를 들어 박스안에 전부 사과로만 채워져있다면 이 박스의 불순도는 0이다.

바나나로만 채워져있어도 불순도는 0이다.

 

하지만 사과랑 바나나가 한 박스에 함께 있다면 이 박스는 순수한 사과박스나 바나나박스가 아니게 된다.

 

이때 얼마나 섞여있는지를 불순도라고 볼수있다.

 

불순도를 구하는 방법은 지니계수(Gini)를 사용한다.

 

1 - {(항목1갯수 / 전체갯수)^2 + (항목2갯수 / 전체갯수)^2 + (항목3갯수 / 전체갯수)^2 ... } 의 방법으로 계산한다.

 

만약 전체 9개의 과일중에 사과가 2개, 바나나 3개, 복숭아가 4개라고 가정했을 때

1 - ({2/9}^2 + {3/9}^2 + {4/9}^2) = 0.6419... 즉 0.642가 된다.

이 과일들의 불순도는 0.642가 되는것이다.

 

def gini(x):
  n = x.sum()
  gini_sum = 0
  for key in x.keys():
    gini_sum = gini_sum + (x[key] / n) * (x[key] / n)
  gini = 1 - gini_sum
  return gini

과일바구니1 = ['사과'] * 9
과일바구니2 = ['사과', '바나나', '사과', '바나나', '바나나', '바나나', '복숭아', '복숭아', '복숭아']
과일바구니3 = ['사과', '바나나', '사과', '바나나', '사과', '복숭아', '복숭아', '사과', '복숭아']

print(round(gini(pd.DataFrame(과일바구니1).value_counts()),3))
print(round(gini(pd.DataFrame(과일바구니2).value_counts()),3))
print(round(gini(pd.DataFrame(과일바구니3).value_counts()),3))

사과만 존재할때는 불순도가 0

사과, 바나나, 복숭아가 각각 2개 4개 3개 존재할때 0.642

사과, 바나나, 복숭아가 각각 4개, 2개, 3개 존재 할 때 0.642

 

즉 불순도는 전체 갯수보다 각 항목간의 비율에 따라 달라진다.

'프로그래밍 공부' 카테고리의 다른 글

파이썬 11일차  (0) 2025.05.21
파이썬 10일차  (0) 2025.05.20
파이썬 8일차  (0) 2025.05.16
파이썬 7일차  (0) 2025.05.15
파이썬 6일차  (0) 2025.05.14