{ "cells": [ { "cell_type": "code", "execution_count": null, "metadata": { "colab": {}, "colab_type": "code", "id": "BvJ0QjgbuFtp" }, "outputs": [], "source": [ "from sklearn import datasets, linear_model, model_selection\n", "\n", "# load the data\n", "diabetes = datasets.load_diabetes()\n", "\n", "# target\n", "y = diabetes.target\n", "\n", "# features\n", "X = diabetes.data" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 34 }, "colab_type": "code", "executionInfo": { "elapsed": 910, "status": "ok", "timestamp": 1571316858888, "user": { "displayName": "Andrew Worsley", "photoUrl": "https://lh3.googleusercontent.com/a-/AAuE7mAp-Td-yKvu76Tg0Swzal8U17btuwNIXFmWVwZo=s64", "userId": "11337101975325054847" }, "user_tz": -660 }, "id": "XKHMBvvmudNU", "outputId": "82434ba9-2d90-4beb-f13e-4c068d096eed" }, "outputs": [], "source": [ "# the first patient has index 0\n", "print(y[0])" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 54 }, "colab_type": "code", "executionInfo": { "elapsed": 905, "status": "ok", "timestamp": 1571316858890, "user": { "displayName": "Andrew Worsley", "photoUrl": "https://lh3.googleusercontent.com/a-/AAuE7mAp-Td-yKvu76Tg0Swzal8U17btuwNIXFmWVwZo=s64", "userId": "11337101975325054847" }, "user_tz": -660 }, "id": "hlKYUSJSuhQH", "outputId": "3f0acc84-f9f3-4808-b89f-8e36cd4c1276" }, "outputs": [], "source": [ "# lets look at the first patients data\n", "print(\n", " dict(zip(diabetes.feature_names, X[0]))\n", ")" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 300 }, "colab_type": "code", "executionInfo": { "elapsed": 1388, "status": "ok", "timestamp": 1571316859384, "user": { "displayName": "Andrew Worsley", "photoUrl": "https://lh3.googleusercontent.com/a-/AAuE7mAp-Td-yKvu76Tg0Swzal8U17btuwNIXFmWVwZo=s64", "userId": "11337101975325054847" }, "user_tz": -660 }, "id": "GJtBIe3lu-df", "outputId": "3d8e6a86-28f8-4599-d82c-5ca85d0201a6" }, "outputs": [], "source": [ "import numpy as np\n", "from scipy import stats\n", "import matplotlib.pyplot as plt\n", "\n", "# values of alpha\n", "x = np.linspace(1, 20, 100)\n", "\n", "# probabilities\n", "p_X = stats.gamma.pdf(x=x, a=1, loc=1, scale=2)\n", "\n", "plt.plot(x,p_X)\n", "plt.xlabel('alpha')\n", "plt.ylabel('P(alpha)')" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": {}, "colab_type": "code", "id": "fvtXzLMXvdUk" }, "outputs": [], "source": [ "# n sample values\n", "n_iter = 100\n", "\n", "# sample from the gamma distribution\n", "samples = stats.gamma.rvs(a=1, loc=1, scale=2, size=n_iter, random_state=100)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 300 }, "colab_type": "code", "executionInfo": { "elapsed": 1370, "status": "ok", "timestamp": 1571316859386, "user": { "displayName": "Andrew Worsley", "photoUrl": "https://lh3.googleusercontent.com/a-/AAuE7mAp-Td-yKvu76Tg0Swzal8U17btuwNIXFmWVwZo=s64", "userId": "11337101975325054847" }, "user_tz": -660 }, "id": "Eh1gcP2bvo9L", "outputId": "390e9fa3-78d1-4ed1-b9b0-a9792eac588b" }, "outputs": [], "source": [ "# visualise the sample distribution\n", "plt.hist(samples)\n", "plt.xlabel('alpha')\n", "plt.ylabel('sample count')" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": {}, "colab_type": "code", "id": "Nou_IG1IwlAu" }, "outputs": [], "source": [ "# we will store the results inside a dictionary\n", "result = {}\n", "\n", "# for each sample\n", "for sample in samples:\n", " \n", " # initialize a ridge regression estimator with alpha set to the sample value\n", " reg = linear_model.Ridge(alpha=sample)\n", " \n", " # conduct a 10-fold cross validation scoring on negative mean squared error\n", " cv = model_selection.cross_val_score(reg, X, y, cv=10, scoring='neg_mean_squared_error')\n", " \n", " # retain the result in the dictionary\n", " result[sample] = [cv.mean()]" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 121 }, "colab_type": "code", "executionInfo": { "elapsed": 2788, "status": "ok", "timestamp": 1571316860822, "user": { "displayName": "Andrew Worsley", "photoUrl": "https://lh3.googleusercontent.com/a-/AAuE7mAp-Td-yKvu76Tg0Swzal8U17btuwNIXFmWVwZo=s64", "userId": "11337101975325054847" }, "user_tz": -660 }, "id": "aqSFjq7ExSMK", "outputId": "515cb7ff-d8ea-41bc-dba7-ec4c79fcb843" }, "outputs": [], "source": [ "import pandas as pd\n", "\n", "# convert the result dictionary to a pandas dataframe, transpose and reset the index\n", "df_result = pd.DataFrame(result).T.reset_index()\n", "\n", "# give the columns sensible names\n", "df_result.columns = ['alpha', 'mean_neg_mean_squared_error']\n", "\n", "print(df_result.sort_values('mean_neg_mean_squared_error', ascending=False).head(5))" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 300 }, "colab_type": "code", "executionInfo": { "elapsed": 950, "status": "ok", "timestamp": 1571316908254, "user": { "displayName": "Andrew Worsley", "photoUrl": "https://lh3.googleusercontent.com/a-/AAuE7mAp-Td-yKvu76Tg0Swzal8U17btuwNIXFmWVwZo=s64", "userId": "11337101975325054847" }, "user_tz": -660 }, "id": "9pCIKSN6yRzS", "outputId": "17a1a958-82a4-4f0c-bf64-0a735517333a" }, "outputs": [], "source": [ "plt.scatter(df_result.alpha, df_result.mean_neg_mean_squared_error)\n", "plt.xlabel('alpha')\n", "plt.ylabel('-MSE')" ] } ], "metadata": { "colab": { "collapsed_sections": [], "name": "simple_demonstration_of_the_random_search_process.ipynb", "provenance": [] }, "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.8.6" } }, "nbformat": 4, "nbformat_minor": 1 }