Files
mlessentials/Lab08/Examples/tuning_using_gridsearchcv.ipynb
T
Your Name 8ac470b565 notebooks
2021-02-05 20:44:51 +00:00

221 lines
5.1 KiB
Plaintext

{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {},
"colab_type": "code",
"id": "aBtvWJNQhBM5"
},
"outputs": [],
"source": [
"from sklearn import model_selection, datasets, neighbors\n",
"\n",
"# load the data\n",
"cancer = datasets.load_breast_cancer()\n",
"\n",
"# target\n",
"y = cancer.target\n",
"\n",
"# features\n",
"X = cancer.data"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {},
"colab_type": "code",
"id": "p6fQASEChP03"
},
"outputs": [],
"source": [
"# initalize the estimator\n",
"knn = neighbors.KNeighborsClassifier()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {},
"colab_type": "code",
"id": "cADtTvf6hb3O"
},
"outputs": [],
"source": [
"# grid contains k and the weight function\n",
"grid = {\n",
" 'n_neighbors': [1, 3, 5, 7],\n",
" 'weights': ['uniform', 'distance']\n",
"}"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {},
"colab_type": "code",
"id": "ZZ_uK5A0ho67"
},
"outputs": [],
"source": [
"# set up the grid search with scoring on precsions and number of folds = 10\n",
"gscv = model_selection.GridSearchCV(estimator=knn, param_grid=grid, scoring='precision', cv=10)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 263
},
"colab_type": "code",
"executionInfo": {
"elapsed": 1591,
"status": "ok",
"timestamp": 1571314594262,
"user": {
"displayName": "Andrew Worsley",
"photoUrl": "https://lh3.googleusercontent.com/a-/AAuE7mAp-Td-yKvu76Tg0Swzal8U17btuwNIXFmWVwZo=s64",
"userId": "11337101975325054847"
},
"user_tz": -660
},
"id": "tLqJ_3fBh2Bh",
"outputId": "bf075dfd-0a91-4171-951e-6735b6885c98"
},
"outputs": [],
"source": [
"# start the search\n",
"gscv.fit(X, y)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 454
},
"colab_type": "code",
"executionInfo": {
"elapsed": 1583,
"status": "ok",
"timestamp": 1571314594263,
"user": {
"displayName": "Andrew Worsley",
"photoUrl": "https://lh3.googleusercontent.com/a-/AAuE7mAp-Td-yKvu76Tg0Swzal8U17btuwNIXFmWVwZo=s64",
"userId": "11337101975325054847"
},
"user_tz": -660
},
"id": "yPZkBBZGiWuC",
"outputId": "ae372db9-8dee-4a68-a564-48dc61d693b8"
},
"outputs": [],
"source": [
"# view the results\n",
"print(gscv.cv_results_)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 121
},
"colab_type": "code",
"executionInfo": {
"elapsed": 1576,
"status": "ok",
"timestamp": 1571314594265,
"user": {
"displayName": "Andrew Worsley",
"photoUrl": "https://lh3.googleusercontent.com/a-/AAuE7mAp-Td-yKvu76Tg0Swzal8U17btuwNIXFmWVwZo=s64",
"userId": "11337101975325054847"
},
"user_tz": -660
},
"id": "H6n1RzMiifkZ",
"outputId": "cb139684-1c11-4e63-827a-0d8b2dd496e7"
},
"outputs": [],
"source": [
"import pandas as pd\n",
"\n",
"# convert the results dictionary to a dataframe\n",
"results = pd.DataFrame(gscv.cv_results_)\n",
"\n",
"# select just the hyperparameterizations tried, the mean test scores, order by score and show the top 5 models\n",
"print(\n",
"results.loc[:,['params','mean_test_score']].sort_values('mean_test_score', ascending=False).head(5)\n",
")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 286
},
"colab_type": "code",
"executionInfo": {
"elapsed": 1323,
"status": "ok",
"timestamp": 1571314633729,
"user": {
"displayName": "Andrew Worsley",
"photoUrl": "https://lh3.googleusercontent.com/a-/AAuE7mAp-Td-yKvu76Tg0Swzal8U17btuwNIXFmWVwZo=s64",
"userId": "11337101975325054847"
},
"user_tz": -660
},
"id": "I92n0QuvpM03",
"outputId": "28021da8-996b-4b0b-c821-8ac571492b1b"
},
"outputs": [],
"source": [
"# visualize the result\n",
"results.loc[:,['params','mean_test_score']].plot.barh(x = 'params')"
]
}
],
"metadata": {
"colab": {
"collapsed_sections": [],
"name": "tuning_using_gridsearchcv.ipynb",
"provenance": []
},
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.6"
}
},
"nbformat": 4,
"nbformat_minor": 1
}