Skip to content

Latest commit

ย 

History

History
162 lines (118 loc) ยท 4.48 KB

File metadata and controls

162 lines (118 loc) ยท 4.48 KB

๋‹ค์ค‘ํšŒ๊ท€(Multiple Regression)

๋‹ค์ค‘ํšŒ๊ท€๋ฅผ ์ด์šฉํ•œ ๋†์–ด ์˜ˆ์ธก

ํ˜ผ์ž ๊ณต๋ถ€ํ•˜๋Š” ๋จธ์‹ ๋Ÿฌ๋‹+๋”ฅ๋Ÿฌ๋‹์—๋Š” ์—ฌ๋Ÿฌ๊ฐœ์˜ ํŠน์„ฑ(๊ธธ์ด, ๋‘๊ป˜, ๋†’์ด)์„ ์‚ฌ์šฉํ•˜์—ฌ ์˜ˆ์ธกํ•  ๋•Œ ์“ธ์ˆ˜ ์žˆ๋Š”๋ฐ ๋‹ค์ค‘ํšŒ๊ท€ ์˜ˆ์ œ๋ฅผ ์•„๋ž˜์ฒ˜๋Ÿผ ์ œ๊ณตํ•˜๊ณ  ์žˆ์Šต๋‹ˆ๋‹ค.

multiple_regression.ipynb์˜ ๋†์–ด์˜ ๋ฌด๊ฒŒ๋ฅผ ์˜ˆ์ธกํ•˜๋Š” ์˜ˆ์ œ๋ฅผ ์•„๋ž˜์—์„œ ์„ค๋ช…ํ•ฉ๋‹ˆ๋‹ค.

  1. ๋ฐ์ดํ„ฐ๋ฅผ ์•„๋ž˜์™€ ๊ฐ™์ด pandas๋กœ ์ฝ์–ด์˜ฌ ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค. CSV ํŒŒ์ผ์—๋Š” length, height, width๋กœ ๋„๋ฏธ(perch) ๋ฐ์ดํ„ฐ๊ฐ€ ์ •๋ฆฌ๋˜์–ด ์žˆ์Šต๋‹ˆ๋‹ค.
import pandas as pd

df = pd.read_csv('https://bit.ly/perch_csv_data')
perch_full = df.to_numpy()
print(perch_full)

์ด๋•Œ ์ฝ์–ด์ง„ ๋ฐ์ดํ„ฐ์˜ ํ˜•ํƒœ๋Š” ์•„๋ž˜์™€ ๊ฐ™์Šต๋‹ˆ๋‹ค.

