Libraries and Dataset¶

In [1]:
import numpy as np
import pandas as pd 
import matplotlib.pyplot as plt
import seaborn  as sns
from sklearn.preprocessing import LabelEncoder, OneHotEncoder, StandardScaler,MinMaxScaler
In [2]:
# dataset = pd.read_csv('data/data_market.csv', encoding='latin1')
dataset = pd.read_csv('data/market_cluster.csv', encoding='latin1')
In [3]:
dataset.head()
Out[3]:
Order ID Customer Name Category Sub Category City Order Date Region Sales Discount Profit State profit_margin Cluster
0 OD1 Harish Oil & Masala Masalas Vellore 11-08-2017 North 1254 0.12 401.28 Tamil Nadu 0.32 Medium
1 OD2 Sudha Beverages Health Drinks Krishnagiri 11-08-2017 South 749 0.18 149.80 Tamil Nadu 0.20 Medium
2 OD3 Hussain Food Grains Atta & Flour Perambalur 06-12-2017 West 2360 0.21 165.20 Tamil Nadu 0.07 Low
3 OD4 Jackson Fruits & Veggies Fresh Vegetables Dharmapuri 10-11-2016 South 896 0.25 89.60 Tamil Nadu 0.10 Low
4 OD5 Ridhesh Food Grains Organic Staples Ooty 10-11-2016 South 2355 0.26 918.45 Tamil Nadu 0.39 High

Data Cleaning & Preprocessing¶

In [4]:
dataset.drop(['Order ID'], axis=1, inplace=True)
In [5]:
dataset.isna().sum()
Out[5]:
Customer Name    0
Category         0
Sub Category     0
City             0
Order Date       0
Region           0
Sales            0
Discount         0
Profit           0
State            0
profit_margin    0
Cluster          0
dtype: int64
In [6]:
def remove_outliers(data: pd.DataFrame, column: str) -> pd.Series:
    q3, q1 = np.nanpercentile(data[column], [75, 25])
    iqr = q3 - q1
    upper_bound = q3 + 1.5 * iqr
    lower_bound = q1 - 1.5 * iqr
    data = data[(data[column] > lower_bound) & (data[column] < upper_bound)]

    return data

dataset = remove_outliers(dataset, 'Discount')
dataset = remove_outliers(dataset, 'Sales')
dataset = remove_outliers(dataset, 'Profit')
In [7]:
dataset.dropna(inplace=True)
In [8]:
dataset.head()
Out[8]:
Customer Name Category Sub Category City Order Date Region Sales Discount Profit State profit_margin Cluster
0 Harish Oil & Masala Masalas Vellore 11-08-2017 North 1254 0.12 401.28 Tamil Nadu 0.32 Medium
1 Sudha Beverages Health Drinks Krishnagiri 11-08-2017 South 749 0.18 149.80 Tamil Nadu 0.20 Medium
2 Hussain Food Grains Atta & Flour Perambalur 06-12-2017 West 2360 0.21 165.20 Tamil Nadu 0.07 Low
3 Jackson Fruits & Veggies Fresh Vegetables Dharmapuri 10-11-2016 South 896 0.25 89.60 Tamil Nadu 0.10 Low
4 Ridhesh Food Grains Organic Staples Ooty 10-11-2016 South 2355 0.26 918.45 Tamil Nadu 0.39 High
In [9]:
sns.histplot(dataset['Cluster'])
Out[9]:
<AxesSubplot:xlabel='Cluster', ylabel='Count'>
In [10]:
encoder = LabelEncoder()
scaler = StandardScaler()
onehot = OneHotEncoder()
minmaxscaler = MinMaxScaler()
In [11]:
dataset["Order Date"] = pd.to_datetime(dataset["Order Date"])
dataset["Order Date"] = dataset["Order Date"].dt.month

dataset["Customer Name"] = encoder.fit_transform(dataset["Customer Name"])
dataset["Category"] = encoder.fit_transform(dataset["Category"])
dataset["City"] = encoder.fit_transform(dataset["City"])
dataset["Region"] = encoder.fit_transform(dataset["Region"])
dataset["State"] = encoder.fit_transform(dataset["State"])
dataset["Sub Category"] = encoder.fit_transform(dataset["Sub Category"])

# dataset["Customer Name"] = onehot.fit_transform(dataset["Customer Name"].values.reshape(-1, 1)).toarray()
# dataset["Category"] = onehot.fit_transform(dataset["Category"].values.reshape(-1, 1)).toarray()
# dataset["City"] = onehot.fit_transform(dataset["City"].values.reshape(-1, 1)).toarray()
# dataset["Region"] = onehot.fit_transform(dataset["Region"].values.reshape(-1, 1)).toarray()
# dataset["State"] = onehot.fit_transform(dataset["State"].values.reshape(-1, 1)).toarray()
# dataset["Sub Category"] = onehot.fit_transform(dataset["Sub Category"].values.reshape(-1, 1)).toarray()

dataset["Order Date"] = encoder.fit_transform(dataset["Order Date"])
In [12]:
dataset[["Sales", "Discount", "profit_margin","Profit"]] = scaler.fit_transform(dataset[["Sales", "Discount", "profit_margin","Profit"]])
In [13]:
class_to_numeric = {'Low': 0, 'Medium': 1, 'High': 2}
dataset['Cluster'] = [class_to_numeric[label] for label in dataset['Cluster']]
In [14]:
dataset.head()
Out[14]:
Customer Name Category Sub Category City Order Date Region Sales Discount Profit State profit_margin Cluster
0 12 5 14 21 10 2 -0.414559 -1.430908 0.124389 0 0.595874 1
1 37 1 13 8 10 3 -1.291968 -0.627370 -0.941183 0 -0.416872 1
2 14 3 0 13 5 4 1.507054 -0.225601 -0.875930 0 -1.514014 0
3 15 4 12 4 9 3 -1.036563 0.310092 -1.196262 0 -1.260827 0
4 28 3 18 12 9 3 1.498367 0.444015 2.315743 0 1.186643 2

Splitting data train & test¶

In [15]:
X = dataset.drop(['Cluster','Sub Category','State','profit_margin'],axis=1)
y = dataset['Cluster']
In [16]:
X.head()
Out[16]:
Customer Name Category City Order Date Region Sales Discount Profit
0 12 5 21 10 2 -0.414559 -1.430908 0.124389
1 37 1 8 10 3 -1.291968 -0.627370 -0.941183
2 14 3 13 5 4 1.507054 -0.225601 -0.875930
3 15 4 4 9 3 -1.036563 0.310092 -1.196262
4 28 3 12 9 3 1.498367 0.444015 2.315743
In [17]:
from sklearn.model_selection import train_test_split
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=1)
In [18]:
X_train.shape , X_test.shape
Out[18]:
((7960, 8), (1991, 8))
In [19]:
y_train.shape , y_test.shape
Out[19]:
((7960,), (1991,))

