Sunday, June 7, 2009

Python nearest neighbors binary classifier

Posted by Danny Tarlow
Binary Classification 101: Round 2 I wrote last time about how easy it is to build a logistic regression classifier in Python using numpy and scipy.

Maintaining the "Machine Learning 101" theme, here's another dead-simple classifier. Remember, we're still talking about binary classification problems:
binary classification problems -- where you have some examples that are "on" and other examples that are "off." You get as input a training set which has some examples of each class along with a label saying whether each example is "on" or "off". The goal is to learn a model from the training data so that you can predict the label of new examples that you haven't seen before and don't know the label of.

For one example, suppose that you have data describing a bunch of buildings and earthquakes (E.g., year the building was constructed, type of material used, strength of earthquake,etc), and you know whether each building collapsed ("on") or not ("off") in each past earthquake. Using this data, you'd like to make predictions about whether a given building is going to collapse in a hypothetical future earthquake.

Classification with Nearest Neighbors
This time we're going to build a nearest neighbors classifier. The idea is even simpler than logistic regression: when you want to predict the label of a point, look at the most similar points, and report the average label as your prediction. For example, suppose we have four examples in a 2D training set:
[0, 0]0
[0, 5]1
[10, 0]0
[10, 5]1
Now suppose we want to predict the label of a point at [0, 2], and we're using only the first nearest neighbor. In this case, we'd note that [0, 0] is the closest point, and it has the label 0. So our prediction for [0, 2] with k=1 is a 100% vote for label 0.

Now suppose we choose k=3, which means our predictions are using the three nearest neighbors. In this case, we'd find that [0, 0] (label 0), [0, 5] (label 1), and [10, 0] (label 0) are the closest points. If we want to give a probability as our answer, then a reasonable estimate would be that there's a 2/3 chance of the label being 0, and a 1/3 chance of the label being 1. If we just want to make a prediction of "off" (0) or "on" (1), we'd probably guess "off."

If you look back at the logistic regression post, you'll notice that there's nothing preventing us from running the same set of experiments as last time using a nearest neighbors classifier instead of the logistic regression. It's just training set in, classifier out. Once you have your classifier, you can ask it to make predictions about new points.

The one chance is that instead of having a regularization parameter, we have a k parameter, which tells us how many neighbors to look at making a prediction. Here's what happens when you run it for a few different settings of k:

It's kind of interesting, huh? Setting k to 1 causes some overfitting that looks quite similar to unregularized logistic regression. Choosing k to be larger has a somewhat similar effect to increasing the regularization strength.

The Code
And here's the code that implements the classifier and produces the figure. For those of you who thought this was all super obvious, take a look at how easy scipy's KDTree class makes this. I almost feel bad taking any credit at all for writing this code.

Also, if you're George, hopefully you'll be happy to see that I took some of your design suggestions into account when deciding how to structure the interface.

from scipy.spatial import KDTree
import numpy as np
import networkx as nx

from synthetic_classifier_data import *

class NearestNeighborBinaryClassifier():
    def __init__(self, x_train, y_train):
        self.set_data(x_train, y_train)

    def set_data(self, x_train, y_train):
        """ y_train entries should be either -1 or 1."""
        self.x_train = x_train
        self.y_train = y_train
        self.n = y_train.shape[0]
        self.kd_tree = KDTree(self.x_train)

    def predict(self, X, k=3):
        predictions = np.zeros(X.shape[0])
        for i in range(X.shape[0]):
            # Let the kd-tree do all the real work
            d_i, n_i = self.kd_tree.query(X[i, :], k=k)
            predictions[i] = np.sum(.5 + .5 * self.y_train[n_i]) / float(k)
        return predictions

if __name__ == "__main__":
    from pylab import *

    # Create 20 dimensional data set with 25 points
    data = SyntheticClassifierData(25, 20)
    nn = NearestNeighborBinaryClassifier(data.X_train, data.Y_train)

    # Run for a variety of settings of k
    ks = [1, 3, 5, 10]
    for j, k in enumerate(ks):       
        # Predict the label of each (training and test) y given its
        # k nearest neighbors
        hat_y_train = nn.predict(data.X_train, k=k)
        hat_y_test = nn.predict(data.X_test, k=k)
        # Plot the results
        subplot(len(ks), 2, 2*j + 1)
        plot(np.arange(data.X_train.shape[0]), .5 + .5 * data.Y_train, 'bo')
        plot(np.arange(data.X_train.shape[0]), hat_y_train, 'rx')
        ylim([-.1, 1.1])
        ylabel("K=%s" % k)
        if j == 0:
            title("Training set reconstructions")
        subplot(len(ks), 2, 2*j + 2)
        plot(np.arange(data.X_test.shape[0]), .5 + .5 * data.Y_test, 'yo')
        plot(np.arange(data.X_test.shape[0]), hat_y_test, 'rx')
        ylim([-.1, 1.1])
        if j == 0:
            title("Test set predictions")



