Overfitting Figure 3

Testing for Overfitting in Binary Classifiers

Overfitting (or overtraining) is a common problem for supervised learning models in which learned behavior from a training dataset does not generalize well to an unseen test dataset. The most common cause of overfitting is model complexity (in random forests an example would be using trees with too much depth.) The good news is that overfitting is easily testable and remedied. This paper will describe an approach to testing overfitting using the probability distributions of binary classifier output and the Kolmogorov-Smirnov test.

Some Context for Overfitting Tests

In order to fully understand the overfitting tests in this paper, I first re-establish a few fundamental principles of statistical learning that are relevant to this discussion. When one trains a supervised learning model and provides its output (weights, insights, etc.) as a product, they are making a few claims:

  • The model has learned correlations/relationships between the input features themselves as well as the target.
  • These learned correlations will generalize to un- seen data such that predictions can be made and understood as arising from the learned feature correlations.

When a model overfits to the training data these two claims can not be made and we are left with a model that cannot be trusted to generalize to new data, nor can any claims be made to understand why the model predicted what it did.

Kolmogorov-Smirnov Test Statistic

A straightforward way to test a binary classifier for overfitting is to plot the classifier output (a probability output from zero to one) for both the test and train sets (see Figure 1).

Figure 1: Example of an overfitting plot for a binary classifier which is not overfitting.

If the two claims we made about our statistical learning model from section one are to hold then we need to require that the test and train distributions of this plot are consistent with one another. The Kolmogorov- Smirnov (KS) test will do just that. For a robust mathematical description one can read here. The KS test can be framed as a non-parametric hypothesis test of agreement between two histograms.  In this case the KS statistic is defined as

F1 and Fare the cumulative distribution functions of n and m sample histograms respectively and sup is the supremum function.

When the KS statistic takes low values (near zero) the p-value becomes large and we can not reject the null hypothesis of the two distributions coming from an underlying, common distribution (exactly what we wanted to prove in order to establish our two assertions from section one.)

Example

In order to drive home the power of the KS test, I have included a test case of random forest that is purposely overfitted by making it overly complex (tree depth =

30) and another random forest with a tree depth of 10. The probability output of these two models are shown below in figures 2 and 3 respectively.

Figure 2: Overfitted model, random forest with tree depth = 30
Figure 3: Random forest with tree depth = 10

It is clear from visual inspection of these plots that one of these models is producing output on the test set that does not represent the correlations learned on the training set. The KS test statistic encapsulates what we can verify visually. One very important side note here is that the accuracy/precision/recall for the overfitted model can be HIGHER than that of the non-overfitted model. This is due to all of those metrics relying on a probability threshold choice that does not take into account whether the test/train distributions are consistent.

Conclusion & Code

The take-home: make these plots and check your binary classification model for overfitting or else you are at risk of shipping a model that will not generalize to new data and should not be put into a production pipeline.

Below is the function that we use to create a KS plot. You will also see an example of how to use the function using example data provided by sklearn.

def make_ks_plot(y_train, train_proba, y_test, test_proba, bins=30, fig_sz=(10, 8)):
    '''
    OUTPUT: outputs KS test/train overtraining plots for classifier output
    INPUTS:
        y_train - Series with outputs of model
        train_proba - np.ndarray from sklearn predict_praba(). Same shape as y_train. 0-1 probabilities from model.
        y_test - Series with outputs of model
        test_proba - np.ndarray from sklearn predict_praba(). Same shape as y_test. 0-1 probabilities from model.
        bins - number of bins for viz. Default 30.
        label_col_name - name of y-label. Change to whatever your model has it named. Default 'label'.
        fig_sz - change to True in order to get larger outputs. Default False.
    '''
    train = pd.DataFrame(y_train, columns=["label"])
    test = pd.DataFrame(y_test, columns=["label"])
    train["probability"] = train_proba
    test["probability"] = test_proba
    decisions = []
    for df in [train, test]:
        d1 = df['probability'][df["label"] == 1]
        d2 = df['probability'][df["label"] == 0]
        decisions += [d1, d2]
    low = min(np.min(d) for d in decisions)
    high = max(np.max(d) for d in decisions)
    low_high = (low,high)
    fig = plt.figure(figsize=fig_sz)
    train_pos = plt.hist(decisions[0],
         color='r', alpha=0.5, range=low_high, bins=bins,
         histtype='stepfilled', density=True,
         label='+ (train)')
    train_neg = plt.hist(decisions[1],
         color='b', alpha=0.5, range=low_high, bins=bins,
         histtype='stepfilled', density=True,
         label='- (train)')
    hist, bins = np.histogram(decisions[2],
                          bins=bins, range=low_high, density=True)
    scale = len(decisions[2]) / sum(hist)
    err = np.sqrt(hist * scale) / scale
    width = (bins[1] - bins[0])
    center = (bins[:-1] + bins[1:]) / 2
    test_pos = plt.errorbar(center, hist, yerr=err, fmt='o', c='r', label='+ (test)')
    hist, bins = np.histogram(decisions[3],
                          bins=bins, range=low_high, density=True)
    scale = len(decisions[2]) / sum(hist)
    err = np.sqrt(hist * scale) / scale
    test_neg = plt.errorbar(center, hist, yerr=err, fmt='o', c='b', label='- (test)')
    # get the KS score
    ks = stats.ks_2samp(decisions[0], decisions[2])
    plt.xlabel("Classifier Output", fontsize=12)
    plt.ylabel("Arbitrary Normalized Units", fontsize=12)
    plt.xlim(0, 1)
    plt.plot([], [], ' ', label='KS Statistic (p-value) :'+str(round(ks[0],2))+'('+str(round(ks[1],2))+')')
    plt.legend(loc='best', fontsize=12)
    plt.show()
    plt.close()
# While not a beautiful example of the power of KS,
# here is an example of the KS viz in action
from sklearn.ensemble import RandomForestClassifier
from sklearn.model_selection import train_test_split
from sklearn.datasets import load_iris
data = load_iris()
X = data.data
y = data.target
X_train, X_test, y_train, y_test = train_test_split(
    X, y, test_size=0.33, random_state=42)
clf = RandomForestClassifier(n_estimators=100, max_depth=8,
                             random_state=42)
clf.fit(X_train, y_train)
train_proba = clf.predict(X_train)
test_proba = clf.predict(X_test)
make_ks_plot(y_train, train_proba, y_test, test_proba)

Other Blog Posts

Blog
Jean Belanger, Cerebri AI CEO & Co-Founder

Why Watch Watches?

The Latin phrase tempus fugit, or “time flies” in the most common English translation, has taken on new meaning in our hyper-connected lives where unfathomable

Read More
Cerebri AI Science
Aarshiya Gupta

Cerebri AI Periodic Table of Data Science

Mankind’s preference for patterns for enhanced cognition and retention of complex information finds the reasoning in everything. However, our reluctance in investing a great quantity of time

Read More
Blog
Arun Prakash

The Enterprise Strikes Back

Not so long ago in a galaxy near you, the most successful companies applying advanced analytics to massive amounts of customer data were digital darlings

Read More