Libraries and Dataset¶

In [68]:
import numpy as np
import pandas as pd
import seaborn  as sns
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import LabelEncoder, OneHotEncoder, StandardScaler,MinMaxScaler
from sklearn.model_selection import cross_val_score, GridSearchCV
import matplotlib.pyplot as plt
from sklearn.metrics import confusion_matrix,classification_report, ConfusionMatrixDisplay
In [69]:
dataset = pd.read_csv('data/market_cluster.csv', encoding='latin1')
In [70]:
dataset.head()
Out[70]:
Order ID Customer Name Category Sub Category City Order Date Region Sales Discount Profit State profit_margin Cluster
0 OD1 Harish Oil & Masala Masalas Vellore 11-08-2017 North 1254 0.12 401.28 Tamil Nadu 0.32 Medium
1 OD2 Sudha Beverages Health Drinks Krishnagiri 11-08-2017 South 749 0.18 149.80 Tamil Nadu 0.20 Medium
2 OD3 Hussain Food Grains Atta & Flour Perambalur 06-12-2017 West 2360 0.21 165.20 Tamil Nadu 0.07 Low
3 OD4 Jackson Fruits & Veggies Fresh Vegetables Dharmapuri 10-11-2016 South 896 0.25 89.60 Tamil Nadu 0.10 Low
4 OD5 Ridhesh Food Grains Organic Staples Ooty 10-11-2016 South 2355 0.26 918.45 Tamil Nadu 0.39 High
In [71]:
dataset.info()
<class 'pandas.core.frame.DataFrame'>
RangeIndex: 9994 entries, 0 to 9993
Data columns (total 13 columns):
 #   Column         Non-Null Count  Dtype  
---  ------         --------------  -----  
 0   Order ID       9994 non-null   object 
 1   Customer Name  9994 non-null   object 
 2   Category       9994 non-null   object 
 3   Sub Category   9994 non-null   object 
 4   City           9994 non-null   object 
 5   Order Date     9994 non-null   object 
 6   Region         9994 non-null   object 
 7   Sales          9994 non-null   int64  
 8   Discount       9994 non-null   float64
 9   Profit         9994 non-null   float64
 10  State          9994 non-null   object 
 11  profit_margin  9994 non-null   float64
 12  Cluster        9994 non-null   object 
dtypes: float64(3), int64(1), object(9)
memory usage: 1015.1+ KB
In [72]:
dataset.describe()
Out[72]:
Sales Discount Profit profit_margin
count 9994.000000 9994.000000 9994.000000 9994.000000
mean 1496.596158 0.226817 374.937082 0.250228
std 577.559036 0.074636 239.932881 0.118919
min 500.000000 0.100000 25.250000 0.050000
25% 1000.000000 0.160000 180.022500 0.150000
50% 1498.000000 0.230000 320.780000 0.250000
75% 1994.750000 0.290000 525.627500 0.350000
max 2500.000000 0.350000 1120.950000 0.450000
In [73]:
dataset.info()
<class 'pandas.core.frame.DataFrame'>
RangeIndex: 9994 entries, 0 to 9993
Data columns (total 13 columns):
 #   Column         Non-Null Count  Dtype  
---  ------         --------------  -----  
 0   Order ID       9994 non-null   object 
 1   Customer Name  9994 non-null   object 
 2   Category       9994 non-null   object 
 3   Sub Category   9994 non-null   object 
 4   City           9994 non-null   object 
 5   Order Date     9994 non-null   object 
 6   Region         9994 non-null   object 
 7   Sales          9994 non-null   int64  
 8   Discount       9994 non-null   float64
 9   Profit         9994 non-null   float64
 10  State          9994 non-null   object 
 11  profit_margin  9994 non-null   float64
 12  Cluster        9994 non-null   object 
dtypes: float64(3), int64(1), object(9)
memory usage: 1015.1+ KB

Data Cleaning & Preprocessing¶

In [74]:
dataset.drop(['Order ID'], axis=1, inplace=True)
In [75]:
dataset.isna().sum()
Out[75]:
Customer Name    0
Category         0
Sub Category     0
City             0
Order Date       0
Region           0
Sales            0
Discount         0
Profit           0
State            0
profit_margin    0
Cluster          0
dtype: int64
In [76]:
dataset.dropna(inplace=True)
In [77]:
def remove_outliers(data: pd.DataFrame, column: str) -> pd.Series:
    q3, q1 = np.nanpercentile(data[column], [75, 25])
    iqr = q3 - q1
    upper_bound = q3 + 1.5 * iqr
    lower_bound = q1 - 1.5 * iqr
    data = data[(data[column] > lower_bound) & (data[column] < upper_bound)]

    return data

dataset = remove_outliers(dataset, 'Discount')
dataset = remove_outliers(dataset, 'Sales')
dataset = remove_outliers(dataset, 'Profit')
In [78]:
dataset.head()
Out[78]:
Customer Name Category Sub Category City Order Date Region Sales Discount Profit State profit_margin Cluster
0 Harish Oil & Masala Masalas Vellore 11-08-2017 North 1254 0.12 401.28 Tamil Nadu 0.32 Medium
1 Sudha Beverages Health Drinks Krishnagiri 11-08-2017 South 749 0.18 149.80 Tamil Nadu 0.20 Medium
2 Hussain Food Grains Atta & Flour Perambalur 06-12-2017 West 2360 0.21 165.20 Tamil Nadu 0.07 Low
3 Jackson Fruits & Veggies Fresh Vegetables Dharmapuri 10-11-2016 South 896 0.25 89.60 Tamil Nadu 0.10 Low
4 Ridhesh Food Grains Organic Staples Ooty 10-11-2016 South 2355 0.26 918.45 Tamil Nadu 0.39 High
In [79]:
sns.histplot(dataset['Cluster'])
Out[79]:
<AxesSubplot:xlabel='Cluster', ylabel='Count'>
In [80]:
encoder = LabelEncoder()
scaler = StandardScaler()
onehot = OneHotEncoder()
minmaxscaler = MinMaxScaler()
In [81]:
dataset["Order Date"] = pd.to_datetime(dataset["Order Date"])
dataset["Order Date"] = dataset["Order Date"].dt.month
dataset["Customer Name"] = encoder.fit_transform(dataset["Customer Name"])
dataset["Category"] = encoder.fit_transform(dataset["Category"])
dataset["City"] = encoder.fit_transform(dataset["City"])
dataset["Region"] = encoder.fit_transform(dataset["Region"])
dataset["State"] = encoder.fit_transform(dataset["State"])
dataset["Sub Category"] = encoder.fit_transform(dataset["Sub Category"])

# dataset["Customer Name"] = onehot.fit_transform(dataset["Customer Name"].values.reshape(-1, 1)).toarray()
# dataset["Category"] = onehot.fit_transform(dataset["Category"].values.reshape(-1, 1)).toarray()
# dataset["City"] = onehot.fit_transform(dataset["City"].values.reshape(-1, 1)).toarray()
# dataset["Region"] = onehot.fit_transform(dataset["Region"].values.reshape(-1, 1)).toarray()
# dataset["State"] = onehot.fit_transform(dataset["State"].values.reshape(-1, 1)).toarray()
# dataset["Sub Category"] = onehot.fit_transform(dataset["Sub Category"].values.reshape(-1, 1)).toarray()