Decision Tree with gini index¶

In [20]:
from sklearn.tree import DecisionTreeClassifier
from sklearn import tree
from sklearn.model_selection import GridSearchCV, cross_val_score, train_test_split, RandomizedSearchCV
In [21]:
from sklearn.metrics import accuracy_score
In [22]:
clf_gini = DecisionTreeClassifier(criterion='gini', max_depth=3)

# clf_gini = GridSearchCV(estimator=df, param_grid=param_grid, cv=5, scoring='accuracy')
# clf_gini.fit(X_train, y_train)
plt.figure(figsize=(12,8))

tree.plot_tree(clf_gini.fit(X_train, y_train)) 
Out[22]:
[Text(0.5, 0.875, 'x[7] <= -0.542\ngini = 0.666\nsamples = 7960\nvalue = [2723, 2720, 2517]'),
 Text(0.25, 0.625, 'x[5] <= -0.792\ngini = 0.426\nsamples = 2968\nvalue = [2143, 662, 163]'),
 Text(0.125, 0.375, 'x[7] <= -1.087\ngini = 0.586\nsamples = 1528\nvalue = [748, 617, 163]'),
 Text(0.0625, 0.125, 'gini = 0.083\nsamples = 603\nvalue = [577, 26, 0]'),
 Text(0.1875, 0.125, 'gini = 0.527\nsamples = 925\nvalue = [171, 591, 163]'),
 Text(0.375, 0.375, 'x[7] <= -0.648\ngini = 0.061\nsamples = 1440\nvalue = [1395, 45, 0]'),
 Text(0.3125, 0.125, 'gini = 0.013\nsamples = 1238\nvalue = [1230, 8, 0]'),
 Text(0.4375, 0.125, 'gini = 0.299\nsamples = 202\nvalue = [165, 37, 0]'),
 Text(0.75, 0.625, 'x[7] <= 1.391\ngini = 0.594\nsamples = 4992\nvalue = [580, 2058, 2354]'),
 Text(0.625, 0.375, 'x[5] <= 0.377\ngini = 0.605\nsamples = 4081\nvalue = [580, 1987, 1514]'),
 Text(0.5625, 0.125, 'gini = 0.516\nsamples = 2392\nvalue = [74, 965, 1353]'),
 Text(0.6875, 0.125, 'gini = 0.535\nsamples = 1689\nvalue = [506, 1022, 161]'),
 Text(0.875, 0.375, 'x[5] <= 1.465\ngini = 0.144\nsamples = 911\nvalue = [0, 71, 840]'),
 Text(0.8125, 0.125, 'gini = 0.037\nsamples = 699\nvalue = [0, 13, 686]'),
 Text(0.9375, 0.125, 'gini = 0.397\nsamples = 212\nvalue = [0, 58, 154]')]
In [23]:
clf = DecisionTreeClassifier(random_state=42)
params = { 
    "max_depth": range(5,41,5), 
    "criterion": ["gini", "entropy"]
}
In [24]:
model_dt =  GridSearchCV(estimator=clf, #
    param_grid=params,
    cv=2, 
    n_jobs=-1,
    verbose=1
)
In [25]:
model_dt.fit(X_train, y_train) # model fit
Fitting 2 folds for each of 16 candidates, totalling 32 fits
Out[25]:
GridSearchCV(cv=2, estimator=DecisionTreeClassifier(random_state=42), n_jobs=-1,
             param_grid={'criterion': ['gini', 'entropy'],
                         'max_depth': range(5, 41, 5)},
             verbose=1)
In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.
GridSearchCV(cv=2, estimator=DecisionTreeClassifier(random_state=42), n_jobs=-1,
             param_grid={'criterion': ['gini', 'entropy'],
                         'max_depth': range(5, 41, 5)},
             verbose=1)
DecisionTreeClassifier(random_state=42)
DecisionTreeClassifier(random_state=42)
In [26]:
cv_results_dt = pd.DataFrame(model_dt.cv_results_)
cv_results_dt.sort_values("rank_test_score").head(1)
Out[26]:
mean_fit_time std_fit_time mean_score_time std_score_time param_criterion param_max_depth params split0_test_score split1_test_score mean_test_score std_test_score rank_test_score
9 0.029328 0.003273 0.005254 0.002253 entropy 10 {'criterion': 'entropy', 'max_depth': 10} 0.972613 0.973869 0.973241 0.000628 1
In [27]:
best_params = model_dt.best_params_
best_maxDepth = best_params['max_depth']
best_Criterion = best_params['criterion']

print(f'The best Max Depth : {best_maxDepth}')
print(f'The best Criterion : {best_Criterion}')
The best Max Depth : 10
The best Criterion : entropy
In [28]:
#Using the Decision Tree Classifier with splitting criterion as Gini impurity, the maximum depth of the tree is 3.
clf_gini = DecisionTreeClassifier(criterion=best_Criterion, max_depth=best_maxDepth, random_state=0)


# fit the model
clf_gini.fit(X_train, y_train)
#Plot the tree
plt.figure(figsize=(12,8))

