.. DO NOT EDIT. .. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY. .. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE: .. "auto_examples/plot_k_neighbors_classification.py" .. LINE NUMBERS ARE GIVEN BELOW. .. only:: html .. note:: :class: sphx-glr-download-link-note :ref:`Go to the end ` to download the full example code or to run this example in your browser via Binder .. rst-class:: sphx-glr-example-title .. _sphx_glr_auto_examples_plot_k_neighbors_classification.py: K-nearest neighbors classification ================================== Shows the usage of the k-nearest neighbors classifier. .. GENERATED FROM PYTHON SOURCE LINES 7-18 .. code-block:: Python # Author: Pablo Marcos Manchón # License: MIT import matplotlib.pyplot as plt import numpy as np from sklearn.model_selection import GridSearchCV, train_test_split import skfda from skfda.ml.classification import KNeighborsClassifier .. GENERATED FROM PYTHON SOURCE LINES 19-29 In this example we are going to show the usage of the K-nearest neighbors classifier in their functional version, which is a extension of the multivariate one, but using functional metrics. Firstly, we are going to fetch a functional dataset, such as the Berkeley Growth Study. This dataset contains the height of several boys and girls measured until the 18 years of age. We will try to predict sex from their growth curves. The following figure shows the growth curves grouped by sex. .. GENERATED FROM PYTHON SOURCE LINES 30-41 .. code-block:: Python X, y = skfda.datasets.fetch_growth(return_X_y=True, as_frame=True) X = X.iloc[:, 0].values y = y.values # Plot samples grouped by sex X.plot(group=y.codes, group_names=y.categories) y = y.codes .. image-sg:: /auto_examples/images/sphx_glr_plot_k_neighbors_classification_001.png :alt: Berkeley Growth Study :srcset: /auto_examples/images/sphx_glr_plot_k_neighbors_classification_001.png :class: sphx-glr-single-img .. GENERATED FROM PYTHON SOURCE LINES 42-44 The class labels are stored in an array. Zeros represent male samples while ones represent female samples. .. GENERATED FROM PYTHON SOURCE LINES 45-48 .. code-block:: Python print(y) .. rst-class:: sphx-glr-script-out .. code-block:: none [0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1] .. GENERATED FROM PYTHON SOURCE LINES 49-55 We can split the dataset using the sklearn function :func:`~sklearn.model_selection.train_test_split`. The function will return two :class:`~skfda.representation.grid.FDataGrid`'s, ``X_train`` and ``X_test`` with the corresponding partitions, and arrays with their class labels. .. GENERATED FROM PYTHON SOURCE LINES 56-66 .. code-block:: Python X_train, X_test, y_train, y_test = train_test_split( X, y, test_size=0.25, stratify=y, random_state=0, ) .. GENERATED FROM PYTHON SOURCE LINES 67-74 We will fit the classifier :class:`~skfda.ml.classification.KNeighborsClassifier` with the training partition. This classifier works exactly like the sklearn multivariate classifier :class:`~sklearn.neighbors.KNeighborsClassifier`, but it's input is a :class:`~skfda.representation.grid.FDataGrid` with functional observations instead of an array with multivariate data. .. GENERATED FROM PYTHON SOURCE LINES 75-79 .. code-block:: Python knn = KNeighborsClassifier(n_neighbors=5) knn.fit(X_train, y_train) .. raw:: html
KNeighborsClassifier()
In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.