dataset["Order Date"] = encoder.fit_transform(dataset["Order Date"])
In [82]:
dataset[["Sales", "Discount", "profit_margin"]] = scaler.fit_transform(dataset[["Sales", "Discount", "profit_margin"]])
dataset["Profit"] = minmaxscaler.fit_transform(dataset["Profit"].values.reshape(-1, 1))
In [83]:
class_to_numeric = {'Low': 0, 'Medium': 1, 'High': 2}
dataset['Cluster'] = [class_to_numeric[label] for label in dataset['Cluster']]
In [84]:
dataset.head()
Out[84]:
Customer Name Category Sub Category City Order Date Region Sales Discount Profit State profit_margin Cluster
0 12 5 14 21 10 2 -0.414559 -1.430908 0.369225 0 0.595874 1
1 37 1 13 8 10 3 -1.291968 -0.627370 0.122296 0 -0.416872 1
2 14 3 0 13 5 4 1.507054 -0.225601 0.137417 0 -1.514014 0
3 15 4 12 4 9 3 -1.036563 0.310092 0.063185 0 -1.260827 0
4 28 3 18 12 9 3 1.498367 0.444015 0.877036 0 1.186643 2

Split Data¶

In [85]:
X = dataset.drop(['Cluster','Sub Category','State','profit_margin'],axis=1)
y = dataset['Cluster']
In [86]:
df_new_ma = pd.DataFrame(minmaxscaler.fit_transform(X), columns=X.columns) 
In [87]:
df_new_ma.head()
Out[87]:
Customer Name Category City Order Date Region Sales Discount Profit
0 0.244898 0.833333 0.913043 0.909091 0.50 0.3770 0.08 0.369225
1 0.755102 0.166667 0.347826 0.909091 0.75 0.1245 0.32 0.122296
2 0.285714 0.500000 0.565217 0.454545 1.00 0.9300 0.44 0.137417
3 0.306122 0.666667 0.173913 0.818182 0.75 0.1980 0.60 0.063185
4 0.571429 0.500000 0.521739 0.818182 0.75 0.9275 0.64 0.877036
In [88]:
subset_of_data = X.sample(frac=0.05, random_state=42)  
plt.figure(figsize=(14,7))
plt.subplot(1,2,2)
plt.title("Scatterplot Scaling", fontsize=18)
sns.scatterplot(data = subset_of_data, color="red")
plt.tight_layout()
plt.show()
In [89]:
plt.figure(figsize=(14,7))
plt.subplot(1,2,2)
plt.title("Data Scaling", fontsize=18)
sns.kdeplot(data = X, color="red")
plt.tight_layout()
plt.show()
In [90]:
X.head()
Out[90]:
Customer Name Category City Order Date Region Sales Discount Profit
0 12 5 21 10 2 -0.414559 -1.430908 0.369225
1 37 1 8 10 3 -1.291968 -0.627370 0.122296
2 14 3 13 5 4 1.507054 -0.225601 0.137417
3 15 4 4 9 3 -1.036563 0.310092 0.063185
4 28 3 12 9 3 1.498367 0.444015 0.877036
In [91]:
X_train, X_temp, y_train, y_temp = train_test_split(X, y, test_size=0.2, random_state=42)
X_val, X_test, y_val, y_test = train_test_split(X_temp, y_temp, test_size=0.5, random_state=42)
In [92]:
heatcol = X.corr()
sns.heatmap(heatcol,cmap="BrBG",annot=True)
Out[92]:
<AxesSubplot:>
In [93]:
print("Dimension of Train set",X_train.shape)
print("Dimension of Val set",X_val.shape)
print("Dimension of Test set",X_test.shape,"\n")

num_cols = X_train._get_numeric_data().columns
print("Number of numeric features:",num_cols.size)
Dimension of Train set (7960, 8)
Dimension of Val set (995, 8)
Dimension of Test set (996, 8) 

Number of numeric features: 8

ANN¶

In [94]:
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense,Flatten,Dropout
from tensorflow.keras.utils import to_categorical
from scikeras.wrappers import KerasClassifier
from sklearn.model_selection import GridSearchCV
In [95]:
y_train = to_categorical(y_train, num_classes=3)
y_val = to_categorical(y_val, num_classes=3)
In [96]:
def create_model(units=64,optimizer='adam',loss='categorical_crossentropy'):
    model = Sequential()
    model.add(Flatten(input_shape=(X_train.shape[1], 1)))
    model.add(Dense(units, activation='relu'))
    model.add(Dense(units=3, activation='sigmoid'))
    model.compile(loss=loss, optimizer=optimizer, metrics=['categorical_accuracy'])
    return model
In [97]:
model = KerasClassifier(build_fn=create_model, units=32, epochs=100, batch_size=32, verbose=0)

param_grid = {
    'optimizer': ['adam', 'sgd', 'rmsprop'],
    'units': [32, 64, 128],
    'loss' : ['categorical_crossentropy','binary_crossentropy','hinge','kullback_leibler_divergence']
}
grid = GridSearchCV(estimator=model, param_grid=param_grid, cv=2, scoring='accuracy')
grid_result = grid.fit(X_train, y_train)
c:\Users\Asus\anaconda3\lib\site-packages\scikeras\wrappers.py:915: UserWarning: ``build_fn`` will be renamed to ``model`` in a future release, at which point use of ``build_fn`` will raise an Error instead.
  X, y = self._initialize(X, y)
c:\Users\Asus\anaconda3\lib\site-packages\scikeras\wrappers.py:915: UserWarning: ``build_fn`` will be renamed to ``model`` in a future release, at which point use of ``build_fn`` will raise an Error instead.
  X, y = self._initialize(X, y)
c:\Users\Asus\anaconda3\lib\site-packages\scikeras\wrappers.py:915: UserWarning: ``build_fn`` will be renamed to ``model`` in a future release, at which point use of ``build_fn`` will raise an Error instead.
  X, y = self._initialize(X, y)
c:\Users\Asus\anaconda3\lib\site-packages\scikeras\wrappers.py:915: UserWarning: ``build_fn`` will be renamed to ``model`` in a future release, at which point use of ``build_fn`` will raise an Error instead.
  X, y = self._initialize(X, y)
c:\Users\Asus\anaconda3\lib\site-packages\scikeras\wrappers.py:915: UserWarning: ``build_fn`` will be renamed to ``model`` in a future release, at which point use of ``build_fn`` will raise an Error instead.
  X, y = self._initialize(X, y)
c:\Users\Asus\anaconda3\lib\site-packages\scikeras\wrappers.py:915: UserWarning: ``build_fn`` will be renamed to ``model`` in a future release, at which point use of ``build_fn`` will raise an Error instead.
  X, y = self._initialize(X, y)
c:\Users\Asus\anaconda3\lib\site-packages\scikeras\wrappers.py:915: UserWarning: ``build_fn`` will be renamed to ``model`` in a future release, at which point use of ``build_fn`` will raise an Error instead.
  X, y = self._initialize(X, y)
c:\Users\Asus\anaconda3\lib\site-packages\scikeras\wrappers.py:915: UserWarning: ``build_fn`` will be renamed to ``model`` in a future release, at which point use of ``build_fn`` will raise an Error instead.
  X, y = self._initialize(X, y)
c:\Users\Asus\anaconda3\lib\site-packages\scikeras\wrappers.py:915: UserWarning: ``build_fn`` will be renamed to ``model`` in a future release, at which point use of ``build_fn`` will raise an Error instead.
  X, y = self._initialize(X, y)
c:\Users\Asus\anaconda3\lib\site-packages\scikeras\wrappers.py:915: UserWarning: ``build_fn`` will be renamed to ``model`` in a future release, at which point use of ``build_fn`` will raise an Error instead.
  X, y = self._initialize(X, y)