[[ 8.4   2.11  1.41]
 [13.7   3.53  2.  ]
 [15.    3.82  2.43]
 [16.2   4.59  2.63]
 [17.4   4.59  2.94]
 [18.    5.22  3.32]
 [18.7   5.2   3.12]
 [19.    5.64  3.05]
  1. Weight์— ๋Œ€ํ•œ ๋ฐ์ดํ„ฐ๋ฅผ ์ค€๋น„ํ•˜๊ณ , Train/Test Set์„ ์ค€๋น„ํ•ฉ๋‹ˆ๋‹ค.
import numpy as np
perch_weight = np.array(
    [5.9, 32.0, 40.0, 51.5, 70.0, 100.0, 78.0, 80.0, 85.0, 85.0, 
     110.0, 115.0, 125.0, 130.0, 120.0, 120.0, 130.0, 135.0, 110.0, 
     130.0, 150.0, 145.0, 150.0, 170.0, 225.0, 145.0, 188.0, 180.0, 
     197.0, 218.0, 300.0, 260.0, 265.0, 250.0, 250.0, 300.0, 320.0, 
     514.0, 556.0, 840.0, 685.0, 700.0, 700.0, 690.0, 900.0, 650.0, 
     820.0, 850.0, 900.0, 1015.0, 820.0, 1100.0, 1000.0, 1100.0, 
     1000.0, 1000.0]
     )

from sklearn.model_selection import train_test_split

# ํ›ˆ๋ จ ์„ธํŠธ์™€ ํ…Œ์ŠคํŠธ ์„ธํŠธ๋กœ ๋‚˜๋ˆ•๋‹ˆ๋‹ค
train_input, test_input, train_target, test_target = train_test_split(
    perch_full, perch_weight, random_state=42)
  1. ๋‹ค์ค‘ํšŒ๊ท€๋ฅผ ์ˆ˜ํ–‰ํ•ฉ๋‹ˆ๋‹ค.
from sklearn.preprocessing import PolynomialFeatures

poly = PolynomialFeatures(degree=2, include_bias=False)  # default degree=2

poly.fit(train_input)

train_poly = poly.transform(train_input)
test_poly = poly.transform(test_input)
  1. Ridge๋กœ ๊ทœ์ œ๋ฅผ ์ˆ˜ํ–‰ํ•ฉ๋‹ˆ๋‹ค.
from sklearn.linear_model import Ridge

ridge = Ridge()
ridge.fit(train_poly, train_target)
print(ridge.score(train_poly, train_target))
print(ridge.score(test_poly, test_target))

์ด๋•Œ์˜ ๊ฒฐ๊ณผ๋Š” ์•„๋ž˜์™€ ๊ฐ™์Šต๋‹ˆ๋‹ค.

0.9894023149360563
0.9853164821839827

alpha๋กœ ๊ทœ์ œ๋กค ์กฐ์ •ํ•ฉ๋‹ˆ๋‹ค.

train_score = []
test_score = []

alpha_list = [0.001, 0.01, 0.1, 1, 10, 100]
for alpha in alpha_list:
    # ๋ฆฟ์ง€ ๋ชจ๋ธ์„ ๋งŒ๋“ญ๋‹ˆ๋‹ค
    ridge = Ridge(alpha=alpha)
    # ๋ฆฟ์ง€ ๋ชจ๋ธ์„ ํ›ˆ๋ จํ•ฉ๋‹ˆ๋‹ค
    ridge.fit(train_poly, train_target)
    # ํ›ˆ๋ จ ์ ์ˆ˜์™€ ํ…Œ์ŠคํŠธ ์ ์ˆ˜๋ฅผ ์ €์žฅํ•ฉ๋‹ˆ๋‹ค
    train_score.append(ridge.score(train_poly, train_target))
    test_score.append(ridge.score(test_poly, test_target))

import matplotlib.pyplot as plt

plt.plot(np.log10(alpha_list), train_score)
plt.plot(np.log10(alpha_list), test_score)
plt.xlabel('alpha')
plt.ylabel('R^2')
plt.show()

๊ฒฐ๊ณผ๋Š” ์•„๋ž˜์™€ ๊ฐ™์Šต๋‹ˆ๋‹ค.

image

  1. Lasso๋กœ ๊ทœ์ œ๋ฅผ ์ˆ˜ํ–‰ํ•ฉ๋‹ˆ๋‹ค.
from sklearn.linear_model import Lasso

lasso = Lasso()
lasso.fit(train_poly, train_target)
print(lasso.score(train_poly, train_target))
print(lasso.score(test_poly, test_target))

๊ฒฐ๊ณผ๋Š” ์•„๋ž˜์™€ ๊ฐ™์Šต๋‹ˆ๋‹ค.

0.9886724935434511
0.9851391569633392

alpha๋กœ ๊ทœ์ œ์˜ ์ •๋„๋ฅผ ์กฐ์ •ํ•ด๋ด…๋‹ˆ๋‹ค.

train_score = []
test_score = []

alpha_list = [0.001, 0.01, 0.1, 1, 10, 100]
for alpha in alpha_list:
    # ๋ผ์˜ ๋ชจ๋ธ์„ ๋งŒ๋“ญ๋‹ˆ๋‹ค
    lasso = Lasso(alpha=alpha, max_iter=10000)
    # ๋ผ์˜ ๋ชจ๋ธ์„ ํ›ˆ๋ จํ•ฉ๋‹ˆ๋‹ค
    lasso.fit(train_poly, train_target)
    # ํ›ˆ๋ จ ์ ์ˆ˜์™€ ํ…Œ์ŠคํŠธ ์ ์ˆ˜๋ฅผ ์ €์žฅํ•ฉ๋‹ˆ๋‹ค
    train_score.append(lasso.score(train_poly, train_target))
    test_score.append(lasso.score(test_poly, test_target))

plt.plot(np.log10(alpha_list), train_score)
plt.plot(np.log10(alpha_list), test_score)
plt.xlabel('alpha')
plt.ylabel('R^2')
plt.show()

์ด๋•Œ์˜ ๊ฒฐ๊ณผ๋Š” ์•„๋ž˜์™€ ๊ฐ™์Šต๋‹ˆ๋‹ค.

image

Reference

ํ˜ผ์ž ๊ณต๋ถ€ํ•˜๋Š” ๋จธ์‹ ๋Ÿฌ๋‹+๋”ฅ๋Ÿฌ๋‹