-
[Machine Learning] Linear Regression - 단순선형회귀Data Science/Machine Learning & Deep Learning 2021. 2. 4. 00:31
지도학습(Supervised Learning)은 머신 러닝의 한 방법입니다.
데이터의 속성을 알려주고 학습을 하는 것인데, 정답을 알려주면서 학습을 시키는 것이라고 할 수 있습니다.
지도 학습에는 회귀(Regression)와 분류(Classification) 두 가지 방법이 있습니다.
분류는 말 그대로 이것이 A이냐 B이냐 Lable로 데이터를 분류하며, 회귀는 연속적인 값으로 얼마나 될 지를 수로 알려줍니다.
scikit-learn을 이용하여
하나의 feature만 사용하는 단순 선형 회귀모델(Simple Linear Regression)을 만들어 보겠습니다.
먼저 사용한 데이터는 보험료를 예측하기 위한 데이터입니다.
나이, 성별, BMI, 어린이, 흡연자, 지역, 요금 columns로 이루어진 데이터입니다.
www.kaggle.com/sonujha090/insurance-prediction
import pandas as pd df = pd.read_csv('insurance.csv')
이번에는 하나의 변수만을 사용한 단순선형회귀모델을 만들어 보기로 했기 때문에
간단하게 숫자형 변수인 나이(age)와 비만도(bmi) 중
df.corrwith(df['charges'])
보험 청구 비용과 상관 관계가 높은 나이(age)로 특성(feature)을 선택하겠습니다.
이제 특성(feature)은 나이(age), 타겟(target)은 청구비용(charges)로 하는 단순선형회귀모델을 만들어 보겠습니다.
#scikit-learn 라이브러리에서 사용할 예측 모델 클래스 import from sklearn.linear_model import LinearRegression model = LinearRegression() # 예측모델 인스턴스 만들기 feature = ['age'] target = ['charges'] X_train = df[feature] Y_train = df[target] model.fit(X_train, Y_train) # 모델 학습
만들어진 모델에 50살인 사람의 경우로 테스트를 해보겠습니다.
X_test = [[50]] y_pred = model.predict(X_test) >> 16912
이 결과로 50세인 사람의 예상 청구 금액은 16912임을 알 수 있습니다.
다음 전체 데이터로 예측을 하고
X_test = [[x] for x in df['age']] y_pred = model.predict(X_test)
그래프를 그려보면
import matplotlib.pyplot as plt ## train data : blue plt.scatter(X_train, Y_train) ## test data : red plt.scatter(X_test, y_pred, color='red');
seaborn에서의 regplot으로 그려보면
import seaborn as sns sns.regplot(x=df['age'], y=df['charges']);
두 그래프가 유사한 형태를 보이고 있음을 알 수 있습니다.
선형회귀모델의 계수(coefficients)를 통해서는 모델이 target과 feature에서 어떤 관계를 학습했는지를 알수 있습니다.
다음은 모델의 계수와 절편을 알아보는 방법입니다.
model.coef_ # 계수 model.intercept_ # 절편
위에서 만든 모델에 적용해 계수를 알아보았고, 계수는 [278.23411643]입니다.
이를 통해 나이와 보험 청구 금액은 양의 상관관계를 가지며,
나이가 1살 증가할 수록 보험 청구 금액은 약 278달러 증가함을 예측할 수 있습니다.
단순선형회귀모델을 설명하기에 적합하지 않은 데이터를 가지고 온 것 같아 당황스럽지만
다음 포스팅에서 해결해보도록 하겠습니다!
'Data Science > Machine Learning & Deep Learning' 카테고리의 다른 글
[Machine Learning] 모델 해석 - Feature Importance, Permutation Importance, PDP, SHAP (0) 2021.03.03 [Machine Learning] 교차검증 - Cross Validation, RandomizedSearchCV, GridSearchCV (0) 2021.02.21 [Machine Learning] RandomForest 랜덤 포레스트 & Threshold, ROC Curve, AUC (0) 2021.02.20 [Machine Learning] Logistic Regression 로지스틱회귀 (0) 2021.02.13 [Machine Learning] Linear Regression - 다중선형회귀 (0) 2021.02.07