728x90
선형회귀 표현식
- Y는 종속변수 값
- θ₀ 는 bias
- θ₁,…,θₙ 회귀계수
- x₁, x₂,…,xₙ 는 독립변수
위의 식을 아래와 같이 벡터로 표현이 가능하다.
- θ는 model의 파라미터 벡터
- x는 Xo=1인 입력 벡터
# data 생성
np.random.seed(321)
x_old = np.random.rand(1000, 1)
# intercept항 1추가
y = 2 + 5 * x_old + np.random.rand(1000, 1)
# plot
plt.scatter(x_old, y, s=10)
plt.xlabel('x')
plt.ylabel('y')
plt.ylim(2, 8)
plt.show()
# intercept 추가
x = np.c_[np.ones(x_old.shape[0]), x_old]
비용 함수(Cost Function)
예측값에서 실제값을 뺀 잔차의 제곱의 평균을 사용한다. 제곱을 사용하는 이유는 잔차(예측값 - 실제값)의 차이가 클수록 더 많은 Penalty를 부여하기 위함이다. 추가적으로 1/2을 곱해주는 이유는 제곱을 미분할때 사라지기 위함이다.
residuals = y_pred - y
cost = np.sum(residuals ** 2) / (2 * m)
경사하강법 (Gradient Descent)
비용함수를 각 세타에 대해서 편미분을 진행한 결과는 다음과 같다.
미분에 대해서 잘 모르는 사람을 위해서 부족하지만 θ₁에 대한 편미분 예시입니다.
이해가 안되시는분들은 합성곱 미분에 대해서 구글에 검색후 찾아보세요!!!
결국 아래와 같이 Gradient 벡터를 한번에 표현하게 됩니다.
# h(x)
y_pred = np.dot(x, self.w_)
# h(x) - y
residuals = y_pred - y
# x_T * (h(x) - y)
gradient_vector = np.dot(x.T, residuals)
gradient vector를 가지고 모델 파라미터를 업데이트 해준다.
- a는 learning rate
이 또한 아래와 같이 벡터로 한번에 표현하게 된다.
learning rate를 잘 설정해줘야 하는 이유를 아래 그림을 통해서 보시면 알 수 있습니다.
learning rate가 낮으면 천천히 내려가기 때문에 데이터가 많아지면 많은 시간이 걸릴수도 있고 원하는 global minimum에 도달하지 못한채 학습이 끝나버릴 수 있습니다.
반대로 learning rate가 너무 높으면 계속 왔다갔다 진행하기 때문에 제대로 global minimum에 수렴하지 못할 수 있습니다.
self.w_ -= (self.lr / m) * gradient_vector
최종결과
Code
class LinearRegressionGD:
def __init__(self, lr=0.001, n_iter=30000):
self.lr = lr
self.n_iter = n_iter
def fit(self, x, y):
self.cost_ = []
self.w_ = np.zeros((x.shape[1], 1))
m = x.shape[0]
for i in range(self.n_iter):
y_pred = np.dot(x, self.w_)
residuals = y_pred - y
gradient_vector = np.dot(x.T, residuals)
self.w_ -= (self.lr / m) * gradient_vector
cost = np.sum(residuals ** 2) / (2 * m)
self.cost_.append(cost)
if i % 2000 == 0:
self.draw(x, y)
return self
def predict(self, x):
return np.dot(x, self.w_)
def draw(self, x, y):
y_predicted = np.dot(x, self.w_)
plt.scatter(x[:,1], y, s=8)
plt.xlabel('x')
plt.ylabel('y')
plt.plot(x[:,1], y_predicted, color='b')
plt.show()
728x90
'데이터분석' 카테고리의 다른 글
RMSE, Grid Search python 구현 (0) | 2020.11.12 |
---|---|
Logistic Regression (0) | 2020.11.12 |
seaborn 시각화 python (0) | 2020.11.12 |
데이터 전처리 python (0) | 2020.11.12 |
python 복사 단순 객체복사 vs shallow copy vs deep copy (0) | 2020.10.19 |
댓글