c:\Users\Asus\anaconda3\lib\site-packages\scikeras\wrappers.py:915: UserWarning: ``build_fn`` will be renamed to ``model`` in a future release, at which point use of ``build_fn`` will raise an Error instead.
  X, y = self._initialize(X, y)
c:\Users\Asus\anaconda3\lib\site-packages\scikeras\wrappers.py:915: UserWarning: ``build_fn`` will be renamed to ``model`` in a future release, at which point use of ``build_fn`` will raise an Error instead.
  X, y = self._initialize(X, y)
c:\Users\Asus\anaconda3\lib\site-packages\scikeras\wrappers.py:915: UserWarning: ``build_fn`` will be renamed to ``model`` in a future release, at which point use of ``build_fn`` will raise an Error instead.
  X, y = self._initialize(X, y)
c:\Users\Asus\anaconda3\lib\site-packages\scikeras\wrappers.py:915: UserWarning: ``build_fn`` will be renamed to ``model`` in a future release, at which point use of ``build_fn`` will raise an Error instead.
  X, y = self._initialize(X, y)
c:\Users\Asus\anaconda3\lib\site-packages\scikeras\wrappers.py:915: UserWarning: ``build_fn`` will be renamed to ``model`` in a future release, at which point use of ``build_fn`` will raise an Error instead.
  X, y = self._initialize(X, y)
c:\Users\Asus\anaconda3\lib\site-packages\scikeras\wrappers.py:915: UserWarning: ``build_fn`` will be renamed to ``model`` in a future release, at which point use of ``build_fn`` will raise an Error instead.
  X, y = self._initialize(X, y)
c:\Users\Asus\anaconda3\lib\site-packages\scikeras\wrappers.py:915: UserWarning: ``build_fn`` will be renamed to ``model`` in a future release, at which point use of ``build_fn`` will raise an Error instead.
  X, y = self._initialize(X, y)
c:\Users\Asus\anaconda3\lib\site-packages\scikeras\wrappers.py:915: UserWarning: ``build_fn`` will be renamed to ``model`` in a future release, at which point use of ``build_fn`` will raise an Error instead.
  X, y = self._initialize(X, y)
c:\Users\Asus\anaconda3\lib\site-packages\scikeras\wrappers.py:915: UserWarning: ``build_fn`` will be renamed to ``model`` in a future release, at which point use of ``build_fn`` will raise an Error instead.
  X, y = self._initialize(X, y)
c:\Users\Asus\anaconda3\lib\site-packages\scikeras\wrappers.py:915: UserWarning: ``build_fn`` will be renamed to ``model`` in a future release, at which point use of ``build_fn`` will raise an Error instead.
  X, y = self._initialize(X, y)
c:\Users\Asus\anaconda3\lib\site-packages\scikeras\wrappers.py:915: UserWarning: ``build_fn`` will be renamed to ``model`` in a future release, at which point use of ``build_fn`` will raise an Error instead.
  X, y = self._initialize(X, y)
c:\Users\Asus\anaconda3\lib\site-packages\scikeras\wrappers.py:915: UserWarning: ``build_fn`` will be renamed to ``model`` in a future release, at which point use of ``build_fn`` will raise an Error instead.
  X, y = self._initialize(X, y)
c:\Users\Asus\anaconda3\lib\site-packages\scikeras\wrappers.py:915: UserWarning: ``build_fn`` will be renamed to ``model`` in a future release, at which point use of ``build_fn`` will raise an Error instead.
  X, y = self._initialize(X, y)
c:\Users\Asus\anaconda3\lib\site-packages\scikeras\wrappers.py:915: UserWarning: ``build_fn`` will be renamed to ``model`` in a future release, at which point use of ``build_fn`` will raise an Error instead.
  X, y = self._initialize(X, y)
c:\Users\Asus\anaconda3\lib\site-packages\scikeras\wrappers.py:915: UserWarning: ``build_fn`` will be renamed to ``model`` in a future release, at which point use of ``build_fn`` will raise an Error instead.
  X, y = self._initialize(X, y)
c:\Users\Asus\anaconda3\lib\site-packages\scikeras\wrappers.py:915: UserWarning: ``build_fn`` will be renamed to ``model`` in a future release, at which point use of ``build_fn`` will raise an Error instead.
  X, y = self._initialize(X, y)
c:\Users\Asus\anaconda3\lib\site-packages\scikeras\wrappers.py:915: UserWarning: ``build_fn`` will be renamed to ``model`` in a future release, at which point use of ``build_fn`` will raise an Error instead.
  X, y = self._initialize(X, y)
c:\Users\Asus\anaconda3\lib\site-packages\scikeras\wrappers.py:915: UserWarning: ``build_fn`` will be renamed to ``model`` in a future release, at which point use of ``build_fn`` will raise an Error instead.
  X, y = self._initialize(X, y)
c:\Users\Asus\anaconda3\lib\site-packages\scikeras\wrappers.py:915: UserWarning: ``build_fn`` will be renamed to ``model`` in a future release, at which point use of ``build_fn`` will raise an Error instead.
  X, y = self._initialize(X, y)
c:\Users\Asus\anaconda3\lib\site-packages\scikeras\wrappers.py:915: UserWarning: ``build_fn`` will be renamed to ``model`` in a future release, at which point use of ``build_fn`` will raise an Error instead.
  X, y = self._initialize(X, y)
c:\Users\Asus\anaconda3\lib\site-packages\scikeras\wrappers.py:915: UserWarning: ``build_fn`` will be renamed to ``model`` in a future release, at which point use of ``build_fn`` will raise an Error instead.
  X, y = self._initialize(X, y)
c:\Users\Asus\anaconda3\lib\site-packages\scikeras\wrappers.py:915: UserWarning: ``build_fn`` will be renamed to ``model`` in a future release, at which point use of ``build_fn`` will raise an Error instead.
  X, y = self._initialize(X, y)
c:\Users\Asus\anaconda3\lib\site-packages\scikeras\wrappers.py:915: UserWarning: ``build_fn`` will be renamed to ``model`` in a future release, at which point use of ``build_fn`` will raise an Error instead.
  X, y = self._initialize(X, y)
c:\Users\Asus\anaconda3\lib\site-packages\scikeras\wrappers.py:915: UserWarning: ``build_fn`` will be renamed to ``model`` in a future release, at which point use of ``build_fn`` will raise an Error instead.
  X, y = self._initialize(X, y)
c:\Users\Asus\anaconda3\lib\site-packages\scikeras\wrappers.py:915: UserWarning: ``build_fn`` will be renamed to ``model`` in a future release, at which point use of ``build_fn`` will raise an Error instead.
  X, y = self._initialize(X, y)
c:\Users\Asus\anaconda3\lib\site-packages\scikeras\wrappers.py:915: UserWarning: ``build_fn`` will be renamed to ``model`` in a future release, at which point use of ``build_fn`` will raise an Error instead.
  X, y = self._initialize(X, y)
c:\Users\Asus\anaconda3\lib\site-packages\scikeras\wrappers.py:915: UserWarning: ``build_fn`` will be renamed to ``model`` in a future release, at which point use of ``build_fn`` will raise an Error instead.
  X, y = self._initialize(X, y)
c:\Users\Asus\anaconda3\lib\site-packages\scikeras\wrappers.py:915: UserWarning: ``build_fn`` will be renamed to ``model`` in a future release, at which point use of ``build_fn`` will raise an Error instead.
  X, y = self._initialize(X, y)