George said...

Wow, I had no idea scipy had a kd tree in it. Awesome! Too bad kd trees aren't really useful in high dimensions. And yes I am quite happy you used a few of my design suggestions.

My own nearest neighbor code isn't too useful at the moment because it doesn't do prediction, it just finds nearest neighbors. I would have to make a KNN classifier module that used my core nearest neighbor search algorithms. In low dimensions, my code still doesn't really shine, but I have my own approximate nearest neighbor techniques that seem to work quite well in 100+ dimensions. Their primary downside is quadratic ( O(N^2 d) ) preprocessing time (and the uncontrolled approximations they make), but one of them achieves query time of O(d + log N) and the other achieves O(d) query time, where d is the dimensionality of the space and N is the cardinality of the training set. They basically work by building a graph on the training set and doing best first search on the graph to find potential nearest neighbors for a query point while expanding a constant number of nodes in this search. I have another algorithm I want to try that will get me down to O(N d) preprocessing cost ad maintain my best query time, although some things will be pushed under the rug because some constants might be worse in order for reasonable performance to be achieved.

KNN always intrigued me because of how intuitive it is and how the naive algorithm uses essentially no training time, but substantial test time. This is quite in contrast to most machine learning algorithms I work with which have really long training times and very fast test times.

Danny Tarlow said...

If you're working with data that has high dimension but not necessarily high intrinsic dimension (e.g., it lies on some relatively low dimensional manifold), I find this to be a nice line of attack (I saw the poster a while back but have never done anything with it myself):

From the intro:
The curse of dimensionality has traditionally been the bane of nonparametric statistics, as reflected for instance in convergence rates that are exponentially slow in dimension. An exciting way out of this impasse is the recent realization by the machine learning and statistics communities that in many real world problems the high dimensionality of the data is only superficial and does not represent the true complexity of the problem. In such cases data of low intrinsic imension is embedded in a
space of high extrinsic dimension.

For example, consider the representation of human motion generated by a motion capture system. Such systems typically track marks located on a tight-fitting body suit. The number of markers, say N, is set sufficiently large in order to get dense coverage of the body. A posture is represented by a (3N)-dimensional vector that gives the 3D location of each of the N marks. However, despite this seeming high dimensionality, the number of degrees of freedom is relatively small, corresponding to the dozen-or-so joint angles in the body.

Will Dwinnell said...

Yes, k-NN can be very useful. I'd make the following points:

1. Choosing an appropriate value for k is critical. Fortunately, this is easily accomplished by checking a variety of values and selecting the one which performs best on validation data.

2. Implementation can be a sticky. Many systems easily handle equations and logic for things like discriminant classifiers, logistic regression or induced decision trees, but k-NN requires that the complete historical data set be joined to the query data.

3. Poor selection of independent variables can be problematic. While most other machine learning techniques can (to some extent) ignore redundant and useless candidate independent variables, k-NN cannot.

Danny Tarlow said...

All are good points, Will. Thanks.

I will definitely address #1 at some point in the near future -- talking about cross validation seems pretty important to round these two posts out.

#3 could be interesting too -- perhaps boosting or L1 regularization would be worth mentioning at some point.

Joseph Turian said...

FLANN (Fast Library for Approximate Nearest Neighbors) by Marius Muja at UBC has Python bindings. It can automatically pick hyperparameters and choose between KD-tree and K-means in a data-dependent way.

Connelly Barnes said...

Mount and Arya [1] is another popular library for doing kNN queries in high dimensions. A common simple approach is if we are finding the nearest feature under L2 distance, we can use PCA to reduce dimensionality of features (sometimes one captures X%, e.g. 95% of the variance, and automatically choose the dimension parameter), and thus at most 100-X% error is introduced in the kNN query. Likewise error epsilon parameters can be introduced in the tree search.