{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Mushroom Classification" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The objective of this activity is to employ the grid and randomized search strategies to find an optimal model capable of discerning whether a particular mushroom species is poisonous or not given attributed relating to its appearance." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### 1. Load the data into Python and call the object mushrooms." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Import pandas." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import pandas as pd" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Read the data. Note the lack of header." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "mushrooms = pd.read_csv('./agaricus-lepiota.data', header=None)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "View the data." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "mushrooms" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### 2.\tSeparate the target y and features X from the dataset. " ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "y_raw = mushrooms.iloc[:,0]\n", "\n", "X_raw = mushrooms.iloc[:,1:]" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### 3.\tRecode the target y such that poisonous mushrooms are represented as 1, edible mushrooms as 0." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "y = (y_raw == 'p') * 1" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "y" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### 4.\tThe featureset X will need to have its columns transformed into a numpy array with a binary representation. This is known as ‘one hot encoding.’" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Import preprocessing." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from sklearn import preprocessing" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Create and fit the encoder then transform the data." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "encoder = preprocessing.OneHotEncoder()\n", "\n", "encoder.fit(X_raw)\n", "\n", "X = encoder.transform(X_raw).toarray()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "View the data." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "X" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### 5.\tConduct both a grid and random search to find an optimal hyperparameterization for a random forest classifier. Use accuracy as your method of model evaluation. Which method of tuning is more effective?" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Import ensemble." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from sklearn import ensemble" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Create the random forest classifer." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "rfc = ensemble.RandomForestClassifier(n_estimators=100, random_state=150)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Import model select, define the grid, set up the grid search, start the grid search and visualise the results." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from sklearn import model_selection\n", "\n", "grid = {\n", " 'criterion': ['gini', 'entropy'],\n", " 'max_features': [2, 4, 6, 8, 10, 12, 14]\n", "}\n", "\n", "gscv = model_selection.GridSearchCV(estimator=rfc, param_grid=grid, cv=5, scoring='accuracy')\n", "\n", "gscv.fit(X,y)\n", "\n", "results = pd.DataFrame(gscv.cv_results_)\n", "\n", "results.sort_values('rank_test_score', ascending=True).head(10)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### 6.\tPlot mean test score vs hyperparameterization for the top 10 models found. Can you spot any obvious patterns?" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "(\n", " results\n", " .sort_values('rank_test_score', ascending=False)\n", " .loc[:,['params','mean_test_score']]\n", " .head(10).plot.barh(x='params', xlim=(0.8))\n", ")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Import the stats models, define the parameter dictionary an any distributions, conduct a randomized search and visualise the results." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from scipy import stats\n", "\n", "max_features = X.shape[1]\n", "\n", "param_dist = {\n", " 'criterion': ['gini', 'entropy'],\n", " 'max_features': stats.randint(low=1, high=max_features)\n", "}\n", "\n", "rscv = model_selection.RandomizedSearchCV(estimator=rfc, param_distributions=param_dist, n_iter=50, cv=5, scoring='accuracy', random_state=100)\n", "\n", "rscv.fit(X,y)\n", "\n", "results = pd.DataFrame(rscv.cv_results_)\n", "\n", "results.sort_values('rank_test_score', ascending=True).head(10)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "results.loc[:,'params'] = results.loc[:,'params'].astype(str)\n", "\n", "(\n", " results.sort_values('rank_test_score', ascending=False)\n", " .loc[:,['params','mean_test_score']]\n", " .drop_duplicates()\n", " .head(10)\n", " .plot.barh(x='params', xlim=(0.8))\n", ")" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.8.6" } }, "nbformat": 4, "nbformat_minor": 2 }