c:\Users\Asus\anaconda3\lib\site-packages\scikeras\wrappers.py:915: UserWarning: ``build_fn`` will be renamed to ``model`` in a future release, at which point use of ``build_fn`` will raise an Error instead.
  X, y = self._initialize(X, y)
c:\Users\Asus\anaconda3\lib\site-packages\scikeras\wrappers.py:915: UserWarning: ``build_fn`` will be renamed to ``model`` in a future release, at which point use of ``build_fn`` will raise an Error instead.
  X, y = self._initialize(X, y)
c:\Users\Asus\anaconda3\lib\site-packages\scikeras\wrappers.py:915: UserWarning: ``build_fn`` will be renamed to ``model`` in a future release, at which point use of ``build_fn`` will raise an Error instead.
  X, y = self._initialize(X, y)
c:\Users\Asus\anaconda3\lib\site-packages\scikeras\wrappers.py:915: UserWarning: ``build_fn`` will be renamed to ``model`` in a future release, at which point use of ``build_fn`` will raise an Error instead.
  X, y = self._initialize(X, y)
c:\Users\Asus\anaconda3\lib\site-packages\scikeras\wrappers.py:915: UserWarning: ``build_fn`` will be renamed to ``model`` in a future release, at which point use of ``build_fn`` will raise an Error instead.
  X, y = self._initialize(X, y)
c:\Users\Asus\anaconda3\lib\site-packages\scikeras\wrappers.py:915: UserWarning: ``build_fn`` will be renamed to ``model`` in a future release, at which point use of ``build_fn`` will raise an Error instead.
  X, y = self._initialize(X, y)
c:\Users\Asus\anaconda3\lib\site-packages\scikeras\wrappers.py:915: UserWarning: ``build_fn`` will be renamed to ``model`` in a future release, at which point use of ``build_fn`` will raise an Error instead.
  X, y = self._initialize(X, y)
c:\Users\Asus\anaconda3\lib\site-packages\scikeras\wrappers.py:915: UserWarning: ``build_fn`` will be renamed to ``model`` in a future release, at which point use of ``build_fn`` will raise an Error instead.
  X, y = self._initialize(X, y)
c:\Users\Asus\anaconda3\lib\site-packages\scikeras\wrappers.py:915: UserWarning: ``build_fn`` will be renamed to ``model`` in a future release, at which point use of ``build_fn`` will raise an Error instead.
  X, y = self._initialize(X, y)
c:\Users\Asus\anaconda3\lib\site-packages\scikeras\wrappers.py:915: UserWarning: ``build_fn`` will be renamed to ``model`` in a future release, at which point use of ``build_fn`` will raise an Error instead.
  X, y = self._initialize(X, y)
c:\Users\Asus\anaconda3\lib\site-packages\scikeras\wrappers.py:915: UserWarning: ``build_fn`` will be renamed to ``model`` in a future release, at which point use of ``build_fn`` will raise an Error instead.
  X, y = self._initialize(X, y)
c:\Users\Asus\anaconda3\lib\site-packages\scikeras\wrappers.py:915: UserWarning: ``build_fn`` will be renamed to ``model`` in a future release, at which point use of ``build_fn`` will raise an Error instead.
  X, y = self._initialize(X, y)
c:\Users\Asus\anaconda3\lib\site-packages\scikeras\wrappers.py:915: UserWarning: ``build_fn`` will be renamed to ``model`` in a future release, at which point use of ``build_fn`` will raise an Error instead.
  X, y = self._initialize(X, y)
c:\Users\Asus\anaconda3\lib\site-packages\scikeras\wrappers.py:915: UserWarning: ``build_fn`` will be renamed to ``model`` in a future release, at which point use of ``build_fn`` will raise an Error instead.
  X, y = self._initialize(X, y)
c:\Users\Asus\anaconda3\lib\site-packages\scikeras\wrappers.py:915: UserWarning: ``build_fn`` will be renamed to ``model`` in a future release, at which point use of ``build_fn`` will raise an Error instead.
  X, y = self._initialize(X, y)
c:\Users\Asus\anaconda3\lib\site-packages\scikeras\wrappers.py:915: UserWarning: ``build_fn`` will be renamed to ``model`` in a future release, at which point use of ``build_fn`` will raise an Error instead.
  X, y = self._initialize(X, y)
c:\Users\Asus\anaconda3\lib\site-packages\scikeras\wrappers.py:915: UserWarning: ``build_fn`` will be renamed to ``model`` in a future release, at which point use of ``build_fn`` will raise an Error instead.
  X, y = self._initialize(X, y)
c:\Users\Asus\anaconda3\lib\site-packages\scikeras\wrappers.py:915: UserWarning: ``build_fn`` will be renamed to ``model`` in a future release, at which point use of ``build_fn`` will raise an Error instead.
  X, y = self._initialize(X, y)
c:\Users\Asus\anaconda3\lib\site-packages\scikeras\wrappers.py:915: UserWarning: ``build_fn`` will be renamed to ``model`` in a future release, at which point use of ``build_fn`` will raise an Error instead.
  X, y = self._initialize(X, y)
c:\Users\Asus\anaconda3\lib\site-packages\scikeras\wrappers.py:915: UserWarning: ``build_fn`` will be renamed to ``model`` in a future release, at which point use of ``build_fn`` will raise an Error instead.
  X, y = self._initialize(X, y)
c:\Users\Asus\anaconda3\lib\site-packages\scikeras\wrappers.py:915: UserWarning: ``build_fn`` will be renamed to ``model`` in a future release, at which point use of ``build_fn`` will raise an Error instead.
  X, y = self._initialize(X, y)
c:\Users\Asus\anaconda3\lib\site-packages\scikeras\wrappers.py:915: UserWarning: ``build_fn`` will be renamed to ``model`` in a future release, at which point use of ``build_fn`` will raise an Error instead.
  X, y = self._initialize(X, y)
c:\Users\Asus\anaconda3\lib\site-packages\scikeras\wrappers.py:915: UserWarning: ``build_fn`` will be renamed to ``model`` in a future release, at which point use of ``build_fn`` will raise an Error instead.
  X, y = self._initialize(X, y)
c:\Users\Asus\anaconda3\lib\site-packages\scikeras\wrappers.py:915: UserWarning: ``build_fn`` will be renamed to ``model`` in a future release, at which point use of ``build_fn`` will raise an Error instead.
  X, y = self._initialize(X, y)
c:\Users\Asus\anaconda3\lib\site-packages\scikeras\wrappers.py:915: UserWarning: ``build_fn`` will be renamed to ``model`` in a future release, at which point use of ``build_fn`` will raise an Error instead.
  X, y = self._initialize(X, y)
c:\Users\Asus\anaconda3\lib\site-packages\scikeras\wrappers.py:915: UserWarning: ``build_fn`` will be renamed to ``model`` in a future release, at which point use of ``build_fn`` will raise an Error instead.
  X, y = self._initialize(X, y)
c:\Users\Asus\anaconda3\lib\site-packages\scikeras\wrappers.py:915: UserWarning: ``build_fn`` will be renamed to ``model`` in a future release, at which point use of ``build_fn`` will raise an Error instead.
  X, y = self._initialize(X, y)
c:\Users\Asus\anaconda3\lib\site-packages\scikeras\wrappers.py:915: UserWarning: ``build_fn`` will be renamed to ``model`` in a future release, at which point use of ``build_fn`` will raise an Error instead.
  X, y = self._initialize(X, y)
c:\Users\Asus\anaconda3\lib\site-packages\scikeras\wrappers.py:915: UserWarning: ``build_fn`` will be renamed to ``model`` in a future release, at which point use of ``build_fn`` will raise an Error instead.
  X, y = self._initialize(X, y)
