Files
mlessentials/Lab08/Examples/classify-mushrooms.ipynb
T
Your Name 8ac470b565 notebooks
2021-02-05 20:44:51 +00:00

334 lines
7.1 KiB
Plaintext
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Mushroom Classification"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The objective of this activity is to employ the grid and randomized search strategies to find an optimal model capable of discerning whether a particular mushroom species is poisonous or not given attributed relating to its appearance."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### 1. Load the data into Python and call the object mushrooms."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Import pandas."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import pandas as pd"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Read the data. Note the lack of header."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"mushrooms = pd.read_csv('./agaricus-lepiota.data', header=None)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"View the data."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"mushrooms"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### 2.\tSeparate the target y and features X from the dataset. "
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"y_raw = mushrooms.iloc[:,0]\n",
"\n",
"X_raw = mushrooms.iloc[:,1:]"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### 3.\tRecode the target y such that poisonous mushrooms are represented as 1, edible mushrooms as 0."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"y = (y_raw == 'p') * 1"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"y"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### 4.\tThe featureset X will need to have its columns transformed into a numpy array with a binary representation. This is known as one hot encoding."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Import preprocessing."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from sklearn import preprocessing"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Create and fit the encoder then transform the data."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"encoder = preprocessing.OneHotEncoder()\n",
"\n",
"encoder.fit(X_raw)\n",
"\n",
"X = encoder.transform(X_raw).toarray()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"View the data."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"X"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### 5.\tConduct both a grid and random search to find an optimal hyperparameterization for a random forest classifier. Use accuracy as your method of model evaluation. Which method of tuning is more effective?"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Import ensemble."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from sklearn import ensemble"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Create the random forest classifer."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"rfc = ensemble.RandomForestClassifier(n_estimators=100, random_state=150)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Import model select, define the grid, set up the grid search, start the grid search and visualise the results."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from sklearn import model_selection\n",
"\n",
"grid = {\n",
" 'criterion': ['gini', 'entropy'],\n",
" 'max_features': [2, 4, 6, 8, 10, 12, 14]\n",
"}\n",
"\n",
"gscv = model_selection.GridSearchCV(estimator=rfc, param_grid=grid, cv=5, scoring='accuracy')\n",
"\n",
"gscv.fit(X,y)\n",
"\n",
"results = pd.DataFrame(gscv.cv_results_)\n",
"\n",
"results.sort_values('rank_test_score', ascending=True).head(10)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### 6.\tPlot mean test score vs hyperparameterization for the top 10 models found. Can you spot any obvious patterns?"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"(\n",
" results\n",
" .sort_values('rank_test_score', ascending=False)\n",
" .loc[:,['params','mean_test_score']]\n",
" .head(10).plot.barh(x='params', xlim=(0.8))\n",
")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Import the stats models, define the parameter dictionary an any distributions, conduct a randomized search and visualise the results."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from scipy import stats\n",
"\n",
"max_features = X.shape[1]\n",
"\n",
"param_dist = {\n",
" 'criterion': ['gini', 'entropy'],\n",
" 'max_features': stats.randint(low=1, high=max_features)\n",
"}\n",
"\n",
"rscv = model_selection.RandomizedSearchCV(estimator=rfc, param_distributions=param_dist, n_iter=50, cv=5, scoring='accuracy', random_state=100)\n",
"\n",
"rscv.fit(X,y)\n",
"\n",
"results = pd.DataFrame(rscv.cv_results_)\n",
"\n",
"results.sort_values('rank_test_score', ascending=True).head(10)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"results.loc[:,'params'] = results.loc[:,'params'].astype(str)\n",
"\n",
"(\n",
" results.sort_values('rank_test_score', ascending=False)\n",
" .loc[:,['params','mean_test_score']]\n",
" .drop_duplicates()\n",
" .head(10)\n",
" .plot.barh(x='params', xlim=(0.8))\n",
")"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.6"
}
},
"nbformat": 4,
"nbformat_minor": 2
}