tree.plot_tree(clf_gini.fit(X_train, y_train)) 
Out[28]:
[Text(0.5950120192307692, 0.9545454545454546, 'x[7] <= -0.12\nentropy = 1.584\nsamples = 7960\nvalue = [2723, 2720, 2517]'),
 Text(0.3144230769230769, 0.8636363636363636, 'x[5] <= -0.792\nentropy = 1.31\nsamples = 4255\nvalue = [2579, 1189, 487]'),
 Text(0.17307692307692307, 0.7727272727272727, 'x[7] <= -0.857\nentropy = 1.56\nsamples = 1968\nvalue = [748, 733, 487]'),
 Text(0.05480769230769231, 0.6818181818181818, 'x[7] <= -1.113\nentropy = 0.848\nsamples = 1012\nvalue = [734, 278, 0]'),
 Text(0.023076923076923078, 0.5909090909090909, 'x[7] <= -1.174\nentropy = 0.162\nsamples = 548\nvalue = [535, 13, 0]'),
 Text(0.015384615384615385, 0.5, 'entropy = 0.0\nsamples = 455\nvalue = [455, 0, 0]'),
 Text(0.03076923076923077, 0.5, 'x[5] <= -1.627\nentropy = 0.584\nsamples = 93\nvalue = [80, 13, 0]'),
 Text(0.023076923076923078, 0.4090909090909091, 'x[7] <= -1.163\nentropy = 0.567\nsamples = 15\nvalue = [2, 13, 0]'),
 Text(0.015384615384615385, 0.3181818181818182, 'x[5] <= -1.691\nentropy = 1.0\nsamples = 4\nvalue = [2, 2, 0]'),
 Text(0.007692307692307693, 0.22727272727272727, 'entropy = 0.0\nsamples = 2\nvalue = [0, 2, 0]'),
 Text(0.023076923076923078, 0.22727272727272727, 'entropy = 0.0\nsamples = 2\nvalue = [2, 0, 0]'),
 Text(0.03076923076923077, 0.3181818181818182, 'entropy = 0.0\nsamples = 11\nvalue = [0, 11, 0]'),
 Text(0.038461538461538464, 0.4090909090909091, 'entropy = 0.0\nsamples = 78\nvalue = [78, 0, 0]'),
 Text(0.08653846153846154, 0.5909090909090909, 'x[5] <= -1.388\nentropy = 0.985\nsamples = 464\nvalue = [199, 265, 0]'),
 Text(0.06153846153846154, 0.5, 'x[7] <= -1.049\nentropy = 0.319\nsamples = 207\nvalue = [12, 195, 0]'),
 Text(0.05384615384615385, 0.4090909090909091, 'x[5] <= -1.527\nentropy = 0.772\nsamples = 53\nvalue = [12, 41, 0]'),
 Text(0.046153846153846156, 0.3181818181818182, 'entropy = 0.0\nsamples = 33\nvalue = [0, 33, 0]'),
 Text(0.06153846153846154, 0.3181818181818182, 'x[5] <= -1.458\nentropy = 0.971\nsamples = 20\nvalue = [12, 8, 0]'),
 Text(0.05384615384615385, 0.22727272727272727, 'x[7] <= -1.089\nentropy = 0.918\nsamples = 12\nvalue = [4, 8, 0]'),
 Text(0.046153846153846156, 0.13636363636363635, 'entropy = 0.0\nsamples = 4\nvalue = [4, 0, 0]'),
 Text(0.06153846153846154, 0.13636363636363635, 'entropy = 0.0\nsamples = 8\nvalue = [0, 8, 0]'),
 Text(0.06923076923076923, 0.22727272727272727, 'entropy = 0.0\nsamples = 8\nvalue = [8, 0, 0]'),
 Text(0.06923076923076923, 0.4090909090909091, 'entropy = 0.0\nsamples = 154\nvalue = [0, 154, 0]'),
 Text(0.11153846153846154, 0.5, 'x[7] <= -0.989\nentropy = 0.845\nsamples = 257\nvalue = [187, 70, 0]'),
 Text(0.08461538461538462, 0.4090909090909091, 'x[0] <= 46.5\nentropy = 0.074\nsamples = 111\nvalue = [110, 1, 0]'),
 Text(0.07692307692307693, 0.3181818181818182, 'entropy = 0.0\nsamples = 106\nvalue = [106, 0, 0]'),
 Text(0.09230769230769231, 0.3181818181818182, 'x[1] <= 0.5\nentropy = 0.722\nsamples = 5\nvalue = [4, 1, 0]'),
 Text(0.08461538461538462, 0.22727272727272727, 'entropy = 0.0\nsamples = 1\nvalue = [0, 1, 0]'),
 Text(0.1, 0.22727272727272727, 'entropy = 0.0\nsamples = 4\nvalue = [4, 0, 0]'),
 Text(0.13846153846153847, 0.4090909090909091, 'x[5] <= -1.086\nentropy = 0.998\nsamples = 146\nvalue = [77, 69, 0]'),
 Text(0.12307692307692308, 0.3181818181818182, 'x[7] <= -0.935\nentropy = 0.529\nsamples = 75\nvalue = [9, 66, 0]'),
 Text(0.11538461538461539, 0.22727272727272727, 'x[5] <= -1.255\nentropy = 0.966\nsamples = 23\nvalue = [9, 14, 0]'),
 Text(0.1076923076923077, 0.13636363636363635, 'entropy = 0.0\nsamples = 13\nvalue = [0, 13, 0]'),
 Text(0.12307692307692308, 0.13636363636363635, 'x[7] <= -0.941\nentropy = 0.469\nsamples = 10\nvalue = [9, 1, 0]'),
 Text(0.11538461538461539, 0.045454545454545456, 'entropy = 0.0\nsamples = 8\nvalue = [8, 0, 0]'),
 Text(0.13076923076923078, 0.045454545454545456, 'entropy = 1.0\nsamples = 2\nvalue = [1, 1, 0]'),
 Text(0.13076923076923078, 0.22727272727272727, 'entropy = 0.0\nsamples = 52\nvalue = [0, 52, 0]'),
 Text(0.15384615384615385, 0.3181818181818182, 'x[7] <= -0.868\nentropy = 0.253\nsamples = 71\nvalue = [68, 3, 0]'),
 Text(0.14615384615384616, 0.22727272727272727, 'entropy = 0.0\nsamples = 62\nvalue = [62, 0, 0]'),
 Text(0.16153846153846155, 0.22727272727272727, 'x[5] <= -1.016\nentropy = 0.918\nsamples = 9\nvalue = [6, 3, 0]'),
 Text(0.15384615384615385, 0.13636363636363635, 'entropy = 0.0\nsamples = 3\nvalue = [0, 3, 0]'),
 Text(0.16923076923076924, 0.13636363636363635, 'entropy = 0.0\nsamples = 6\nvalue = [6, 0, 0]'),
 Text(0.29134615384615387, 0.6818181818181818, 'x[7] <= -0.694\nentropy = 1.095\nsamples = 956\nvalue = [14, 455, 487]'),
 Text(0.2423076923076923, 0.5909090909090909, 'x[5] <= -1.53\nentropy = 0.954\nsamples = 274\nvalue = [14, 211, 49]'),
 Text(0.2230769230769231, 0.5, 'x[7] <= -0.773\nentropy = 0.893\nsamples = 71\nvalue = [0, 22, 49]'),
 Text(0.2076923076923077, 0.4090909090909091, 'x[5] <= -1.633\nentropy = 0.98\nsamples = 36\nvalue = [0, 21, 15]'),
 Text(0.2, 0.3181818181818182, 'x[7] <= -0.837\nentropy = 0.742\nsamples = 19\nvalue = [0, 4, 15]'),
 Text(0.19230769230769232, 0.22727272727272727, 'x[5] <= -1.678\nentropy = 0.985\nsamples = 7\nvalue = [0, 4, 3]'),
 Text(0.18461538461538463, 0.13636363636363635, 'entropy = 0.0\nsamples = 3\nvalue = [0, 0, 3]'),
 Text(0.2, 0.13636363636363635, 'entropy = 0.0\nsamples = 4\nvalue = [0, 4, 0]'),
 Text(0.2076923076923077, 0.22727272727272727, 'entropy = 0.0\nsamples = 12\nvalue = [0, 0, 12]'),
 Text(0.2153846153846154, 0.3181818181818182, 'entropy = 0.0\nsamples = 17\nvalue = [0, 17, 0]'),
 Text(0.23846153846153847, 0.4090909090909091, 'x[0] <= 2.5\nentropy = 0.187\nsamples = 35\nvalue = [0, 1, 34]'),
 Text(0.23076923076923078, 0.3181818181818182, 'x[7] <= -0.737\nentropy = 1.0\nsamples = 2\nvalue = [0, 1, 1]'),
 Text(0.2230769230769231, 0.22727272727272727, 'entropy = 0.0\nsamples = 1\nvalue = [0, 1, 0]'),
 Text(0.23846153846153847, 0.22727272727272727, 'entropy = 0.0\nsamples = 1\nvalue = [0, 0, 1]'),
 Text(0.24615384615384617, 0.3181818181818182, 'entropy = 0.0\nsamples = 33\nvalue = [0, 0, 33]'),
 Text(0.26153846153846155, 0.5, 'x[5] <= -0.946\nentropy = 0.362\nsamples = 203\nvalue = [14, 189, 0]'),
 Text(0.25384615384615383, 0.4090909090909091, 'entropy = 0.0\nsamples = 162\nvalue = [0, 162, 0]'),
 Text(0.2692307692307692, 0.4090909090909091, 'x[7] <= -0.787\nentropy = 0.926\nsamples = 41\nvalue = [14, 27, 0]'),
 Text(0.26153846153846155, 0.3181818181818182, 'x[5] <= -0.912\nentropy = 0.672\nsamples = 17\nvalue = [14, 3, 0]'),
 Text(0.25384615384615383, 0.22727272727272727, 'x[1] <= 4.5\nentropy = 0.811\nsamples = 4\nvalue = [1, 3, 0]'),
 Text(0.24615384615384617, 0.13636363636363635, 'entropy = 0.0\nsamples = 3\nvalue = [0, 3, 0]'),
 Text(0.26153846153846155, 0.13636363636363635, 'entropy = 0.0\nsamples = 1\nvalue = [1, 0, 0]'),
 Text(0.2692307692307692, 0.22727272727272727, 'entropy = 0.0\nsamples = 13\nvalue = [13, 0, 0]'),
 Text(0.27692307692307694, 0.3181818181818182, 'entropy = 0.0\nsamples = 24\nvalue = [0, 24, 0]'),
 Text(0.3403846153846154, 0.5909090909090909, 'x[5] <= -1.199\nentropy = 0.941\nsamples = 682\nvalue = [0, 244, 438]'),
 Text(0.3153846153846154, 0.5, 'x[7] <= -0.514\nentropy = 0.512\nsamples = 377\nvalue = [0, 43, 334]'),
 Text(0.3076923076923077, 0.4090909090909091, 'x[5] <= -1.435\nentropy = 0.778\nsamples = 187\nvalue = [0, 43, 144]'),
 Text(0.2923076923076923, 0.3181818181818182, 'x[7] <= -0.686\nentropy = 0.067\nsamples = 126\nvalue = [0, 1, 125]'),
 Text(0.2846153846153846, 0.22727272727272727, 'x[7] <= -0.687\nentropy = 0.469\nsamples = 10\nvalue = [0, 1, 9]'),
 Text(0.27692307692307694, 0.13636363636363635, 'entropy = 0.0\nsamples = 9\nvalue = [0, 0, 9]'),
 Text(0.2923076923076923, 0.13636363636363635, 'entropy = 0.0\nsamples = 1\nvalue = [0, 1, 0]'),
 Text(0.3, 0.22727272727272727, 'entropy = 0.0\nsamples = 116\nvalue = [0, 0, 116]'),
 Text(0.3230769230769231, 0.3181818181818182, 'x[5] <= -1.301\nentropy = 0.895\nsamples = 61\nvalue = [0, 42, 19]'),
 Text(0.3153846153846154, 0.22727272727272727, 'x[7] <= -0.585\nentropy = 0.983\nsamples = 33\nvalue = [0, 14, 19]'),
 Text(0.3076923076923077, 0.13636363636363635, 'x[0] <= 37.5\nentropy = 0.544\nsamples = 16\nvalue = [0, 14, 2]'),
 Text(0.3, 0.045454545454545456, 'entropy = 0.0\nsamples = 11\nvalue = [0, 11, 0]'),
 Text(0.3153846153846154, 0.045454545454545456, 'entropy = 0.971\nsamples = 5\nvalue = [0, 3, 2]'),
 Text(0.3230769230769231, 0.13636363636363635, 'entropy = 0.0\nsamples = 17\nvalue = [0, 0, 17]'),
 Text(0.33076923076923076, 0.22727272727272727, 'entropy = 0.0\nsamples = 28\nvalue = [0, 28, 0]'),
 Text(0.3230769230769231, 0.4090909090909091, 'entropy = 0.0\nsamples = 190\nvalue = [0, 0, 190]'),
 Text(0.36538461538461536, 0.5, 'x[7] <= -0.357\nentropy = 0.926\nsamples = 305\nvalue = [0, 201, 104]'),
 Text(0.34615384615384615, 0.4090909090909091, 'x[7] <= -0.427\nentropy = 0.186\nsamples = 176\nvalue = [0, 171, 5]'),
 Text(0.3384615384615385, 0.3181818181818182, 'entropy = 0.0\nsamples = 142\nvalue = [0, 142, 0]'),
 Text(0.35384615384615387, 0.3181818181818182, 'x[5] <= -1.119\nentropy = 0.602\nsamples = 34\nvalue = [0, 29, 5]'),
 Text(0.34615384615384615, 0.22727272727272727, 'entropy = 0.0\nsamples = 5\nvalue = [0, 0, 5]'),
 Text(0.36153846153846153, 0.22727272727272727, 'entropy = 0.0\nsamples = 29\nvalue = [0, 29, 0]'),
 Text(0.38461538461538464, 0.4090909090909091, 'x[5] <= -0.977\nentropy = 0.782\nsamples = 129\nvalue = [0, 30, 99]'),
 Text(0.3769230769230769, 0.3181818181818182, 'entropy = 0.0\nsamples = 70\nvalue = [0, 0, 70]'),
 Text(0.3923076923076923, 0.3181818181818182, 'x[7] <= -0.234\nentropy = 1.0\nsamples = 59\nvalue = [0, 30, 29]'),
 Text(0.3769230769230769, 0.22727272727272727, 'x[6] <= -1.632\nentropy = 0.235\nsamples = 26\nvalue = [0, 25, 1]'),
 Text(0.36923076923076925, 0.13636363636363635, 'entropy = 0.0\nsamples = 1\nvalue = [0, 0, 1]'),
 Text(0.38461538461538464, 0.13636363636363635, 'entropy = 0.0\nsamples = 25\nvalue = [0, 25, 0]'),
 Text(0.4076923076923077, 0.22727272727272727, 'x[7] <= -0.176\nentropy = 0.614\nsamples = 33\nvalue = [0, 5, 28]'),
 Text(0.4, 0.13636363636363635, 'x[5] <= -0.875\nentropy = 0.918\nsamples = 15\nvalue = [0, 5, 10]'),
 Text(0.3923076923076923, 0.045454545454545456, 'entropy = 0.0\nsamples = 10\nvalue = [0, 0, 10]'),
 Text(0.4076923076923077, 0.045454545454545456, 'entropy = 0.0\nsamples = 5\nvalue = [0, 5, 0]'),
 Text(0.4153846153846154, 0.13636363636363635, 'entropy = 0.0\nsamples = 18\nvalue = [0, 0, 18]'),
 Text(0.45576923076923076, 0.7727272727272727, 'x[7] <= -0.648\nentropy = 0.721\nsamples = 2287\nvalue = [1831, 456, 0]'),
 Text(0.4307692307692308, 0.6818181818181818, 'x[5] <= -0.669\nentropy = 0.056\nsamples = 1238\nvalue = [1230, 8, 0]'),
 Text(0.4230769230769231, 0.5909090909090909, 'x[7] <= -0.725\nentropy = 0.389\nsamples = 105\nvalue = [97, 8, 0]'),
 Text(0.4153846153846154, 0.5, 'entropy = 0.0\nsamples = 97\nvalue = [97, 0, 0]'),
 Text(0.4307692307692308, 0.5, 'entropy = 0.0\nsamples = 8\nvalue = [0, 8, 0]'),
 Text(0.43846153846153846, 0.5909090909090909, 'entropy = 0.0\nsamples = 1133\nvalue = [1133, 0, 0]'),
 Text(0.4807692307692308, 0.6818181818181818, 'x[5] <= 0.396\nentropy = 0.985\nsamples = 1049\nvalue = [601, 448, 0]'),
 Text(0.46153846153846156, 0.5909090909090909, 'x[7] <= -0.44\nentropy = 0.841\nsamples = 601\nvalue = [162, 439, 0]'),
 Text(0.4461538461538462, 0.5, 'x[5] <= -0.163\nentropy = 0.987\nsamples = 235\nvalue = [133, 102, 0]'),
 Text(0.43846153846153846, 0.4090909090909091, 'x[5] <= -0.441\nentropy = 0.793\nsamples = 134\nvalue = [32, 102, 0]'),
 Text(0.4307692307692308, 0.3181818181818182, 'entropy = 0.0\nsamples = 71\nvalue = [0, 71, 0]'),
 Text(0.4461538461538462, 0.3181818181818182, 'x[7] <= -0.519\nentropy = 1.0\nsamples = 63\nvalue = [32, 31, 0]'),
 Text(0.43846153846153846, 0.22727272727272727, 'x[5] <= -0.35\nentropy = 0.571\nsamples = 37\nvalue = [32, 5, 0]'),
 Text(0.4307692307692308, 0.13636363636363635, 'x[7] <= -0.586\nentropy = 0.98\nsamples = 12\nvalue = [7, 5, 0]'),
 Text(0.4230769230769231, 0.045454545454545456, 'entropy = 0.0\nsamples = 7\nvalue = [7, 0, 0]'),
 Text(0.43846153846153846, 0.045454545454545456, 'entropy = 0.0\nsamples = 5\nvalue = [0, 5, 0]'),
 Text(0.4461538461538462, 0.13636363636363635, 'entropy = 0.0\nsamples = 25\nvalue = [25, 0, 0]'),
 Text(0.45384615384615384, 0.22727272727272727, 'entropy = 0.0\nsamples = 26\nvalue = [0, 26, 0]'),
 Text(0.45384615384615384, 0.4090909090909091, 'entropy = 0.0\nsamples = 101\nvalue = [101, 0, 0]'),
 Text(0.47692307692307695, 0.5, 'x[5] <= 0.001\nentropy = 0.399\nsamples = 366\nvalue = [29, 337, 0]'),
 Text(0.46923076923076923, 0.4090909090909091, 'entropy = 0.0\nsamples = 261\nvalue = [0, 261, 0]'),
 Text(0.4846153846153846, 0.4090909090909091, 'x[7] <= -0.355\nentropy = 0.85\nsamples = 105\nvalue = [29, 76, 0]'),
 Text(0.47692307692307695, 0.3181818181818182, 'entropy = 0.0\nsamples = 19\nvalue = [19, 0, 0]'),
 Text(0.49230769230769234, 0.3181818181818182, 'x[7] <= -0.265\nentropy = 0.519\nsamples = 86\nvalue = [10, 76, 0]'),
 Text(0.4846153846153846, 0.22727272727272727, 'x[5] <= 0.234\nentropy = 0.885\nsamples = 33\nvalue = [10, 23, 0]'),
 Text(0.47692307692307695, 0.13636363636363635, 'x[7] <= -0.343\nentropy = 0.25\nsamples = 24\nvalue = [1, 23, 0]'),
 Text(0.46923076923076923, 0.045454545454545456, 'entropy = 0.918\nsamples = 3\nvalue = [1, 2, 0]'),
 Text(0.4846153846153846, 0.045454545454545456, 'entropy = 0.0\nsamples = 21\nvalue = [0, 21, 0]'),
 Text(0.49230769230769234, 0.13636363636363635, 'entropy = 0.0\nsamples = 9\nvalue = [9, 0, 0]'),
 Text(0.5, 0.22727272727272727, 'entropy = 0.0\nsamples = 53\nvalue = [0, 53, 0]'),
 Text(0.5, 0.5909090909090909, 'x[7] <= -0.179\nentropy = 0.142\nsamples = 448\nvalue = [439, 9, 0]'),
 Text(0.49230769230769234, 0.5, 'entropy = 0.0\nsamples = 397\nvalue = [397, 0, 0]'),
 Text(0.5076923076923077, 0.5, 'x[5] <= 0.575\nentropy = 0.672\nsamples = 51\nvalue = [42, 9, 0]'),
 Text(0.5, 0.4090909090909091, 'entropy = 0.0\nsamples = 9\nvalue = [0, 9, 0]'),
 Text(0.5153846153846153, 0.4090909090909091, 'entropy = 0.0\nsamples = 42\nvalue = [42, 0, 0]'),
 Text(0.8756009615384616, 0.8636363636363636, 'x[7] <= 1.454\nentropy = 1.185\nsamples = 3705\nvalue = [144, 1531, 2030]'),
 Text(0.7819711538461539, 0.7727272727272727, 'x[5] <= 0.719\nentropy = 1.232\nsamples = 2873\nvalue = [144, 1483, 1246]'),
 Text(0.6793269230769231, 0.6818181818181818, 'x[7] <= 0.633\nentropy = 0.951\nsamples = 1885\nvalue = [0, 699, 1186]'),
 Text(0.6048076923076923, 0.5909090909090909, 'x[5] <= -0.24\nentropy = 1.0\nsamples = 1248\nvalue = [0, 639, 609]'),
 Text(0.5557692307692308, 0.5, 'x[5] <= -0.608\nentropy = 0.642\nsamples = 644\nvalue = [0, 105, 539]'),
 Text(0.5307692307692308, 0.4090909090909091, 'x[7] <= -0.066\nentropy = 0.094\nsamples = 331\nvalue = [0, 4, 327]'),
 Text(0.5230769230769231, 0.3181818181818182, 'x[5] <= -0.734\nentropy = 0.469\nsamples = 40\nvalue = [0, 4, 36]'),
 Text(0.5153846153846153, 0.22727272727272727, 'entropy = 0.0\nsamples = 36\nvalue = [0, 0, 36]'),
 Text(0.5307692307692308, 0.22727272727272727, 'entropy = 0.0\nsamples = 4\nvalue = [0, 4, 0]'),
 Text(0.5384615384615384, 0.3181818181818182, 'entropy = 0.0\nsamples = 291\nvalue = [0, 0, 291]'),
 Text(0.5807692307692308, 0.4090909090909091, 'x[7] <= 0.159\nentropy = 0.907\nsamples = 313\nvalue = [0, 101, 212]'),
 Text(0.5615384615384615, 0.3181818181818182, 'x[5] <= -0.51\nentropy = 0.639\nsamples = 111\nvalue = [0, 93, 18]'),
 Text(0.5461538461538461, 0.22727272727272727, 'x[7] <= 0.04\nentropy = 0.999\nsamples = 33\nvalue = [0, 17, 16]'),
 Text(0.5384615384615384, 0.13636363636363635, 'entropy = 0.0\nsamples = 16\nvalue = [0, 16, 0]'),
 Text(0.5538461538461539, 0.13636363636363635, 'x[7] <= 0.046\nentropy = 0.323\nsamples = 17\nvalue = [0, 1, 16]'),
 Text(0.5461538461538461, 0.045454545454545456, 'entropy = 1.0\nsamples = 2\nvalue = [0, 1, 1]'),
 Text(0.5615384615384615, 0.045454545454545456, 'entropy = 0.0\nsamples = 15\nvalue = [0, 0, 15]'),
 Text(0.5769230769230769, 0.22727272727272727, 'x[7] <= 0.135\nentropy = 0.172\nsamples = 78\nvalue = [0, 76, 2]'),
 Text(0.5692307692307692, 0.13636363636363635, 'entropy = 0.0\nsamples = 68\nvalue = [0, 68, 0]'),
 Text(0.5846153846153846, 0.13636363636363635, 'x[5] <= -0.422\nentropy = 0.722\nsamples = 10\nvalue = [0, 8, 2]'),
 Text(0.5769230769230769, 0.045454545454545456, 'entropy = 0.0\nsamples = 2\nvalue = [0, 0, 2]'),
 Text(0.5923076923076923, 0.045454545454545456, 'entropy = 0.0\nsamples = 8\nvalue = [0, 8, 0]'),
 Text(0.6, 0.3181818181818182, 'x[5] <= -0.308\nentropy = 0.24\nsamples = 202\nvalue = [0, 8, 194]'),
 Text(0.5923076923076923, 0.22727272727272727, 'entropy = 0.0\nsamples = 167\nvalue = [0, 0, 167]'),
 Text(0.6076923076923076, 0.22727272727272727, 'x[7] <= 0.261\nentropy = 0.776\nsamples = 35\nvalue = [0, 8, 27]'),
 Text(0.6, 0.13636363636363635, 'entropy = 0.0\nsamples = 8\nvalue = [0, 8, 0]'),
 Text(0.6153846153846154, 0.13636363636363635, 'entropy = 0.0\nsamples = 27\nvalue = [0, 0, 27]'),
 Text(0.6538461538461539, 0.5, 'x[7] <= 0.37\nentropy = 0.517\nsamples = 604\nvalue = [0, 534, 70]'),
 Text(0.6230769230769231, 0.4090909090909091, 'x[7] <= 0.351\nentropy = 0.026\nsamples = 378\nvalue = [0, 377, 1]'),
 Text(0.6153846153846154, 0.3181818181818182, 'entropy = 0.0\nsamples = 359\nvalue = [0, 359, 0]'),
 Text(0.6307692307692307, 0.3181818181818182, 'x[7] <= 0.353\nentropy = 0.297\nsamples = 19\nvalue = [0, 18, 1]'),
 Text(0.6230769230769231, 0.22727272727272727, 'entropy = 0.0\nsamples = 1\nvalue = [0, 0, 1]'),
 Text(0.6384615384615384, 0.22727272727272727, 'entropy = 0.0\nsamples = 18\nvalue = [0, 18, 0]'),
 Text(0.6846153846153846, 0.4090909090909091, 'x[5] <= -0.005\nentropy = 0.888\nsamples = 226\nvalue = [0, 157, 69]'),
 Text(0.6615384615384615, 0.3181818181818182, 'x[5] <= -0.025\nentropy = 0.124\nsamples = 59\nvalue = [0, 1, 58]'),
 Text(0.6538461538461539, 0.22727272727272727, 'entropy = 0.0\nsamples = 55\nvalue = [0, 0, 55]'),
 Text(0.6692307692307692, 0.22727272727272727, 'x[2] <= 16.0\nentropy = 0.811\nsamples = 4\nvalue = [0, 1, 3]'),
 Text(0.6615384615384615, 0.13636363636363635, 'entropy = 0.0\nsamples = 3\nvalue = [0, 0, 3]'),
 Text(0.676923076923077, 0.13636363636363635, 'entropy = 0.0\nsamples = 1\nvalue = [0, 1, 0]'),
 Text(0.7076923076923077, 0.3181818181818182, 'x[5] <= 0.156\nentropy = 0.35\nsamples = 167\nvalue = [0, 156, 11]'),
 Text(0.7, 0.22727272727272727, 'x[7] <= 0.562\nentropy = 0.858\nsamples = 39\nvalue = [0, 28, 11]'),
 Text(0.6923076923076923, 0.13636363636363635, 'entropy = 0.0\nsamples = 28\nvalue = [0, 28, 0]'),
 Text(0.7076923076923077, 0.13636363636363635, 'entropy = 0.0\nsamples = 11\nvalue = [0, 0, 11]'),
 Text(0.7153846153846154, 0.22727272727272727, 'entropy = 0.0\nsamples = 128\nvalue = [0, 128, 0]'),
 Text(0.7538461538461538, 0.5909090909090909, 'x[5] <= 0.324\nentropy = 0.45\nsamples = 637\nvalue = [0, 60, 577]'),
 Text(0.7384615384615385, 0.5, 'x[7] <= 0.651\nentropy = 0.045\nsamples = 409\nvalue = [0, 2, 407]'),
 Text(0.7307692307692307, 0.4090909090909091, 'x[5] <= 0.215\nentropy = 0.503\nsamples = 18\nvalue = [0, 2, 16]'),
 Text(0.7230769230769231, 0.3181818181818182, 'entropy = 0.0\nsamples = 16\nvalue = [0, 0, 16]'),
 Text(0.7384615384615385, 0.3181818181818182, 'entropy = 0.0\nsamples = 2\nvalue = [0, 2, 0]'),
 Text(0.7461538461538462, 0.4090909090909091, 'entropy = 0.0\nsamples = 391\nvalue = [0, 0, 391]'),
 Text(0.7692307692307693, 0.5, 'x[7] <= 0.985\nentropy = 0.818\nsamples = 228\nvalue = [0, 58, 170]'),
 Text(0.7615384615384615, 0.4090909090909091, 'x[5] <= 0.571\nentropy = 0.849\nsamples = 80\nvalue = [0, 58, 22]'),
 Text(0.7538461538461538, 0.3181818181818182, 'x[7] <= 0.781\nentropy = 1.0\nsamples = 44\nvalue = [0, 22, 22]'),
 Text(0.7461538461538462, 0.22727272727272727, 'entropy = 0.0\nsamples = 18\nvalue = [0, 18, 0]'),
 Text(0.7615384615384615, 0.22727272727272727, 'x[5] <= 0.453\nentropy = 0.619\nsamples = 26\nvalue = [0, 4, 22]'),
 Text(0.7538461538461538, 0.13636363636363635, 'entropy = 0.0\nsamples = 15\nvalue = [0, 0, 15]'),
 Text(0.7692307692307693, 0.13636363636363635, 'x[7] <= 0.9\nentropy = 0.946\nsamples = 11\nvalue = [0, 4, 7]'),
 Text(0.7615384615384615, 0.045454545454545456, 'entropy = 0.722\nsamples = 5\nvalue = [0, 4, 1]'),
 Text(0.7769230769230769, 0.045454545454545456, 'entropy = 0.0\nsamples = 6\nvalue = [0, 0, 6]'),
 Text(0.7692307692307693, 0.3181818181818182, 'entropy = 0.0\nsamples = 36\nvalue = [0, 36, 0]'),
 Text(0.7769230769230769, 0.4090909090909091, 'entropy = 0.0\nsamples = 148\nvalue = [0, 0, 148]'),
 Text(0.8846153846153846, 0.6818181818181818, 'x[7] <= 0.317\nentropy = 0.915\nsamples = 988\nvalue = [144, 784, 60]'),
 Text(0.8538461538461538, 0.5909090909090909, 'x[5] <= 1.376\nentropy = 0.99\nsamples = 258\nvalue = [144, 114, 0]'),
 Text(0.8384615384615385, 0.5, 'x[7] <= 0.147\nentropy = 0.901\nsamples = 161\nvalue = [51, 110, 0]'),
 Text(0.8307692307692308, 0.4090909090909091, 'x[5] <= 1.102\nentropy = 0.991\nsamples = 92\nvalue = [51, 41, 0]'),
 Text(0.823076923076923, 0.3181818181818182, 'x[7] <= 0.038\nentropy = 0.924\nsamples = 62\nvalue = [21, 41, 0]'),
 Text(0.8153846153846154, 0.22727272727272727, 'x[5] <= 0.853\nentropy = 0.928\nsamples = 32\nvalue = [21, 11, 0]'),
 Text(0.8, 0.13636363636363635, 'x[7] <= -0.07\nentropy = 0.65\nsamples = 12\nvalue = [2, 10, 0]'),
 Text(0.7923076923076923, 0.045454545454545456, 'entropy = 0.0\nsamples = 2\nvalue = [2, 0, 0]'),
 Text(0.8076923076923077, 0.045454545454545456, 'entropy = 0.0\nsamples = 10\nvalue = [0, 10, 0]'),
 Text(0.8307692307692308, 0.13636363636363635, 'x[7] <= 0.025\nentropy = 0.286\nsamples = 20\nvalue = [19, 1, 0]'),
 Text(0.823076923076923, 0.045454545454545456, 'entropy = 0.0\nsamples = 18\nvalue = [18, 0, 0]'),
 Text(0.8384615384615385, 0.045454545454545456, 'entropy = 1.0\nsamples = 2\nvalue = [1, 1, 0]'),
 Text(0.8307692307692308, 0.22727272727272727, 'entropy = 0.0\nsamples = 30\nvalue = [0, 30, 0]'),
 Text(0.8384615384615385, 0.3181818181818182, 'entropy = 0.0\nsamples = 30\nvalue = [30, 0, 0]'),
 Text(0.8461538461538461, 0.4090909090909091, 'entropy = 0.0\nsamples = 69\nvalue = [0, 69, 0]'),
 Text(0.8692307692307693, 0.5, 'x[7] <= 0.275\nentropy = 0.248\nsamples = 97\nvalue = [93, 4, 0]'),
 Text(0.8615384615384616, 0.4090909090909091, 'entropy = 0.0\nsamples = 87\nvalue = [87, 0, 0]'),
 Text(0.8769230769230769, 0.4090909090909091, 'x[5] <= 1.558\nentropy = 0.971\nsamples = 10\nvalue = [6, 4, 0]'),
 Text(0.8692307692307693, 0.3181818181818182, 'entropy = 0.0\nsamples = 4\nvalue = [0, 4, 0]'),
 Text(0.8846153846153846, 0.3181818181818182, 'entropy = 0.0\nsamples = 6\nvalue = [6, 0, 0]'),
 Text(0.9153846153846154, 0.5909090909090909, 'x[7] <= 1.107\nentropy = 0.41\nsamples = 730\nvalue = [0, 670, 60]'),
 Text(0.9076923076923077, 0.5, 'entropy = 0.0\nsamples = 504\nvalue = [0, 504, 0]'),
 Text(0.9230769230769231, 0.5, 'x[5] <= 0.949\nentropy = 0.835\nsamples = 226\nvalue = [0, 166, 60]'),
 Text(0.9076923076923077, 0.4090909090909091, 'x[7] <= 1.162\nentropy = 0.149\nsamples = 47\nvalue = [0, 1, 46]'),
 Text(0.9, 0.3181818181818182, 'x[5] <= 0.85\nentropy = 0.722\nsamples = 5\nvalue = [0, 1, 4]'),
 Text(0.8923076923076924, 0.22727272727272727, 'entropy = 0.0\nsamples = 4\nvalue = [0, 0, 4]'),
 Text(0.9076923076923077, 0.22727272727272727, 'entropy = 0.0\nsamples = 1\nvalue = [0, 1, 0]'),
 Text(0.9153846153846154, 0.3181818181818182, 'entropy = 0.0\nsamples = 42\nvalue = [0, 0, 42]'),
 Text(0.9384615384615385, 0.4090909090909091, 'x[5] <= 1.172\nentropy = 0.396\nsamples = 179\nvalue = [0, 165, 14]'),
 Text(0.9307692307692308, 0.3181818181818182, 'x[7] <= 1.307\nentropy = 0.902\nsamples = 44\nvalue = [0, 30, 14]'),
 Text(0.9230769230769231, 0.22727272727272727, 'x[6] <= -1.364\nentropy = 0.206\nsamples = 31\nvalue = [0, 30, 1]'),
 Text(0.9153846153846154, 0.13636363636363635, 'x[4] <= 2.5\nentropy = 0.918\nsamples = 3\nvalue = [0, 2, 1]'),
 Text(0.9076923076923077, 0.045454545454545456, 'entropy = 0.0\nsamples = 2\nvalue = [0, 2, 0]'),
 Text(0.9230769230769231, 0.045454545454545456, 'entropy = 0.0\nsamples = 1\nvalue = [0, 0, 1]'),
 Text(0.9307692307692308, 0.13636363636363635, 'entropy = 0.0\nsamples = 28\nvalue = [0, 28, 0]'),
 Text(0.9384615384615385, 0.22727272727272727, 'entropy = 0.0\nsamples = 13\nvalue = [0, 0, 13]'),
 Text(0.9461538461538461, 0.3181818181818182, 'entropy = 0.0\nsamples = 135\nvalue = [0, 135, 0]'),
 Text(0.9692307692307692, 0.7727272727272727, 'x[5] <= 1.431\nentropy = 0.318\nsamples = 832\nvalue = [0, 48, 784]'),
 Text(0.9538461538461539, 0.6818181818181818, 'x[7] <= 1.502\nentropy = 0.018\nsamples = 610\nvalue = [0, 1, 609]'),
 Text(0.9461538461538461, 0.5909090909090909, 'x[5] <= 1.272\nentropy = 0.159\nsamples = 43\nvalue = [0, 1, 42]'),
 Text(0.9384615384615385, 0.5, 'entropy = 0.0\nsamples = 42\nvalue = [0, 0, 42]'),
 Text(0.9538461538461539, 0.5, 'entropy = 0.0\nsamples = 1\nvalue = [0, 1, 0]'),
 Text(0.9615384615384616, 0.5909090909090909, 'entropy = 0.0\nsamples = 567\nvalue = [0, 0, 567]'),
 Text(0.9846153846153847, 0.6818181818181818, 'x[7] <= 1.775\nentropy = 0.745\nsamples = 222\nvalue = [0, 47, 175]'),
 Text(0.9769230769230769, 0.5909090909090909, 'x[7] <= 1.682\nentropy = 0.556\nsamples = 54\nvalue = [0, 47, 7]'),
 Text(0.9692307692307692, 0.5, 'entropy = 0.0\nsamples = 34\nvalue = [0, 34, 0]'),
 Text(0.9846153846153847, 0.5, 'x[5] <= 1.582\nentropy = 0.934\nsamples = 20\nvalue = [0, 13, 7]'),
 Text(0.9769230769230769, 0.4090909090909091, 'entropy = 0.0\nsamples = 7\nvalue = [0, 0, 7]'),
 Text(0.9923076923076923, 0.4090909090909091, 'entropy = 0.0\nsamples = 13\nvalue = [0, 13, 0]'),
 Text(0.9923076923076923, 0.5909090909090909, 'entropy = 0.0\nsamples = 168\nvalue = [0, 0, 168]')]