c:\Users\Asus\anaconda3\lib\site-packages\scikeras\wrappers.py:915: UserWarning: ``build_fn`` will be renamed to ``model`` in a future release, at which point use of ``build_fn`` will raise an Error instead.
  X, y = self._initialize(X, y)
c:\Users\Asus\anaconda3\lib\site-packages\scikeras\wrappers.py:915: UserWarning: ``build_fn`` will be renamed to ``model`` in a future release, at which point use of ``build_fn`` will raise an Error instead.
  X, y = self._initialize(X, y)
c:\Users\Asus\anaconda3\lib\site-packages\scikeras\wrappers.py:915: UserWarning: ``build_fn`` will be renamed to ``model`` in a future release, at which point use of ``build_fn`` will raise an Error instead.
  X, y = self._initialize(X, y)
c:\Users\Asus\anaconda3\lib\site-packages\scikeras\wrappers.py:915: UserWarning: ``build_fn`` will be renamed to ``model`` in a future release, at which point use of ``build_fn`` will raise an Error instead.
  X, y = self._initialize(X, y)
c:\Users\Asus\anaconda3\lib\site-packages\scikeras\wrappers.py:915: UserWarning: ``build_fn`` will be renamed to ``model`` in a future release, at which point use of ``build_fn`` will raise an Error instead.
  X, y = self._initialize(X, y)
c:\Users\Asus\anaconda3\lib\site-packages\sklearn\model_selection\_validation.py:425: FitFailedWarning: 
54 fits failed out of a total of 72.
The score on these train-test partitions for these parameters will be set to nan.
If these failures are not expected, you can try to debug them by setting error_score='raise'.

