플랫폼 개발팀 기술 블로그

K-최근접 이웃법 (Nearest Neighbor) 본문

카테고리 없음

K-최근접 이웃법 (Nearest Neighbor)

@위너스 2019. 4. 4. 18:15

최근접 이웃법은 새로운 데이터를 입력받았을 때 가장 가까이 있는 것이 무엇이냐를 중심으로 새로운 데이터의 종류를 정해주는 알고리즘이다.

그림에서 보는것 처럼 기존 데이터가 파랑색과 주황색으로 데이터가 분류되었다고 한다면 물음표 데이터가 들어왔을때 이 데이터는 어떤 색상으로 분류가 되어야 할까?

 

최근접 이웃법 말그대로 가까운 것에 따른 분류이기 때문에 주황색으로 분류 할 것이다 . 정말 간단하고 직관적인 알고리즘 이다. 하지만 단순히 가장 가까이에 있는 것으로 분류를 하는것이 문제가 되는 경우도 있다.

 

다음의 경우를 살펴 보자.

 

knn

이번에는 문제가 약간 복잡해 진다. 이론대로 라면 가장 가까운 것은 파란색이기 때문에 파란색으로 분류를 해야 할것 같지만, 주변을 보면 대부분 주황색이 보인다. 왠지.. 파랑색으로 분류하면 안될 것 같다는 생각이 든다....

 

따라서 단순히 가까운것에 따른 분류에서 주변에 있는 몇개의 것들중에 많은 것으로 분류하게 되면 이 문제는 어느정도 해결이 된다.

 

물음표 주변(큰원)에는 4개의 데이터가 있는데 그중 세개가 주황색, 한개가 파랑색 이다. 가장 많은 색인 주황색으로 물음표는 분류가 된다.

 

그렇다면 주변의 범위 즉 주변 데이터의 갯수에 대한 기준이 있어야 할 것 같은데.. 위에서는 4개로 정하였다. 즉 K=4가 되는 최근접 이웃법이 된다. 그래서 KNN 알고리즘 이라고 하는 것이다.

knn

만약 K=1로 한다면 처음 정의했던 것처럼 물음표는 가장 가까운 한개의 요소만 바라보게 될 테니 파랑색으로 분류를 하게 될 것이다. 즉, K의 값에 따라 물음표가 어느 범주로 분류 될 것인지가 결정 된다.

 

그럼 과연 K는 몇이어야 좋은 것일까? 최선의 K값을 선택하는 것은 데이터마다 다르게 접근해야 한다.

실습

간단하게 mglearn 라이브러리에 있는 데이터셋을 이용하여 실습해 보자.

(mglearn : 그래프와 데이터셋을 손쉽게 다루기 위한 샘플데이터 라이브러리)

%matplotlib inline
import numpy as np
import matplotlib.pyplot as plt
import mglearn

mglearn.plots.plot_knn_classification(n_neighbors=3)

결과

knn

n_neighbors=3 즉 K=3으로 할 경우 테스트 데이터가 어느 분류로 될지를 보여주는 그래프이다.

앞에서 설명했듯이 가장 가까운 3개의 점들에 의해 분류가 결정이 된다.

여기서는 세개의 별이 테스트로 들어온 데이터고 결과적으로는 두개는 주황색, 한개는 파랑색으로 분류가 되었다.

mglearn.plots.plot_knn_classification(n_neighbors=3)

결과

knn

만약 n_neighbors=1 즉 K=1으로 할 경우 다른 결과를 얻게 된다.

mglearn의 make_forge 함수를 이용하여서 data와 target을 x와 y에 대입하여 해당 데이터로 모델을 만들고 검증해보자.

from sklearn.model_selection import train_test_split
x, y = mglearn.datasets.make_forge()

x_train, x_test, y_train, y_test = train_test_split(x, y, random_state=0)

mglearn.discrete_scatter(x[:, 0], x[:, 1], y)

결과

knn

from sklearn.neighbors import KNeighborsClassifier
knn_classifier = KNeighborsClassifier(n_neighbors=3)

knn_classifier.fit(x_train, y_train)

print("테스트 데이터 예측 : {}".format(knn_classifier.predict(x_test)))
print("테스트 데이터 정확도 : {}".format(knn_classifier.score(x_test, y_test)))
테스트 데이터 예측 : [1 0 1 0 1 0 0]
테스트 데이터 정확도 : 0.8571428571428571

정확도는 약 86%

이 말은 모델이 테스트 데이터셋에 있는 데이터들중 86%를 제대로 맞췄다는 의미이다.

실제 핵심 로직은 sklearn에서 제공해 주고 있기 때문에 코드는 굉장히 간단해 보일 것이다.

K-NN알고리즘은 지도학습의 분류 알고리즘의 하나로 로직이 간단하여 구현하기 쉽다. 하지만 학습 모델이 따로 없고, 전체 데이터를 스캔하여 데이터를 분류하기 때문에 데이터의 양이 많아지면 분류 속도가 현저하게 느려진다. 그래서 K-NN알고리즘을 게으른 알고리즘 이라고도 한다.

Comments