In [29]:
import joblib
joblib.dump(clf_gini, 'model/decisiontree_model.pkl')
Out[29]:
['model/decisiontree_model.pkl']
In [30]:
#Predict the values 
y_pred_gini = clf_gini.predict(X_test)
#Predict the value using X train for accuracy comparision 
y_pred_train_gini = clf_gini.predict(X_train)
y_pred_train_gini
#Determine the accuracy score
print('Model accuracy score with criterion gini index: {0:0.4f}'. format(accuracy_score(y_test, y_pred_gini)))
#Accuracy Score for training set
print('Training-set accuracy score: {0:0.4f}'. format(accuracy_score(y_train, y_pred_train_gini)))
Model accuracy score with criterion gini index: 0.9769
Training-set accuracy score: 0.9991
In [31]:
from sklearn.metrics import confusion_matrix,classification_report, ConfusionMatrixDisplay
from sklearn.metrics import  f1_score
In [32]:
cm = confusion_matrix(y_test, y_pred_gini)

print('Confusion matrix\n\n', cm)
Confusion matrix

 [[682   6   0]
 [ 15 642  16]
 [  0   9 621]]
In [33]:
cm = confusion_matrix(y_test, y_pred_gini)
disp = ConfusionMatrixDisplay(confusion_matrix=cm)
disp.plot(cmap='Blues', values_format='d')
plt.title("Confusion Matrix: Decision Tree")
plt.show()

print(classification_report(y_test, y_pred_gini))
              precision    recall  f1-score   support

           0       0.98      0.99      0.98       688
           1       0.98      0.95      0.97       673
           2       0.97      0.99      0.98       630

    accuracy                           0.98      1991
   macro avg       0.98      0.98      0.98      1991
weighted avg       0.98      0.98      0.98      1991