{ "cells": [ { "cell_type": "code", "execution_count": null, "metadata": { "colab": {}, "colab_type": "code", "id": "aBtvWJNQhBM5" }, "outputs": [], "source": [ "from sklearn import model_selection, datasets, neighbors\n", "\n", "# load the data\n", "cancer = datasets.load_breast_cancer()\n", "\n", "# target\n", "y = cancer.target\n", "\n", "# features\n", "X = cancer.data" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": {}, "colab_type": "code", "id": "p6fQASEChP03" }, "outputs": [], "source": [ "# initalize the estimator\n", "knn = neighbors.KNeighborsClassifier()" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": {}, "colab_type": "code", "id": "cADtTvf6hb3O" }, "outputs": [], "source": [ "# grid contains k and the weight function\n", "grid = {\n", " 'n_neighbors': [1, 3, 5, 7],\n", " 'weights': ['uniform', 'distance']\n", "}" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": {}, "colab_type": "code", "id": "ZZ_uK5A0ho67" }, "outputs": [], "source": [ "# set up the grid search with scoring on precsions and number of folds = 10\n", "gscv = model_selection.GridSearchCV(estimator=knn, param_grid=grid, scoring='precision', cv=10)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 263 }, "colab_type": "code", "executionInfo": { "elapsed": 1591, "status": "ok", "timestamp": 1571314594262, "user": { "displayName": "Andrew Worsley", "photoUrl": "https://lh3.googleusercontent.com/a-/AAuE7mAp-Td-yKvu76Tg0Swzal8U17btuwNIXFmWVwZo=s64", "userId": "11337101975325054847" }, "user_tz": -660 }, "id": "tLqJ_3fBh2Bh", "outputId": "bf075dfd-0a91-4171-951e-6735b6885c98" }, "outputs": [], "source": [ "# start the search\n", "gscv.fit(X, y)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 454 }, "colab_type": "code", "executionInfo": { "elapsed": 1583, "status": "ok", "timestamp": 1571314594263, "user": { "displayName": "Andrew Worsley", "photoUrl": "https://lh3.googleusercontent.com/a-/AAuE7mAp-Td-yKvu76Tg0Swzal8U17btuwNIXFmWVwZo=s64", "userId": "11337101975325054847" }, "user_tz": -660 }, "id": "yPZkBBZGiWuC", "outputId": "ae372db9-8dee-4a68-a564-48dc61d693b8" }, "outputs": [], "source": [ "# view the results\n", "print(gscv.cv_results_)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 121 }, "colab_type": "code", "executionInfo": { "elapsed": 1576, "status": "ok", "timestamp": 1571314594265, "user": { "displayName": "Andrew Worsley", "photoUrl": "https://lh3.googleusercontent.com/a-/AAuE7mAp-Td-yKvu76Tg0Swzal8U17btuwNIXFmWVwZo=s64", "userId": "11337101975325054847" }, "user_tz": -660 }, "id": "H6n1RzMiifkZ", "outputId": "cb139684-1c11-4e63-827a-0d8b2dd496e7" }, "outputs": [], "source": [ "import pandas as pd\n", "\n", "# convert the results dictionary to a dataframe\n", "results = pd.DataFrame(gscv.cv_results_)\n", "\n", "# select just the hyperparameterizations tried, the mean test scores, order by score and show the top 5 models\n", "print(\n", "results.loc[:,['params','mean_test_score']].sort_values('mean_test_score', ascending=False).head(5)\n", ")" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 286 }, "colab_type": "code", "executionInfo": { "elapsed": 1323, "status": "ok", "timestamp": 1571314633729, "user": { "displayName": "Andrew Worsley", "photoUrl": "https://lh3.googleusercontent.com/a-/AAuE7mAp-Td-yKvu76Tg0Swzal8U17btuwNIXFmWVwZo=s64", "userId": "11337101975325054847" }, "user_tz": -660 }, "id": "I92n0QuvpM03", "outputId": "28021da8-996b-4b0b-c821-8ac571492b1b" }, "outputs": [], "source": [ "# visualize the result\n", "results.loc[:,['params','mean_test_score']].plot.barh(x = 'params')" ] } ], "metadata": { "colab": { "collapsed_sections": [], "name": "tuning_using_gridsearchcv.ipynb", "provenance": [] }, "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.8.6" } }, "nbformat": 4, "nbformat_minor": 1 }