Below are more details about the failures:
--------------------------------------------------------------------------------
18 fits failed with the following error:
Traceback (most recent call last):
  File "c:\Users\Asus\anaconda3\lib\site-packages\sklearn\model_selection\_validation.py", line 729, in _fit_and_score
    estimator.fit(X_train, y_train, **fit_params)
  File "c:\Users\Asus\anaconda3\lib\site-packages\scikeras\wrappers.py", line 1491, in fit
    super().fit(X=X, y=y, sample_weight=sample_weight, **kwargs)
  File "c:\Users\Asus\anaconda3\lib\site-packages\scikeras\wrappers.py", line 760, in fit
    self._fit(
  File "c:\Users\Asus\anaconda3\lib\site-packages\scikeras\wrappers.py", line 926, in _fit
    self._check_model_compatibility(y)
  File "c:\Users\Asus\anaconda3\lib\site-packages\scikeras\wrappers.py", line 569, in _check_model_compatibility
    raise ValueError(
ValueError: loss=binary_crossentropy but model compiled with categorical_crossentropy. Data may not match loss function!

--------------------------------------------------------------------------------
18 fits failed with the following error:
Traceback (most recent call last):
  File "c:\Users\Asus\anaconda3\lib\site-packages\sklearn\model_selection\_validation.py", line 729, in _fit_and_score
    estimator.fit(X_train, y_train, **fit_params)
  File "c:\Users\Asus\anaconda3\lib\site-packages\scikeras\wrappers.py", line 1491, in fit
    super().fit(X=X, y=y, sample_weight=sample_weight, **kwargs)
  File "c:\Users\Asus\anaconda3\lib\site-packages\scikeras\wrappers.py", line 760, in fit
    self._fit(
  File "c:\Users\Asus\anaconda3\lib\site-packages\scikeras\wrappers.py", line 926, in _fit
    self._check_model_compatibility(y)
  File "c:\Users\Asus\anaconda3\lib\site-packages\scikeras\wrappers.py", line 569, in _check_model_compatibility
    raise ValueError(
ValueError: loss=hinge but model compiled with categorical_crossentropy. Data may not match loss function!

--------------------------------------------------------------------------------
18 fits failed with the following error:
Traceback (most recent call last):
  File "c:\Users\Asus\anaconda3\lib\site-packages\sklearn\model_selection\_validation.py", line 729, in _fit_and_score
    estimator.fit(X_train, y_train, **fit_params)
  File "c:\Users\Asus\anaconda3\lib\site-packages\scikeras\wrappers.py", line 1491, in fit
    super().fit(X=X, y=y, sample_weight=sample_weight, **kwargs)
  File "c:\Users\Asus\anaconda3\lib\site-packages\scikeras\wrappers.py", line 760, in fit
    self._fit(
  File "c:\Users\Asus\anaconda3\lib\site-packages\scikeras\wrappers.py", line 926, in _fit
    self._check_model_compatibility(y)
  File "c:\Users\Asus\anaconda3\lib\site-packages\scikeras\wrappers.py", line 569, in _check_model_compatibility
    raise ValueError(
ValueError: loss=kullback_leibler_divergence but model compiled with categorical_crossentropy. Data may not match loss function!

  warnings.warn(some_fits_failed_message, FitFailedWarning)
c:\Users\Asus\anaconda3\lib\site-packages\sklearn\model_selection\_search.py:979: UserWarning: One or more of the test scores are non-finite: [0.975      0.96331658 0.96105528 0.96218593 0.97072864 0.96055276
 0.96683417 0.97487437 0.96494975        nan        nan        nan
        nan        nan        nan        nan        nan        nan
        nan        nan        nan        nan        nan        nan
        nan        nan        nan        nan        nan        nan
        nan        nan        nan        nan        nan        nan]
  warnings.warn(
c:\Users\Asus\anaconda3\lib\site-packages\scikeras\wrappers.py:915: UserWarning: ``build_fn`` will be renamed to ``model`` in a future release, at which point use of ``build_fn`` will raise an Error instead.
  X, y = self._initialize(X, y)
In [98]:
print("Best: %f using %s" % (grid_result.best_score_, grid_result.best_params_))
Best: 0.975000 using {'loss': 'categorical_crossentropy', 'optimizer': 'adam', 'units': 32}
In [99]:
optimizer = grid_result.best_params_['optimizer']
units = grid_result.best_params_['units']
loss = grid_result.best_params_['loss']
In [100]:
model = Sequential()
model.add(Flatten(input_shape=(X_train.shape[1], 1)))

model.add(Dense(units, activation='relu'))
model.add(Dense(units=3, activation='sigmoid'))

model.compile(loss=loss, optimizer=optimizer, metrics=['mean_absolute_error'])

history = model.fit(X_train, y_train, epochs=100, batch_size=32, validation_data=(X_val, y_val))

accuracy = model.evaluate(X_val, y_val)
print("Accuracy: ", accuracy)
Epoch 1/100
249/249 [==============================] - 2s 4ms/step - loss: 1.1910 - mean_absolute_error: 0.6411 - val_loss: 1.1244 - val_mean_absolute_error: 0.6395
Epoch 2/100
249/249 [==============================] - 1s 3ms/step - loss: 1.0659 - mean_absolute_error: 0.6380 - val_loss: 1.0774 - val_mean_absolute_error: 0.6335
Epoch 3/100
249/249 [==============================] - 1s 3ms/step - loss: 0.9923 - mean_absolute_error: 0.6326 - val_loss: 0.9391 - val_mean_absolute_error: 0.6281
Epoch 4/100
249/249 [==============================] - 1s 3ms/step - loss: 0.8947 - mean_absolute_error: 0.6261 - val_loss: 0.8354 - val_mean_absolute_error: 0.6206
Epoch 5/100
249/249 [==============================] - 1s 4ms/step - loss: 0.7923 - mean_absolute_error: 0.6163 - val_loss: 0.7401 - val_mean_absolute_error: 0.6092
Epoch 6/100
249/249 [==============================] - 1s 3ms/step - loss: 0.7080 - mean_absolute_error: 0.6052 - val_loss: 0.6866 - val_mean_absolute_error: 0.5997
Epoch 7/100
249/249 [==============================] - 1s 4ms/step - loss: 0.6435 - mean_absolute_error: 0.5925 - val_loss: 0.6067 - val_mean_absolute_error: 0.5837
Epoch 8/100
249/249 [==============================] - 1s 3ms/step - loss: 0.5873 - mean_absolute_error: 0.5801 - val_loss: 0.5607 - val_mean_absolute_error: 0.5696
Epoch 9/100
249/249 [==============================] - 1s 2ms/step - loss: 0.5372 - mean_absolute_error: 0.5675 - val_loss: 0.5247 - val_mean_absolute_error: 0.5587
Epoch 10/100
249/249 [==============================] - 1s 2ms/step - loss: 0.4932 - mean_absolute_error: 0.5509 - val_loss: 0.4803 - val_mean_absolute_error: 0.5343
Epoch 11/100
249/249 [==============================] - 0s 2ms/step - loss: 0.4585 - mean_absolute_error: 0.5335 - val_loss: 0.4467 - val_mean_absolute_error: 0.5177
Epoch 12/100
249/249 [==============================] - 0s 2ms/step - loss: 0.4255 - mean_absolute_error: 0.5139 - val_loss: 0.4227 - val_mean_absolute_error: 0.4969
Epoch 13/100
249/249 [==============================] - 1s 2ms/step - loss: 0.3996 - mean_absolute_error: 0.4973 - val_loss: 0.4061 - val_mean_absolute_error: 0.4880
Epoch 14/100
249/249 [==============================] - 1s 2ms/step - loss: 0.3785 - mean_absolute_error: 0.4817 - val_loss: 0.3621 - val_mean_absolute_error: 0.4704
Epoch 15/100
249/249 [==============================] - 0s 2ms/step - loss: 0.3542 - mean_absolute_error: 0.4687 - val_loss: 0.3536 - val_mean_absolute_error: 0.4550
Epoch 16/100
249/249 [==============================] - 0s 2ms/step - loss: 0.3334 - mean_absolute_error: 0.4537 - val_loss: 0.3299 - val_mean_absolute_error: 0.4426
Epoch 17/100
249/249 [==============================] - 0s 2ms/step - loss: 0.3154 - mean_absolute_error: 0.4416 - val_loss: 0.3357 - val_mean_absolute_error: 0.4417
Epoch 18/100
249/249 [==============================] - 0s 2ms/step - loss: 0.3021 - mean_absolute_error: 0.4304 - val_loss: 0.2929 - val_mean_absolute_error: 0.4229
Epoch 19/100
249/249 [==============================] - 1s 2ms/step - loss: 0.2862 - mean_absolute_error: 0.4196 - val_loss: 0.2929 - val_mean_absolute_error: 0.4146
Epoch 20/100
249/249 [==============================] - 0s 2ms/step - loss: 0.2757 - mean_absolute_error: 0.4095 - val_loss: 0.2736 - val_mean_absolute_error: 0.4053
Epoch 21/100
249/249 [==============================] - 1s 2ms/step - loss: 0.2629 - mean_absolute_error: 0.3994 - val_loss: 0.2760 - val_mean_absolute_error: 0.3953
Epoch 22/100
249/249 [==============================] - 0s 2ms/step - loss: 0.2531 - mean_absolute_error: 0.3916 - val_loss: 0.2753 - val_mean_absolute_error: 0.3941
Epoch 23/100
249/249 [==============================] - 0s 2ms/step - loss: 0.2426 - mean_absolute_error: 0.3848 - val_loss: 0.2376 - val_mean_absolute_error: 0.3791
Epoch 24/100
249/249 [==============================] - 0s 2ms/step - loss: 0.2357 - mean_absolute_error: 0.3771 - val_loss: 0.2394 - val_mean_absolute_error: 0.3702
Epoch 25/100
249/249 [==============================] - 0s 2ms/step - loss: 0.2279 - mean_absolute_error: 0.3707 - val_loss: 0.2248 - val_mean_absolute_error: 0.3674
Epoch 26/100
249/249 [==============================] - 0s 2ms/step - loss: 0.2193 - mean_absolute_error: 0.3640 - val_loss: 0.2354 - val_mean_absolute_error: 0.3640
Epoch 27/100
249/249 [==============================] - 1s 2ms/step - loss: 0.2089 - mean_absolute_error: 0.3579 - val_loss: 0.2131 - val_mean_absolute_error: 0.3570
Epoch 28/100
249/249 [==============================] - 0s 2ms/step - loss: 0.2069 - mean_absolute_error: 0.3544 - val_loss: 0.2008 - val_mean_absolute_error: 0.3480
Epoch 29/100
249/249 [==============================] - 0s 2ms/step - loss: 0.2021 - mean_absolute_error: 0.3487 - val_loss: 0.1961 - val_mean_absolute_error: 0.3450
Epoch 30/100
249/249 [==============================] - 0s 2ms/step - loss: 0.1916 - mean_absolute_error: 0.3440 - val_loss: 0.1964 - val_mean_absolute_error: 0.3376
Epoch 31/100
249/249 [==============================] - 0s 2ms/step - loss: 0.1885 - mean_absolute_error: 0.3389 - val_loss: 0.1858 - val_mean_absolute_error: 0.3388
Epoch 32/100
249/249 [==============================] - 0s 2ms/step - loss: 0.1851 - mean_absolute_error: 0.3363 - val_loss: 0.1802 - val_mean_absolute_error: 0.3323
Epoch 33/100
249/249 [==============================] - 0s 2ms/step - loss: 0.1796 - mean_absolute_error: 0.3333 - val_loss: 0.1808 - val_mean_absolute_error: 0.3329
Epoch 34/100
249/249 [==============================] - 0s 2ms/step - loss: 0.1722 - mean_absolute_error: 0.3303 - val_loss: 0.1789 - val_mean_absolute_error: 0.3264
Epoch 35/100
249/249 [==============================] - 1s 2ms/step - loss: 0.1708 - mean_absolute_error: 0.3281 - val_loss: 0.1660 - val_mean_absolute_error: 0.3260
Epoch 36/100
249/249 [==============================] - 0s 2ms/step - loss: 0.1629 - mean_absolute_error: 0.3245 - val_loss: 0.1664 - val_mean_absolute_error: 0.3233
Epoch 37/100
249/249 [==============================] - 0s 2ms/step - loss: 0.1624 - mean_absolute_error: 0.3226 - val_loss: 0.1608 - val_mean_absolute_error: 0.3180
Epoch 38/100
249/249 [==============================] - 1s 3ms/step - loss: 0.1589 - mean_absolute_error: 0.3204 - val_loss: 0.1551 - val_mean_absolute_error: 0.3176
Epoch 39/100
249/249 [==============================] - 0s 2ms/step - loss: 0.1577 - mean_absolute_error: 0.3183 - val_loss: 0.1614 - val_mean_absolute_error: 0.3207
Epoch 40/100
249/249 [==============================] - 0s 2ms/step - loss: 0.1519 - mean_absolute_error: 0.3168 - val_loss: 0.1600 - val_mean_absolute_error: 0.3190
Epoch 41/100
249/249 [==============================] - 1s 2ms/step - loss: 0.1492 - mean_absolute_error: 0.3146 - val_loss: 0.1483 - val_mean_absolute_error: 0.3111
Epoch 42/100
249/249 [==============================] - 0s 2ms/step - loss: 0.1460 - mean_absolute_error: 0.3126 - val_loss: 0.1433 - val_mean_absolute_error: 0.3117
Epoch 43/100
249/249 [==============================] - 0s 2ms/step - loss: 0.1449 - mean_absolute_error: 0.3104 - val_loss: 0.1429 - val_mean_absolute_error: 0.3103
Epoch 44/100
249/249 [==============================] - 0s 2ms/step - loss: 0.1395 - mean_absolute_error: 0.3088 - val_loss: 0.1551 - val_mean_absolute_error: 0.3091
Epoch 45/100
249/249 [==============================] - 0s 2ms/step - loss: 0.1398 - mean_absolute_error: 0.3070 - val_loss: 0.1453 - val_mean_absolute_error: 0.3085
Epoch 46/100
249/249 [==============================] - 0s 2ms/step - loss: 0.1347 - mean_absolute_error: 0.3065 - val_loss: 0.1335 - val_mean_absolute_error: 0.3057
Epoch 47/100
249/249 [==============================] - 0s 2ms/step - loss: 0.1339 - mean_absolute_error: 0.3058 - val_loss: 0.1429 - val_mean_absolute_error: 0.3065
Epoch 48/100
249/249 [==============================] - 1s 2ms/step - loss: 0.1293 - mean_absolute_error: 0.3045 - val_loss: 0.1582 - val_mean_absolute_error: 0.3063
Epoch 49/100
249/249 [==============================] - 0s 2ms/step - loss: 0.1302 - mean_absolute_error: 0.3038 - val_loss: 0.1316 - val_mean_absolute_error: 0.2995
Epoch 50/100
249/249 [==============================] - 0s 2ms/step - loss: 0.1270 - mean_absolute_error: 0.3020 - val_loss: 0.1248 - val_mean_absolute_error: 0.3014
Epoch 51/100
249/249 [==============================] - 1s 2ms/step - loss: 0.1254 - mean_absolute_error: 0.3007 - val_loss: 0.1246 - val_mean_absolute_error: 0.2976
Epoch 52/100
249/249 [==============================] - 0s 2ms/step - loss: 0.1237 - mean_absolute_error: 0.3001 - val_loss: 0.1268 - val_mean_absolute_error: 0.2993
Epoch 53/100
249/249 [==============================] - 0s 2ms/step - loss: 0.1224 - mean_absolute_error: 0.2992 - val_loss: 0.1205 - val_mean_absolute_error: 0.2961
Epoch 54/100
249/249 [==============================] - 1s 2ms/step - loss: 0.1204 - mean_absolute_error: 0.2985 - val_loss: 0.1195 - val_mean_absolute_error: 0.2967
Epoch 55/100
249/249 [==============================] - 1s 2ms/step - loss: 0.1161 - mean_absolute_error: 0.2968 - val_loss: 0.1323 - val_mean_absolute_error: 0.2987
Epoch 56/100
249/249 [==============================] - 0s 2ms/step - loss: 0.1159 - mean_absolute_error: 0.2967 - val_loss: 0.1243 - val_mean_absolute_error: 0.2918
Epoch 57/100
249/249 [==============================] - 0s 2ms/step - loss: 0.1170 - mean_absolute_error: 0.2957 - val_loss: 0.1168 - val_mean_absolute_error: 0.2954
Epoch 58/100
249/249 [==============================] - 0s 2ms/step - loss: 0.1118 - mean_absolute_error: 0.2945 - val_loss: 0.1122 - val_mean_absolute_error: 0.2922
Epoch 59/100
249/249 [==============================] - 0s 2ms/step - loss: 0.1104 - mean_absolute_error: 0.2941 - val_loss: 0.1136 - val_mean_absolute_error: 0.2926
Epoch 60/100
249/249 [==============================] - 0s 2ms/step - loss: 0.1114 - mean_absolute_error: 0.2934 - val_loss: 0.1122 - val_mean_absolute_error: 0.2906
Epoch 61/100
249/249 [==============================] - 1s 2ms/step - loss: 0.1084 - mean_absolute_error: 0.2917 - val_loss: 0.1300 - val_mean_absolute_error: 0.2949
Epoch 62/100
249/249 [==============================] - 0s 2ms/step - loss: 0.1058 - mean_absolute_error: 0.2919 - val_loss: 0.1068 - val_mean_absolute_error: 0.2885
Epoch 63/100
249/249 [==============================] - 0s 2ms/step - loss: 0.1048 - mean_absolute_error: 0.2905 - val_loss: 0.1115 - val_mean_absolute_error: 0.2907
Epoch 64/100
249/249 [==============================] - 1s 2ms/step - loss: 0.1066 - mean_absolute_error: 0.2890 - val_loss: 0.1111 - val_mean_absolute_error: 0.2885
Epoch 65/100
249/249 [==============================] - 0s 2ms/step - loss: 0.1070 - mean_absolute_error: 0.2881 - val_loss: 0.1144 - val_mean_absolute_error: 0.2862
Epoch 66/100
249/249 [==============================] - 0s 2ms/step - loss: 0.1058 - mean_absolute_error: 0.2870 - val_loss: 0.1062 - val_mean_absolute_error: 0.2849
Epoch 67/100
249/249 [==============================] - 0s 2ms/step - loss: 0.1033 - mean_absolute_error: 0.2869 - val_loss: 0.1075 - val_mean_absolute_error: 0.2859
Epoch 68/100
249/249 [==============================] - 0s 2ms/step - loss: 0.1009 - mean_absolute_error: 0.2852 - val_loss: 0.0992 - val_mean_absolute_error: 0.2831
Epoch 69/100
249/249 [==============================] - 0s 2ms/step - loss: 0.0994 - mean_absolute_error: 0.2851 - val_loss: 0.1087 - val_mean_absolute_error: 0.2850
Epoch 70/100
249/249 [==============================] - 0s 2ms/step - loss: 0.0962 - mean_absolute_error: 0.2848 - val_loss: 0.0973 - val_mean_absolute_error: 0.2813
Epoch 71/100
249/249 [==============================] - 1s 3ms/step - loss: 0.0970 - mean_absolute_error: 0.2841 - val_loss: 0.1170 - val_mean_absolute_error: 0.2797
Epoch 72/100
249/249 [==============================] - 1s 2ms/step - loss: 0.0955 - mean_absolute_error: 0.2836 - val_loss: 0.1019 - val_mean_absolute_error: 0.2807
Epoch 73/100
249/249 [==============================] - 0s 2ms/step - loss: 0.0954 - mean_absolute_error: 0.2828 - val_loss: 0.1007 - val_mean_absolute_error: 0.2790
Epoch 74/100
249/249 [==============================] - 0s 2ms/step - loss: 0.0931 - mean_absolute_error: 0.2827 - val_loss: 0.0917 - val_mean_absolute_error: 0.2806
Epoch 75/100
249/249 [==============================] - 0s 2ms/step - loss: 0.0924 - mean_absolute_error: 0.2824 - val_loss: 0.0965 - val_mean_absolute_error: 0.2787
Epoch 76/100
249/249 [==============================] - 1s 2ms/step - loss: 0.0928 - mean_absolute_error: 0.2820 - val_loss: 0.1156 - val_mean_absolute_error: 0.2824
Epoch 77/100
249/249 [==============================] - 0s 2ms/step - loss: 0.0905 - mean_absolute_error: 0.2808 - val_loss: 0.1023 - val_mean_absolute_error: 0.2814
Epoch 78/100
249/249 [==============================] - 1s 2ms/step - loss: 0.0889 - mean_absolute_error: 0.2812 - val_loss: 0.0889 - val_mean_absolute_error: 0.2774
Epoch 79/100
249/249 [==============================] - 0s 2ms/step - loss: 0.0880 - mean_absolute_error: 0.2802 - val_loss: 0.0892 - val_mean_absolute_error: 0.2780
Epoch 80/100
249/249 [==============================] - 0s 2ms/step - loss: 0.0898 - mean_absolute_error: 0.2798 - val_loss: 0.1034 - val_mean_absolute_error: 0.2787
Epoch 81/100
249/249 [==============================] - 0s 2ms/step - loss: 0.0915 - mean_absolute_error: 0.2800 - val_loss: 0.0916 - val_mean_absolute_error: 0.2780
Epoch 82/100
249/249 [==============================] - 1s 2ms/step - loss: 0.0889 - mean_absolute_error: 0.2797 - val_loss: 0.0969 - val_mean_absolute_error: 0.2779
Epoch 83/100
249/249 [==============================] - 0s 2ms/step - loss: 0.0841 - mean_absolute_error: 0.2784 - val_loss: 0.0923 - val_mean_absolute_error: 0.2779
Epoch 84/100
249/249 [==============================] - 0s 2ms/step - loss: 0.0842 - mean_absolute_error: 0.2782 - val_loss: 0.0923 - val_mean_absolute_error: 0.2762
Epoch 85/100
249/249 [==============================] - 1s 2ms/step - loss: 0.0851 - mean_absolute_error: 0.2783 - val_loss: 0.0852 - val_mean_absolute_error: 0.2763
Epoch 86/100
249/249 [==============================] - 0s 2ms/step - loss: 0.0833 - mean_absolute_error: 0.2779 - val_loss: 0.0912 - val_mean_absolute_error: 0.2762
Epoch 87/100
249/249 [==============================] - 0s 2ms/step - loss: 0.0815 - mean_absolute_error: 0.2769 - val_loss: 0.0947 - val_mean_absolute_error: 0.2749
Epoch 88/100
249/249 [==============================] - 0s 2ms/step - loss: 0.0846 - mean_absolute_error: 0.2762 - val_loss: 0.0921 - val_mean_absolute_error: 0.2733
Epoch 89/100
249/249 [==============================] - 0s 2ms/step - loss: 0.0798 - mean_absolute_error: 0.2758 - val_loss: 0.0836 - val_mean_absolute_error: 0.2742
Epoch 90/100
249/249 [==============================] - 0s 2ms/step - loss: 0.0797 - mean_absolute_error: 0.2752 - val_loss: 0.0820 - val_mean_absolute_error: 0.2739
Epoch 91/100
249/249 [==============================] - 0s 2ms/step - loss: 0.0796 - mean_absolute_error: 0.2753 - val_loss: 0.0805 - val_mean_absolute_error: 0.2736
Epoch 92/100
249/249 [==============================] - 0s 2ms/step - loss: 0.0802 - mean_absolute_error: 0.2752 - val_loss: 0.0806 - val_mean_absolute_error: 0.2716
Epoch 93/100
249/249 [==============================] - 0s 2ms/step - loss: 0.0802 - mean_absolute_error: 0.2744 - val_loss: 0.0779 - val_mean_absolute_error: 0.2721
Epoch 94/100
249/249 [==============================] - 0s 2ms/step - loss: 0.0818 - mean_absolute_error: 0.2745 - val_loss: 0.0827 - val_mean_absolute_error: 0.2721
Epoch 95/100
249/249 [==============================] - 1s 2ms/step - loss: 0.0843 - mean_absolute_error: 0.2748 - val_loss: 0.0791 - val_mean_absolute_error: 0.2720
Epoch 96/100
249/249 [==============================] - 0s 2ms/step - loss: 0.0771 - mean_absolute_error: 0.2740 - val_loss: 0.0778 - val_mean_absolute_error: 0.2706
Epoch 97/100
249/249 [==============================] - 1s 2ms/step - loss: 0.0810 - mean_absolute_error: 0.2735 - val_loss: 0.0923 - val_mean_absolute_error: 0.2728
Epoch 98/100
249/249 [==============================] - 0s 2ms/step - loss: 0.0744 - mean_absolute_error: 0.2732 - val_loss: 0.0820 - val_mean_absolute_error: 0.2712
Epoch 99/100
249/249 [==============================] - 0s 2ms/step - loss: 0.0750 - mean_absolute_error: 0.2728 - val_loss: 0.0775 - val_mean_absolute_error: 0.2699
Epoch 100/100
249/249 [==============================] - 0s 2ms/step - loss: 0.0776 - mean_absolute_error: 0.2729 - val_loss: 0.0769 - val_mean_absolute_error: 0.2707
32/32 [==============================] - 0s 1ms/step - loss: 0.0769 - mean_absolute_error: 0.2707
Accuracy:  [0.07687386870384216, 0.2706981301307678]
In [101]:
from tensorflow.keras.utils import plot_model
model.summary()
plot_model(model, to_file='mlp-mnist.png', show_shapes=True)
Model: "sequential_147"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
=================================================================
 flatten_147 (Flatten)       (None, 8)                 0         
                                                                 
 dense_294 (Dense)           (None, 32)                288       
                                                                 
 dense_295 (Dense)           (None, 3)                 99        
                                                                 
=================================================================
Total params: 387 (1.51 KB)
Trainable params: 387 (1.51 KB)
Non-trainable params: 0 (0.00 Byte)
_________________________________________________________________
You must install pydot (`pip install pydot`) and install graphviz (see instructions at https://graphviz.gitlab.io/download/) for plot_model to work.
In [102]:
import joblib
joblib.dump(model, 'model/ann_model.pkl')
INFO:tensorflow:Assets written to: C:\Users\Asus\AppData\Local\Temp\tmpfg28j20q\assets
INFO:tensorflow:Assets written to: C:\Users\Asus\AppData\Local\Temp\tmpfg28j20q\assets
Out[102]:
['model/ann_model.pkl']
In [103]:
predictions = model.predict(X_test)
y_pred = np.argmax(predictions, axis=1)
32/32 [==============================] - 0s 1ms/step
In [104]:
plt.figure(figsize=(12, 5))
plt.subplot(1, 2, 1)
plt.plot(history.history['loss'])
plt.plot(history.history['val_loss'])
plt.title('Loss over epochs')
plt.ylabel('Loss')
plt.xlabel('Epoch')
plt.legend(['Train', 'Validation'], loc='best')

plt.subplot(1, 2, 2)
plt.plot(history.history['mean_absolute_error'])
plt.plot(history.history['val_mean_absolute_error'])
plt.title('Accuracy over epochs')
plt.ylabel('Accuracy')
plt.xlabel('Epoch')
plt.legend(['Train', 'Validation'], loc='best')

plt.tight_layout()
plt.show()
In [105]:
cm = confusion_matrix(y_test, y_pred)
disp = ConfusionMatrixDisplay(confusion_matrix=cm)
disp.plot(cmap='Blues', values_format='d')
plt.title("Confusion Matrix: RNN")
plt.show()

print(classification_report(y_test, y_pred))
              precision    recall  f1-score   support

           0       0.99      1.00      0.99       327
           1       0.99      0.98      0.99       352
           2       0.99      0.99      0.99       317

    accuracy                           0.99       996
   macro avg       0.99      0.99      0.99       996
weighted avg       0.99      0.99      0.99       996