.. GENERATED FROM PYTHON SOURCE LINES 80-91 Once it is fitted, we can predict labels for the test samples. To predict the label of a test sample, the classifier will calculate the k-nearest neighbors and will assign the class shared by most of those k neighbors. In this case, we have set the number of neighbors to 5 (:math:`k=5`). By default, it will use the :math:`\mathbb{L}^2` distance between functions, to determine the neighborhood of a sample. However, it can be used with any of the functional metrics described in :doc:`/modules/misc/metrics`. .. GENERATED FROM PYTHON SOURCE LINES 92-96 .. code-block:: Python pred = knn.predict(X_test) print(pred) .. rst-class:: sphx-glr-script-out .. code-block:: none [0 0 1 0 1 1 1 0 0 0 0 1 1 0 0 0 0 1 1 1 1 1 1 1] .. GENERATED FROM PYTHON SOURCE LINES 97-100 The :func:`~skfda.ml.classification.KNeighborsClassifier.score` method allows us to calculate the mean accuracy for the test data. In this case we obtained around 96% of accuracy. .. GENERATED FROM PYTHON SOURCE LINES 101-105 .. code-block:: Python score = knn.score(X_test, y_test) print(score) .. rst-class:: sphx-glr-script-out .. code-block:: none 0.9583333333333334 .. GENERATED FROM PYTHON SOURCE LINES 106-110 We can also estimate the probability of membership to the predicted class using :func:`~skfda.ml.classification.KNeighborsClassifier.predict_proba`, which will return an array with the probabilities of the classes, in lexicographic order, for each test sample. .. GENERATED FROM PYTHON SOURCE LINES 111-116 .. code-block:: Python probs = knn.predict_proba(X_test[:5]) # Predict first 5 samples print(probs) .. rst-class:: sphx-glr-script-out .. code-block:: none [[1. 0. ] [0.6 0.4] [0. 1. ] [1. 0. ] [0. 1. ]] .. GENERATED FROM PYTHON SOURCE LINES 117-122 We can use the sklearn :class:`~sklearn.model_selection.GridSearchCV` to perform a grid search to select the best hyperparams, using cross-validation. In this case, we will vary the number of neighbors between 1 and 17. .. GENERATED FROM PYTHON SOURCE LINES 123-139 .. code-block:: Python # Only odd numbers, to prevent ties param_grid = {"n_neighbors": range(1, 18, 2)} knn = KNeighborsClassifier() # Perform grid search with cross-validation gscv = GridSearchCV(knn, param_grid, cv=5) gscv.fit(X_train, y_train) print("Best params:", gscv.best_params_) print("Best cross-validation score:", gscv.best_score_) .. rst-class:: sphx-glr-script-out .. code-block:: none Best params: {'n_neighbors': 11} Best cross-validation score: 0.9571428571428573 .. GENERATED FROM PYTHON SOURCE LINES 140-142 We have obtained the greatest mean accuracy using 11 neighbors. The following figure shows the score depending on the number of neighbors. .. GENERATED FROM PYTHON SOURCE LINES 143-152 .. code-block:: Python fig = plt.figure() ax = fig.add_subplot(1, 1, 1) ax.bar(param_grid["n_neighbors"], gscv.cv_results_["mean_test_score"]) ax.set_xticks(param_grid["n_neighbors"]) ax.set_ylabel("Number of Neighbors") ax.set_xlabel("Cross-validation score") ax.set_ylim((0.9, 1)) .. image-sg:: /auto_examples/images/sphx_glr_plot_k_neighbors_classification_002.png :alt: plot k neighbors classification :srcset: /auto_examples/images/sphx_glr_plot_k_neighbors_classification_002.png :class: sphx-glr-single-img .. rst-class:: sphx-glr-script-out .. code-block:: none (0.9, 1.0) .. GENERATED FROM PYTHON SOURCE LINES 153-159 By default, after performing the cross validation, the classifier will be fitted to the whole training data provided in the call to :func:`~skfda.ml.classification.KNeighborsClassifier.fit`. Therefore, to check the accuracy of the classifier for the number of neighbors selected (11), we can simply call the :func:`~sklearn.model_selection.GridSearchCV.score` method. .. GENERATED FROM PYTHON SOURCE LINES 160-164 .. code-block:: Python score = gscv.score(X_test, y_test) print(score) .. rst-class:: sphx-glr-script-out .. code-block:: none 1.0 .. GENERATED FROM PYTHON SOURCE LINES 165-167 This classifier can be used with multivariate functional data, as surfaces or curves in :math:`\mathbb{R}^N`, if the metric supports it too. .. rst-class:: sphx-glr-timing **Total running time of the script:** (0 minutes 0.621 seconds) .. _sphx_glr_download_auto_examples_plot_k_neighbors_classification.py: .. only:: html .. container:: sphx-glr-footer sphx-glr-footer-example .. container:: binder-badge .. image:: images/binder_badge_logo.svg :target: https://mybinder.org/v2/gh/GAA-UAM/scikit-fda/develop?filepath=examples/plot_k_neighbors_classification.py :alt: Launch binder :width: 150 px .. container:: sphx-glr-download sphx-glr-download-jupyter :download:`Download Jupyter notebook: plot_k_neighbors_classification.ipynb ` .. container:: sphx-glr-download sphx-glr-download-python :download:`Download Python source code: plot_k_neighbors_classification.py ` .. only:: html .. rst-class:: sphx-glr-signature `Gallery generated by Sphinx-Gallery `_