Libraries and Dataset¶

In [1]:
import numpy as np
import pandas as pd
import seaborn  as sns
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import LabelEncoder, OneHotEncoder, StandardScaler,MinMaxScaler
from sklearn.model_selection import cross_val_score, GridSearchCV
import matplotlib.pyplot as plt
from sklearn.metrics import confusion_matrix,classification_report, ConfusionMatrixDisplay
In [2]:
dataset = pd.read_csv('data/market_cluster.csv', encoding='latin1')
In [3]:
dataset.head()
Out[3]:
Order ID Customer Name Category Sub Category City Order Date Region Sales Discount Profit State profit_margin Cluster
0 OD1 Harish Oil & Masala Masalas Vellore 11-08-2017 North 1254 0.12 401.28 Tamil Nadu 0.32 Medium
1 OD2 Sudha Beverages Health Drinks Krishnagiri 11-08-2017 South 749 0.18 149.80 Tamil Nadu 0.20 Medium
2 OD3 Hussain Food Grains Atta & Flour Perambalur 06-12-2017 West 2360 0.21 165.20 Tamil Nadu 0.07 Low
3 OD4 Jackson Fruits & Veggies Fresh Vegetables Dharmapuri 10-11-2016 South 896 0.25 89.60 Tamil Nadu 0.10 Low
4 OD5 Ridhesh Food Grains Organic Staples Ooty 10-11-2016 South 2355 0.26 918.45 Tamil Nadu 0.39 High
In [4]:
dataset.info()
<class 'pandas.core.frame.DataFrame'>
RangeIndex: 9994 entries, 0 to 9993
Data columns (total 13 columns):
 #   Column         Non-Null Count  Dtype  
---  ------         --------------  -----  
 0   Order ID       9994 non-null   object 
 1   Customer Name  9994 non-null   object 
 2   Category       9994 non-null   object 
 3   Sub Category   9994 non-null   object 
 4   City           9994 non-null   object 
 5   Order Date     9994 non-null   object 
 6   Region         9994 non-null   object 
 7   Sales          9994 non-null   int64  
 8   Discount       9994 non-null   float64
 9   Profit         9994 non-null   float64
 10  State          9994 non-null   object 
 11  profit_margin  9994 non-null   float64
 12  Cluster        9994 non-null   object 
dtypes: float64(3), int64(1), object(9)
memory usage: 1015.1+ KB
In [5]:
dataset.describe()
Out[5]:
Sales Discount Profit profit_margin
count 9994.000000 9994.000000 9994.000000 9994.000000
mean 1496.596158 0.226817 374.937082 0.250228
std 577.559036 0.074636 239.932881 0.118919
min 500.000000 0.100000 25.250000 0.050000
25% 1000.000000 0.160000 180.022500 0.150000
50% 1498.000000 0.230000 320.780000 0.250000
75% 1994.750000 0.290000 525.627500 0.350000
max 2500.000000 0.350000 1120.950000 0.450000
In [6]:
dataset.info()
<class 'pandas.core.frame.DataFrame'>
RangeIndex: 9994 entries, 0 to 9993
Data columns (total 13 columns):
 #   Column         Non-Null Count  Dtype  
---  ------         --------------  -----  
 0   Order ID       9994 non-null   object 
 1   Customer Name  9994 non-null   object 
 2   Category       9994 non-null   object 
 3   Sub Category   9994 non-null   object 
 4   City           9994 non-null   object 
 5   Order Date     9994 non-null   object 
 6   Region         9994 non-null   object 
 7   Sales          9994 non-null   int64  
 8   Discount       9994 non-null   float64
 9   Profit         9994 non-null   float64
 10  State          9994 non-null   object 
 11  profit_margin  9994 non-null   float64
 12  Cluster        9994 non-null   object 
dtypes: float64(3), int64(1), object(9)
memory usage: 1015.1+ KB

Data Cleaning & Preprocessing¶

In [7]:
dataset.drop(['Order ID'], axis=1, inplace=True)
In [8]:
dataset.isna().sum()
Out[8]:
Customer Name    0
Category         0
Sub Category     0
City             0
Order Date       0
Region           0
Sales            0
Discount         0
Profit           0
State            0
profit_margin    0
Cluster          0
dtype: int64
In [9]:
dataset.dropna(inplace=True)
In [10]:
def remove_outliers(data: pd.DataFrame, column: str) -> pd.Series:
    q3, q1 = np.nanpercentile(data[column], [75, 25])
    iqr = q3 - q1
    upper_bound = q3 + 1.5 * iqr
    lower_bound = q1 - 1.5 * iqr
    data = data[(data[column] > lower_bound) & (data[column] < upper_bound)]

    return data

dataset = remove_outliers(dataset, 'Discount')
dataset = remove_outliers(dataset, 'Sales')
dataset = remove_outliers(dataset, 'Profit')
In [11]:
dataset.head()
Out[11]:
Customer Name Category Sub Category City Order Date Region Sales Discount Profit State profit_margin Cluster
0 Harish Oil & Masala Masalas Vellore 11-08-2017 North 1254 0.12 401.28 Tamil Nadu 0.32 Medium
1 Sudha Beverages Health Drinks Krishnagiri 11-08-2017 South 749 0.18 149.80 Tamil Nadu 0.20 Medium
2 Hussain Food Grains Atta & Flour Perambalur 06-12-2017 West 2360 0.21 165.20 Tamil Nadu 0.07 Low
3 Jackson Fruits & Veggies Fresh Vegetables Dharmapuri 10-11-2016 South 896 0.25 89.60 Tamil Nadu 0.10 Low
4 Ridhesh Food Grains Organic Staples Ooty 10-11-2016 South 2355 0.26 918.45 Tamil Nadu 0.39 High
In [12]:
sns.histplot(dataset['Cluster'])
Out[12]:
<AxesSubplot:xlabel='Cluster', ylabel='Count'>
In [13]:
encoder = LabelEncoder()
scaler = StandardScaler()
onehot = OneHotEncoder()
minmaxscaler = MinMaxScaler()
In [14]:
dataset["Order Date"] = pd.to_datetime(dataset["Order Date"])
dataset["Order Date"] = dataset["Order Date"].dt.month

# dataset["Customer Name"] = onehot.fit_transform(dataset["Customer Name"].values.reshape(-1, 1)).toarray()
# dataset["Category"] = onehot.fit_transform(dataset["Category"].values.reshape(-1, 1)).toarray()
# dataset["City"] = onehot.fit_transform(dataset["City"].values.reshape(-1, 1)).toarray()
# dataset["Region"] = onehot.fit_transform(dataset["Region"].values.reshape(-1, 1)).toarray()
# dataset["State"] = onehot.fit_transform(dataset["State"].values.reshape(-1, 1)).toarray()
# dataset["Sub Category"] = onehot.fit_transform(dataset["Sub Category"].values.reshape(-1, 1)).toarray()

dataset["Customer Name"] = encoder.fit_transform(dataset["Customer Name"])
dataset["Category"] = encoder.fit_transform(dataset["Category"])
dataset["City"] = encoder.fit_transform(dataset["City"])
dataset["Region"] = encoder.fit_transform(dataset["Region"])
dataset["State"] = encoder.fit_transform(dataset["State"])
dataset["Sub Category"] = encoder.fit_transform(dataset["Sub Category"])

dataset["Order Date"] = encoder.fit_transform(dataset["Order Date"])
In [15]:
dataset[["Sales", "Discount", "profit_margin","Profit"]] = scaler.fit_transform(dataset[["Sales", "Discount", "profit_margin","Profit"]])
In [16]:
class_to_numeric = {'Low': 0, 'Medium': 1, 'High': 2}
dataset['Cluster'] = [class_to_numeric[label] for label in dataset['Cluster']]
In [17]:
dataset.head()
Out[17]:
Customer Name Category Sub Category City Order Date Region Sales Discount Profit State profit_margin Cluster
0 12 5 14 21 10 2 -0.414559 -1.430908 0.124389 0 0.595874 1
1 37 1 13 8 10 3 -1.291968 -0.627370 -0.941183 0 -0.416872 1
2 14 3 0 13 5 4 1.507054 -0.225601 -0.875930 0 -1.514014 0
3 15 4 12 4 9 3 -1.036563 0.310092 -1.196262 0 -1.260827 0
4 28 3 18 12 9 3 1.498367 0.444015 2.315743 0 1.186643 2

Split Data¶

In [18]:
X = dataset.drop(['Cluster','Sub Category','State','profit_margin'],axis=1)
y = dataset['Cluster']
In [19]:
X.head()
Out[19]:
Customer Name Category City Order Date Region Sales Discount Profit
0 12 5 21 10 2 -0.414559 -1.430908 0.124389
1 37 1 8 10 3 -1.291968 -0.627370 -0.941183
2 14 3 13 5 4 1.507054 -0.225601 -0.875930
3 15 4 4 9 3 -1.036563 0.310092 -1.196262
4 28 3 12 9 3 1.498367 0.444015 2.315743
In [20]:
X_train, X_temp, y_train, y_temp = train_test_split(X, y, test_size=0.2, random_state=42)
X_val, X_test, y_val, y_test = train_test_split(X_temp, y_temp, test_size=0.5, random_state=42)
In [21]:
heatcol = X.corr()
sns.heatmap(heatcol,cmap="BrBG",annot=True)
Out[21]:
<AxesSubplot:>
In [22]:
print("Dimension of Train set",X_train.shape)
print("Dimension of Val set",X_val.shape)
print("Dimension of Test set",X_test.shape,"\n")

num_cols = X_train._get_numeric_data().columns
print("Number of numeric features:",num_cols.size)
Dimension of Train set (7960, 8)
Dimension of Val set (995, 8)
Dimension of Test set (996, 8) 

Number of numeric features: 8

RNN¶

In [23]:
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense,SimpleRNN,Dropout
from tensorflow.keras.utils import to_categorical, plot_model
from scikeras.wrappers import KerasClassifier, KerasRegressor
from sklearn.model_selection import GridSearchCV
In [24]:
y_train = to_categorical(y_train, num_classes=3)
y_val = to_categorical(y_val, num_classes=3)
In [25]:
def create_model(units=64,optimizer='adam',loss='categorical_crossentropy'):
    model = Sequential()
    model.add(SimpleRNN(units, return_sequences=False, input_shape=(X_train.shape[1], 1)))
    model.add(Dense(units, activation='relu'))
    model.add(Dense(units=3, activation='sigmoid'))
    model.compile(loss=loss, optimizer=optimizer, metrics=['categorical_accuracy'])
    
    return model
In [26]:
# model = KerasClassifier(build_fn=create_model, epochs=100, batch_size=32, verbose=0, validation_data=(X_val, y_val))
# model = KerasClassifier(build_fn=create_model, units=32, epochs=100, batch_size=32, verbose=0)
model = KerasClassifier(build_fn=create_model, units=32, epochs=100, batch_size=32, verbose=0)


param_grid = {
    'optimizer': ['adam', 'sgd', 'rmsprop'],
    'units': [32, 64, 128],
    'loss' : ['categorical_crossentropy','binary_crossentropy','hinge','kullback_leibler_divergence']
}
In [27]:
grid = GridSearchCV(estimator=model, param_grid=param_grid, cv=2, scoring='accuracy')
grid_result = grid.fit(X_train, y_train)
c:\Users\Asus\anaconda3\lib\site-packages\scikeras\wrappers.py:915: UserWarning: ``build_fn`` will be renamed to ``model`` in a future release, at which point use of ``build_fn`` will raise an Error instead.
  X, y = self._initialize(X, y)
c:\Users\Asus\anaconda3\lib\site-packages\scikeras\wrappers.py:915: UserWarning: ``build_fn`` will be renamed to ``model`` in a future release, at which point use of ``build_fn`` will raise an Error instead.
  X, y = self._initialize(X, y)
c:\Users\Asus\anaconda3\lib\site-packages\scikeras\wrappers.py:915: UserWarning: ``build_fn`` will be renamed to ``model`` in a future release, at which point use of ``build_fn`` will raise an Error instead.
  X, y = self._initialize(X, y)
c:\Users\Asus\anaconda3\lib\site-packages\scikeras\wrappers.py:915: UserWarning: ``build_fn`` will be renamed to ``model`` in a future release, at which point use of ``build_fn`` will raise an Error instead.
  X, y = self._initialize(X, y)
c:\Users\Asus\anaconda3\lib\site-packages\scikeras\wrappers.py:915: UserWarning: ``build_fn`` will be renamed to ``model`` in a future release, at which point use of ``build_fn`` will raise an Error instead.
  X, y = self._initialize(X, y)
c:\Users\Asus\anaconda3\lib\site-packages\scikeras\wrappers.py:915: UserWarning: ``build_fn`` will be renamed to ``model`` in a future release, at which point use of ``build_fn`` will raise an Error instead.
  X, y = self._initialize(X, y)
c:\Users\Asus\anaconda3\lib\site-packages\scikeras\wrappers.py:915: UserWarning: ``build_fn`` will be renamed to ``model`` in a future release, at which point use of ``build_fn`` will raise an Error instead.
  X, y = self._initialize(X, y)
c:\Users\Asus\anaconda3\lib\site-packages\scikeras\wrappers.py:915: UserWarning: ``build_fn`` will be renamed to ``model`` in a future release, at which point use of ``build_fn`` will raise an Error instead.
  X, y = self._initialize(X, y)
c:\Users\Asus\anaconda3\lib\site-packages\scikeras\wrappers.py:915: UserWarning: ``build_fn`` will be renamed to ``model`` in a future release, at which point use of ``build_fn`` will raise an Error instead.
  X, y = self._initialize(X, y)
c:\Users\Asus\anaconda3\lib\site-packages\scikeras\wrappers.py:915: UserWarning: ``build_fn`` will be renamed to ``model`` in a future release, at which point use of ``build_fn`` will raise an Error instead.
  X, y = self._initialize(X, y)
c:\Users\Asus\anaconda3\lib\site-packages\scikeras\wrappers.py:915: UserWarning: ``build_fn`` will be renamed to ``model`` in a future release, at which point use of ``build_fn`` will raise an Error instead.
  X, y = self._initialize(X, y)
c:\Users\Asus\anaconda3\lib\site-packages\scikeras\wrappers.py:915: UserWarning: ``build_fn`` will be renamed to ``model`` in a future release, at which point use of ``build_fn`` will raise an Error instead.
  X, y = self._initialize(X, y)
c:\Users\Asus\anaconda3\lib\site-packages\scikeras\wrappers.py:915: UserWarning: ``build_fn`` will be renamed to ``model`` in a future release, at which point use of ``build_fn`` will raise an Error instead.
  X, y = self._initialize(X, y)
c:\Users\Asus\anaconda3\lib\site-packages\scikeras\wrappers.py:915: UserWarning: ``build_fn`` will be renamed to ``model`` in a future release, at which point use of ``build_fn`` will raise an Error instead.
  X, y = self._initialize(X, y)
c:\Users\Asus\anaconda3\lib\site-packages\scikeras\wrappers.py:915: UserWarning: ``build_fn`` will be renamed to ``model`` in a future release, at which point use of ``build_fn`` will raise an Error instead.
  X, y = self._initialize(X, y)
c:\Users\Asus\anaconda3\lib\site-packages\scikeras\wrappers.py:915: UserWarning: ``build_fn`` will be renamed to ``model`` in a future release, at which point use of ``build_fn`` will raise an Error instead.
  X, y = self._initialize(X, y)
c:\Users\Asus\anaconda3\lib\site-packages\scikeras\wrappers.py:915: UserWarning: ``build_fn`` will be renamed to ``model`` in a future release, at which point use of ``build_fn`` will raise an Error instead.
  X, y = self._initialize(X, y)
c:\Users\Asus\anaconda3\lib\site-packages\scikeras\wrappers.py:915: UserWarning: ``build_fn`` will be renamed to ``model`` in a future release, at which point use of ``build_fn`` will raise an Error instead.
  X, y = self._initialize(X, y)
c:\Users\Asus\anaconda3\lib\site-packages\scikeras\wrappers.py:915: UserWarning: ``build_fn`` will be renamed to ``model`` in a future release, at which point use of ``build_fn`` will raise an Error instead.
  X, y = self._initialize(X, y)
c:\Users\Asus\anaconda3\lib\site-packages\scikeras\wrappers.py:915: UserWarning: ``build_fn`` will be renamed to ``model`` in a future release, at which point use of ``build_fn`` will raise an Error instead.
  X, y = self._initialize(X, y)
c:\Users\Asus\anaconda3\lib\site-packages\scikeras\wrappers.py:915: UserWarning: ``build_fn`` will be renamed to ``model`` in a future release, at which point use of ``build_fn`` will raise an Error instead.
  X, y = self._initialize(X, y)
c:\Users\Asus\anaconda3\lib\site-packages\scikeras\wrappers.py:915: UserWarning: ``build_fn`` will be renamed to ``model`` in a future release, at which point use of ``build_fn`` will raise an Error instead.
  X, y = self._initialize(X, y)
c:\Users\Asus\anaconda3\lib\site-packages\scikeras\wrappers.py:915: UserWarning: ``build_fn`` will be renamed to ``model`` in a future release, at which point use of ``build_fn`` will raise an Error instead.
  X, y = self._initialize(X, y)
c:\Users\Asus\anaconda3\lib\site-packages\scikeras\wrappers.py:915: UserWarning: ``build_fn`` will be renamed to ``model`` in a future release, at which point use of ``build_fn`` will raise an Error instead.
  X, y = self._initialize(X, y)
c:\Users\Asus\anaconda3\lib\site-packages\scikeras\wrappers.py:915: UserWarning: ``build_fn`` will be renamed to ``model`` in a future release, at which point use of ``build_fn`` will raise an Error instead.
  X, y = self._initialize(X, y)
c:\Users\Asus\anaconda3\lib\site-packages\scikeras\wrappers.py:915: UserWarning: ``build_fn`` will be renamed to ``model`` in a future release, at which point use of ``build_fn`` will raise an Error instead.
  X, y = self._initialize(X, y)
c:\Users\Asus\anaconda3\lib\site-packages\scikeras\wrappers.py:915: UserWarning: ``build_fn`` will be renamed to ``model`` in a future release, at which point use of ``build_fn`` will raise an Error instead.
  X, y = self._initialize(X, y)
c:\Users\Asus\anaconda3\lib\site-packages\scikeras\wrappers.py:915: UserWarning: ``build_fn`` will be renamed to ``model`` in a future release, at which point use of ``build_fn`` will raise an Error instead.
  X, y = self._initialize(X, y)
c:\Users\Asus\anaconda3\lib\site-packages\scikeras\wrappers.py:915: UserWarning: ``build_fn`` will be renamed to ``model`` in a future release, at which point use of ``build_fn`` will raise an Error instead.
  X, y = self._initialize(X, y)
c:\Users\Asus\anaconda3\lib\site-packages\scikeras\wrappers.py:915: UserWarning: ``build_fn`` will be renamed to ``model`` in a future release, at which point use of ``build_fn`` will raise an Error instead.
  X, y = self._initialize(X, y)
c:\Users\Asus\anaconda3\lib\site-packages\scikeras\wrappers.py:915: UserWarning: ``build_fn`` will be renamed to ``model`` in a future release, at which point use of ``build_fn`` will raise an Error instead.
  X, y = self._initialize(X, y)
c:\Users\Asus\anaconda3\lib\site-packages\scikeras\wrappers.py:915: UserWarning: ``build_fn`` will be renamed to ``model`` in a future release, at which point use of ``build_fn`` will raise an Error instead.
  X, y = self._initialize(X, y)
c:\Users\Asus\anaconda3\lib\site-packages\scikeras\wrappers.py:915: UserWarning: ``build_fn`` will be renamed to ``model`` in a future release, at which point use of ``build_fn`` will raise an Error instead.
  X, y = self._initialize(X, y)
c:\Users\Asus\anaconda3\lib\site-packages\scikeras\wrappers.py:915: UserWarning: ``build_fn`` will be renamed to ``model`` in a future release, at which point use of ``build_fn`` will raise an Error instead.
  X, y = self._initialize(X, y)
c:\Users\Asus\anaconda3\lib\site-packages\scikeras\wrappers.py:915: UserWarning: ``build_fn`` will be renamed to ``model`` in a future release, at which point use of ``build_fn`` will raise an Error instead.
  X, y = self._initialize(X, y)
c:\Users\Asus\anaconda3\lib\site-packages\scikeras\wrappers.py:915: UserWarning: ``build_fn`` will be renamed to ``model`` in a future release, at which point use of ``build_fn`` will raise an Error instead.
  X, y = self._initialize(X, y)
c:\Users\Asus\anaconda3\lib\site-packages\scikeras\wrappers.py:915: UserWarning: ``build_fn`` will be renamed to ``model`` in a future release, at which point use of ``build_fn`` will raise an Error instead.
  X, y = self._initialize(X, y)
c:\Users\Asus\anaconda3\lib\site-packages\scikeras\wrappers.py:915: UserWarning: ``build_fn`` will be renamed to ``model`` in a future release, at which point use of ``build_fn`` will raise an Error instead.
  X, y = self._initialize(X, y)
c:\Users\Asus\anaconda3\lib\site-packages\scikeras\wrappers.py:915: UserWarning: ``build_fn`` will be renamed to ``model`` in a future release, at which point use of ``build_fn`` will raise an Error instead.
  X, y = self._initialize(X, y)
c:\Users\Asus\anaconda3\lib\site-packages\scikeras\wrappers.py:915: UserWarning: ``build_fn`` will be renamed to ``model`` in a future release, at which point use of ``build_fn`` will raise an Error instead.
  X, y = self._initialize(X, y)
c:\Users\Asus\anaconda3\lib\site-packages\scikeras\wrappers.py:915: UserWarning: ``build_fn`` will be renamed to ``model`` in a future release, at which point use of ``build_fn`` will raise an Error instead.
  X, y = self._initialize(X, y)
c:\Users\Asus\anaconda3\lib\site-packages\scikeras\wrappers.py:915: UserWarning: ``build_fn`` will be renamed to ``model`` in a future release, at which point use of ``build_fn`` will raise an Error instead.
  X, y = self._initialize(X, y)
c:\Users\Asus\anaconda3\lib\site-packages\scikeras\wrappers.py:915: UserWarning: ``build_fn`` will be renamed to ``model`` in a future release, at which point use of ``build_fn`` will raise an Error instead.
  X, y = self._initialize(X, y)
c:\Users\Asus\anaconda3\lib\site-packages\scikeras\wrappers.py:915: UserWarning: ``build_fn`` will be renamed to ``model`` in a future release, at which point use of ``build_fn`` will raise an Error instead.
  X, y = self._initialize(X, y)
c:\Users\Asus\anaconda3\lib\site-packages\scikeras\wrappers.py:915: UserWarning: ``build_fn`` will be renamed to ``model`` in a future release, at which point use of ``build_fn`` will raise an Error instead.
  X, y = self._initialize(X, y)
c:\Users\Asus\anaconda3\lib\site-packages\scikeras\wrappers.py:915: UserWarning: ``build_fn`` will be renamed to ``model`` in a future release, at which point use of ``build_fn`` will raise an Error instead.
  X, y = self._initialize(X, y)
c:\Users\Asus\anaconda3\lib\site-packages\scikeras\wrappers.py:915: UserWarning: ``build_fn`` will be renamed to ``model`` in a future release, at which point use of ``build_fn`` will raise an Error instead.
  X, y = self._initialize(X, y)
c:\Users\Asus\anaconda3\lib\site-packages\scikeras\wrappers.py:915: UserWarning: ``build_fn`` will be renamed to ``model`` in a future release, at which point use of ``build_fn`` will raise an Error instead.
  X, y = self._initialize(X, y)
c:\Users\Asus\anaconda3\lib\site-packages\scikeras\wrappers.py:915: UserWarning: ``build_fn`` will be renamed to ``model`` in a future release, at which point use of ``build_fn`` will raise an Error instead.
  X, y = self._initialize(X, y)
c:\Users\Asus\anaconda3\lib\site-packages\scikeras\wrappers.py:915: UserWarning: ``build_fn`` will be renamed to ``model`` in a future release, at which point use of ``build_fn`` will raise an Error instead.
  X, y = self._initialize(X, y)
c:\Users\Asus\anaconda3\lib\site-packages\scikeras\wrappers.py:915: UserWarning: ``build_fn`` will be renamed to ``model`` in a future release, at which point use of ``build_fn`` will raise an Error instead.
  X, y = self._initialize(X, y)
c:\Users\Asus\anaconda3\lib\site-packages\scikeras\wrappers.py:915: UserWarning: ``build_fn`` will be renamed to ``model`` in a future release, at which point use of ``build_fn`` will raise an Error instead.
  X, y = self._initialize(X, y)
c:\Users\Asus\anaconda3\lib\site-packages\scikeras\wrappers.py:915: UserWarning: ``build_fn`` will be renamed to ``model`` in a future release, at which point use of ``build_fn`` will raise an Error instead.
  X, y = self._initialize(X, y)
c:\Users\Asus\anaconda3\lib\site-packages\scikeras\wrappers.py:915: UserWarning: ``build_fn`` will be renamed to ``model`` in a future release, at which point use of ``build_fn`` will raise an Error instead.
  X, y = self._initialize(X, y)
c:\Users\Asus\anaconda3\lib\site-packages\scikeras\wrappers.py:915: UserWarning: ``build_fn`` will be renamed to ``model`` in a future release, at which point use of ``build_fn`` will raise an Error instead.
  X, y = self._initialize(X, y)
c:\Users\Asus\anaconda3\lib\site-packages\scikeras\wrappers.py:915: UserWarning: ``build_fn`` will be renamed to ``model`` in a future release, at which point use of ``build_fn`` will raise an Error instead.
  X, y = self._initialize(X, y)
c:\Users\Asus\anaconda3\lib\site-packages\scikeras\wrappers.py:915: UserWarning: ``build_fn`` will be renamed to ``model`` in a future release, at which point use of ``build_fn`` will raise an Error instead.
  X, y = self._initialize(X, y)
c:\Users\Asus\anaconda3\lib\site-packages\scikeras\wrappers.py:915: UserWarning: ``build_fn`` will be renamed to ``model`` in a future release, at which point use of ``build_fn`` will raise an Error instead.
  X, y = self._initialize(X, y)
c:\Users\Asus\anaconda3\lib\site-packages\scikeras\wrappers.py:915: UserWarning: ``build_fn`` will be renamed to ``model`` in a future release, at which point use of ``build_fn`` will raise an Error instead.
  X, y = self._initialize(X, y)
c:\Users\Asus\anaconda3\lib\site-packages\scikeras\wrappers.py:915: UserWarning: ``build_fn`` will be renamed to ``model`` in a future release, at which point use of ``build_fn`` will raise an Error instead.
  X, y = self._initialize(X, y)
c:\Users\Asus\anaconda3\lib\site-packages\scikeras\wrappers.py:915: UserWarning: ``build_fn`` will be renamed to ``model`` in a future release, at which point use of ``build_fn`` will raise an Error instead.
  X, y = self._initialize(X, y)
c:\Users\Asus\anaconda3\lib\site-packages\scikeras\wrappers.py:915: UserWarning: ``build_fn`` will be renamed to ``model`` in a future release, at which point use of ``build_fn`` will raise an Error instead.
  X, y = self._initialize(X, y)
c:\Users\Asus\anaconda3\lib\site-packages\scikeras\wrappers.py:915: UserWarning: ``build_fn`` will be renamed to ``model`` in a future release, at which point use of ``build_fn`` will raise an Error instead.
  X, y = self._initialize(X, y)
c:\Users\Asus\anaconda3\lib\site-packages\scikeras\wrappers.py:915: UserWarning: ``build_fn`` will be renamed to ``model`` in a future release, at which point use of ``build_fn`` will raise an Error instead.
  X, y = self._initialize(X, y)
c:\Users\Asus\anaconda3\lib\site-packages\scikeras\wrappers.py:915: UserWarning: ``build_fn`` will be renamed to ``model`` in a future release, at which point use of ``build_fn`` will raise an Error instead.
  X, y = self._initialize(X, y)
c:\Users\Asus\anaconda3\lib\site-packages\scikeras\wrappers.py:915: UserWarning: ``build_fn`` will be renamed to ``model`` in a future release, at which point use of ``build_fn`` will raise an Error instead.
  X, y = self._initialize(X, y)
c:\Users\Asus\anaconda3\lib\site-packages\scikeras\wrappers.py:915: UserWarning: ``build_fn`` will be renamed to ``model`` in a future release, at which point use of ``build_fn`` will raise an Error instead.
  X, y = self._initialize(X, y)
c:\Users\Asus\anaconda3\lib\site-packages\scikeras\wrappers.py:915: UserWarning: ``build_fn`` will be renamed to ``model`` in a future release, at which point use of ``build_fn`` will raise an Error instead.
  X, y = self._initialize(X, y)
c:\Users\Asus\anaconda3\lib\site-packages\scikeras\wrappers.py:915: UserWarning: ``build_fn`` will be renamed to ``model`` in a future release, at which point use of ``build_fn`` will raise an Error instead.
  X, y = self._initialize(X, y)
c:\Users\Asus\anaconda3\lib\site-packages\scikeras\wrappers.py:915: UserWarning: ``build_fn`` will be renamed to ``model`` in a future release, at which point use of ``build_fn`` will raise an Error instead.
  X, y = self._initialize(X, y)
c:\Users\Asus\anaconda3\lib\site-packages\scikeras\wrappers.py:915: UserWarning: ``build_fn`` will be renamed to ``model`` in a future release, at which point use of ``build_fn`` will raise an Error instead.
  X, y = self._initialize(X, y)
c:\Users\Asus\anaconda3\lib\site-packages\scikeras\wrappers.py:915: UserWarning: ``build_fn`` will be renamed to ``model`` in a future release, at which point use of ``build_fn`` will raise an Error instead.
  X, y = self._initialize(X, y)
c:\Users\Asus\anaconda3\lib\site-packages\sklearn\model_selection\_validation.py:425: FitFailedWarning: 
54 fits failed out of a total of 72.
The score on these train-test partitions for these parameters will be set to nan.
If these failures are not expected, you can try to debug them by setting error_score='raise'.

Below are more details about the failures:
--------------------------------------------------------------------------------
18 fits failed with the following error:
Traceback (most recent call last):
  File "c:\Users\Asus\anaconda3\lib\site-packages\sklearn\model_selection\_validation.py", line 729, in _fit_and_score
    estimator.fit(X_train, y_train, **fit_params)
  File "c:\Users\Asus\anaconda3\lib\site-packages\scikeras\wrappers.py", line 1491, in fit
    super().fit(X=X, y=y, sample_weight=sample_weight, **kwargs)
  File "c:\Users\Asus\anaconda3\lib\site-packages\scikeras\wrappers.py", line 760, in fit
    self._fit(
  File "c:\Users\Asus\anaconda3\lib\site-packages\scikeras\wrappers.py", line 926, in _fit
    self._check_model_compatibility(y)
  File "c:\Users\Asus\anaconda3\lib\site-packages\scikeras\wrappers.py", line 569, in _check_model_compatibility
    raise ValueError(
ValueError: loss=binary_crossentropy but model compiled with categorical_crossentropy. Data may not match loss function!

--------------------------------------------------------------------------------
18 fits failed with the following error:
Traceback (most recent call last):
  File "c:\Users\Asus\anaconda3\lib\site-packages\sklearn\model_selection\_validation.py", line 729, in _fit_and_score
    estimator.fit(X_train, y_train, **fit_params)
  File "c:\Users\Asus\anaconda3\lib\site-packages\scikeras\wrappers.py", line 1491, in fit
    super().fit(X=X, y=y, sample_weight=sample_weight, **kwargs)
  File "c:\Users\Asus\anaconda3\lib\site-packages\scikeras\wrappers.py", line 760, in fit
    self._fit(
  File "c:\Users\Asus\anaconda3\lib\site-packages\scikeras\wrappers.py", line 926, in _fit
    self._check_model_compatibility(y)
  File "c:\Users\Asus\anaconda3\lib\site-packages\scikeras\wrappers.py", line 569, in _check_model_compatibility
    raise ValueError(
ValueError: loss=hinge but model compiled with categorical_crossentropy. Data may not match loss function!

--------------------------------------------------------------------------------
18 fits failed with the following error:
Traceback (most recent call last):
  File "c:\Users\Asus\anaconda3\lib\site-packages\sklearn\model_selection\_validation.py", line 729, in _fit_and_score
    estimator.fit(X_train, y_train, **fit_params)
  File "c:\Users\Asus\anaconda3\lib\site-packages\scikeras\wrappers.py", line 1491, in fit
    super().fit(X=X, y=y, sample_weight=sample_weight, **kwargs)
  File "c:\Users\Asus\anaconda3\lib\site-packages\scikeras\wrappers.py", line 760, in fit
    self._fit(
  File "c:\Users\Asus\anaconda3\lib\site-packages\scikeras\wrappers.py", line 926, in _fit
    self._check_model_compatibility(y)
  File "c:\Users\Asus\anaconda3\lib\site-packages\scikeras\wrappers.py", line 569, in _check_model_compatibility
    raise ValueError(
ValueError: loss=kullback_leibler_divergence but model compiled with categorical_crossentropy. Data may not match loss function!

  warnings.warn(some_fits_failed_message, FitFailedWarning)
c:\Users\Asus\anaconda3\lib\site-packages\sklearn\model_selection\_search.py:979: UserWarning: One or more of the test scores are non-finite: [0.9821608  0.97135678 0.95565327 0.98329146 0.96733668 0.94987437
 0.97286432 0.9678392  0.95062814        nan        nan        nan
        nan        nan        nan        nan        nan        nan
        nan        nan        nan        nan        nan        nan
        nan        nan        nan        nan        nan        nan
        nan        nan        nan        nan        nan        nan]
  warnings.warn(
c:\Users\Asus\anaconda3\lib\site-packages\scikeras\wrappers.py:915: UserWarning: ``build_fn`` will be renamed to ``model`` in a future release, at which point use of ``build_fn`` will raise an Error instead.
  X, y = self._initialize(X, y)
In [28]:
print("Best: %f using %s" % (grid_result.best_score_, grid_result.best_params_))

# best_model = grid_result.best_estimator_
# accuracy = best_model.score(X_val, y_val)
# print("Validation Accuracy: %.2f%%" % (accuracy * 100))
Best: 0.983291 using {'loss': 'categorical_crossentropy', 'optimizer': 'sgd', 'units': 32}
In [29]:
optimizer = grid_result.best_params_['optimizer']
units = grid_result.best_params_['units']
loss = grid_result.best_params_['loss']
In [30]:
model = Sequential()
model.add(SimpleRNN(64, return_sequences = False, input_shape=(X_train.shape[1], 1)))

model.add(Dense(units, activation='relu'))
model.add(Dense(units=3, activation='sigmoid'))

model.compile(loss=loss, optimizer=optimizer, metrics=['mean_absolute_error'])

history = model.fit(X_train, y_train, epochs=100, batch_size=32, validation_data=(X_val, y_val))

accuracy = model.evaluate(X_val, y_val)
print("Accuracy: ", accuracy)
Epoch 1/100
249/249 [==============================] - 8s 13ms/step - loss: 0.8234 - mean_absolute_error: 0.4445 - val_loss: 0.6038 - val_mean_absolute_error: 0.3896
Epoch 2/100
249/249 [==============================] - 2s 8ms/step - loss: 0.4730 - mean_absolute_error: 0.3500 - val_loss: 0.3911 - val_mean_absolute_error: 0.3204
Epoch 3/100
249/249 [==============================] - 2s 9ms/step - loss: 0.3228 - mean_absolute_error: 0.2984 - val_loss: 0.2747 - val_mean_absolute_error: 0.2817
Epoch 4/100
249/249 [==============================] - 2s 9ms/step - loss: 0.2502 - mean_absolute_error: 0.2735 - val_loss: 0.2468 - val_mean_absolute_error: 0.2685
Epoch 5/100
249/249 [==============================] - 2s 8ms/step - loss: 0.2096 - mean_absolute_error: 0.2594 - val_loss: 0.1874 - val_mean_absolute_error: 0.2532
Epoch 6/100
249/249 [==============================] - 2s 7ms/step - loss: 0.1821 - mean_absolute_error: 0.2506 - val_loss: 0.1691 - val_mean_absolute_error: 0.2457
Epoch 7/100
249/249 [==============================] - 2s 7ms/step - loss: 0.1657 - mean_absolute_error: 0.2446 - val_loss: 0.1516 - val_mean_absolute_error: 0.2414
Epoch 8/100
249/249 [==============================] - 2s 7ms/step - loss: 0.1547 - mean_absolute_error: 0.2398 - val_loss: 0.1535 - val_mean_absolute_error: 0.2366
Epoch 9/100
249/249 [==============================] - 2s 8ms/step - loss: 0.1424 - mean_absolute_error: 0.2355 - val_loss: 0.1291 - val_mean_absolute_error: 0.2335
Epoch 10/100
249/249 [==============================] - 2s 7ms/step - loss: 0.1306 - mean_absolute_error: 0.2314 - val_loss: 0.1478 - val_mean_absolute_error: 0.2325
Epoch 11/100
249/249 [==============================] - 2s 7ms/step - loss: 0.1221 - mean_absolute_error: 0.2284 - val_loss: 0.1193 - val_mean_absolute_error: 0.2266
Epoch 12/100
249/249 [==============================] - 2s 7ms/step - loss: 0.1171 - mean_absolute_error: 0.2256 - val_loss: 0.1140 - val_mean_absolute_error: 0.2203
Epoch 13/100
249/249 [==============================] - 2s 7ms/step - loss: 0.1108 - mean_absolute_error: 0.2229 - val_loss: 0.1263 - val_mean_absolute_error: 0.2237
Epoch 14/100
249/249 [==============================] - 2s 7ms/step - loss: 0.1054 - mean_absolute_error: 0.2202 - val_loss: 0.1068 - val_mean_absolute_error: 0.2177
Epoch 15/100
249/249 [==============================] - 2s 7ms/step - loss: 0.1029 - mean_absolute_error: 0.2188 - val_loss: 0.1485 - val_mean_absolute_error: 0.2263
Epoch 16/100
249/249 [==============================] - 2s 7ms/step - loss: 0.1009 - mean_absolute_error: 0.2172 - val_loss: 0.1079 - val_mean_absolute_error: 0.2196
Epoch 17/100
249/249 [==============================] - 2s 7ms/step - loss: 0.0946 - mean_absolute_error: 0.2157 - val_loss: 0.0888 - val_mean_absolute_error: 0.2133
Epoch 18/100
249/249 [==============================] - 2s 7ms/step - loss: 0.0926 - mean_absolute_error: 0.2135 - val_loss: 0.0973 - val_mean_absolute_error: 0.2152
Epoch 19/100
249/249 [==============================] - 2s 7ms/step - loss: 0.0924 - mean_absolute_error: 0.2117 - val_loss: 0.0989 - val_mean_absolute_error: 0.2107
Epoch 20/100
249/249 [==============================] - 2s 7ms/step - loss: 0.0890 - mean_absolute_error: 0.2110 - val_loss: 0.0791 - val_mean_absolute_error: 0.2065
Epoch 21/100
249/249 [==============================] - 2s 7ms/step - loss: 0.0862 - mean_absolute_error: 0.2086 - val_loss: 0.0744 - val_mean_absolute_error: 0.2078
Epoch 22/100
249/249 [==============================] - 2s 7ms/step - loss: 0.0836 - mean_absolute_error: 0.2076 - val_loss: 0.0761 - val_mean_absolute_error: 0.2068
Epoch 23/100
249/249 [==============================] - 2s 7ms/step - loss: 0.0875 - mean_absolute_error: 0.2069 - val_loss: 0.0816 - val_mean_absolute_error: 0.2056
Epoch 24/100
249/249 [==============================] - 2s 7ms/step - loss: 0.0793 - mean_absolute_error: 0.2054 - val_loss: 0.0677 - val_mean_absolute_error: 0.2031
Epoch 25/100
249/249 [==============================] - 2s 7ms/step - loss: 0.0794 - mean_absolute_error: 0.2046 - val_loss: 0.0665 - val_mean_absolute_error: 0.2013
Epoch 26/100
249/249 [==============================] - 2s 7ms/step - loss: 0.0789 - mean_absolute_error: 0.2028 - val_loss: 0.0650 - val_mean_absolute_error: 0.1977
Epoch 27/100
249/249 [==============================] - 2s 7ms/step - loss: 0.0724 - mean_absolute_error: 0.2010 - val_loss: 0.0748 - val_mean_absolute_error: 0.1948
Epoch 28/100
249/249 [==============================] - 2s 8ms/step - loss: 0.0711 - mean_absolute_error: 0.2000 - val_loss: 0.0790 - val_mean_absolute_error: 0.1988
Epoch 29/100
249/249 [==============================] - 2s 7ms/step - loss: 0.0792 - mean_absolute_error: 0.1997 - val_loss: 0.0685 - val_mean_absolute_error: 0.1943
Epoch 30/100
249/249 [==============================] - 2s 7ms/step - loss: 0.0711 - mean_absolute_error: 0.1976 - val_loss: 0.0759 - val_mean_absolute_error: 0.1917
Epoch 31/100
249/249 [==============================] - 2s 7ms/step - loss: 0.0711 - mean_absolute_error: 0.1972 - val_loss: 0.0668 - val_mean_absolute_error: 0.1936
Epoch 32/100
249/249 [==============================] - 2s 7ms/step - loss: 0.0691 - mean_absolute_error: 0.1963 - val_loss: 0.0653 - val_mean_absolute_error: 0.1902
Epoch 33/100
249/249 [==============================] - 2s 7ms/step - loss: 0.0708 - mean_absolute_error: 0.1954 - val_loss: 0.1067 - val_mean_absolute_error: 0.1988
Epoch 34/100
249/249 [==============================] - 2s 7ms/step - loss: 0.0656 - mean_absolute_error: 0.1950 - val_loss: 0.0621 - val_mean_absolute_error: 0.1884
Epoch 35/100
249/249 [==============================] - 2s 7ms/step - loss: 0.0650 - mean_absolute_error: 0.1931 - val_loss: 0.0613 - val_mean_absolute_error: 0.1862
Epoch 36/100
249/249 [==============================] - 2s 7ms/step - loss: 0.0656 - mean_absolute_error: 0.1924 - val_loss: 0.0801 - val_mean_absolute_error: 0.1921
Epoch 37/100
249/249 [==============================] - 2s 8ms/step - loss: 0.0591 - mean_absolute_error: 0.1910 - val_loss: 0.0658 - val_mean_absolute_error: 0.1883
Epoch 38/100
249/249 [==============================] - 2s 7ms/step - loss: 0.0649 - mean_absolute_error: 0.1904 - val_loss: 0.0996 - val_mean_absolute_error: 0.1835
Epoch 39/100
249/249 [==============================] - 2s 7ms/step - loss: 0.0622 - mean_absolute_error: 0.1900 - val_loss: 0.0553 - val_mean_absolute_error: 0.1863
Epoch 40/100
249/249 [==============================] - 2s 7ms/step - loss: 0.0601 - mean_absolute_error: 0.1895 - val_loss: 0.0523 - val_mean_absolute_error: 0.1855
Epoch 41/100
249/249 [==============================] - 2s 7ms/step - loss: 0.0625 - mean_absolute_error: 0.1882 - val_loss: 0.0825 - val_mean_absolute_error: 0.1869
Epoch 42/100
249/249 [==============================] - 2s 7ms/step - loss: 0.0565 - mean_absolute_error: 0.1878 - val_loss: 0.0621 - val_mean_absolute_error: 0.1838
Epoch 43/100
249/249 [==============================] - 2s 7ms/step - loss: 0.0601 - mean_absolute_error: 0.1860 - val_loss: 0.0701 - val_mean_absolute_error: 0.1806
Epoch 44/100
249/249 [==============================] - 2s 7ms/step - loss: 0.0612 - mean_absolute_error: 0.1856 - val_loss: 0.0566 - val_mean_absolute_error: 0.1824
Epoch 45/100
249/249 [==============================] - 2s 7ms/step - loss: 0.0564 - mean_absolute_error: 0.1849 - val_loss: 0.0698 - val_mean_absolute_error: 0.1864
Epoch 46/100
249/249 [==============================] - 2s 7ms/step - loss: 0.0557 - mean_absolute_error: 0.1839 - val_loss: 0.0577 - val_mean_absolute_error: 0.1783
Epoch 47/100
249/249 [==============================] - 2s 7ms/step - loss: 0.0561 - mean_absolute_error: 0.1833 - val_loss: 0.0533 - val_mean_absolute_error: 0.1818
Epoch 48/100
249/249 [==============================] - 2s 7ms/step - loss: 0.0529 - mean_absolute_error: 0.1836 - val_loss: 0.0589 - val_mean_absolute_error: 0.1800
Epoch 49/100
249/249 [==============================] - 2s 8ms/step - loss: 0.0586 - mean_absolute_error: 0.1826 - val_loss: 0.0942 - val_mean_absolute_error: 0.1777
Epoch 50/100
249/249 [==============================] - 2s 7ms/step - loss: 0.0615 - mean_absolute_error: 0.1817 - val_loss: 0.0574 - val_mean_absolute_error: 0.1771
Epoch 51/100
249/249 [==============================] - 2s 7ms/step - loss: 0.0516 - mean_absolute_error: 0.1812 - val_loss: 0.0584 - val_mean_absolute_error: 0.1813
Epoch 52/100
249/249 [==============================] - 2s 7ms/step - loss: 0.0526 - mean_absolute_error: 0.1805 - val_loss: 0.1105 - val_mean_absolute_error: 0.1847
Epoch 53/100
249/249 [==============================] - 2s 7ms/step - loss: 0.0605 - mean_absolute_error: 0.1799 - val_loss: 0.0443 - val_mean_absolute_error: 0.1781
Epoch 54/100
249/249 [==============================] - 2s 7ms/step - loss: 0.0586 - mean_absolute_error: 0.1797 - val_loss: 0.0744 - val_mean_absolute_error: 0.1791
Epoch 55/100
249/249 [==============================] - 2s 8ms/step - loss: 0.0544 - mean_absolute_error: 0.1792 - val_loss: 0.0840 - val_mean_absolute_error: 0.1792
Epoch 56/100
249/249 [==============================] - 2s 7ms/step - loss: 0.0615 - mean_absolute_error: 0.1781 - val_loss: 0.0782 - val_mean_absolute_error: 0.1775
Epoch 57/100
249/249 [==============================] - 2s 7ms/step - loss: 0.0521 - mean_absolute_error: 0.1782 - val_loss: 0.0452 - val_mean_absolute_error: 0.1783
Epoch 58/100
249/249 [==============================] - 2s 8ms/step - loss: 0.0474 - mean_absolute_error: 0.1776 - val_loss: 0.0634 - val_mean_absolute_error: 0.1773
Epoch 59/100
249/249 [==============================] - 2s 7ms/step - loss: 0.0496 - mean_absolute_error: 0.1771 - val_loss: 0.0426 - val_mean_absolute_error: 0.1762
Epoch 60/100
249/249 [==============================] - 2s 7ms/step - loss: 0.0508 - mean_absolute_error: 0.1770 - val_loss: 0.0844 - val_mean_absolute_error: 0.1811
Epoch 61/100
249/249 [==============================] - 2s 7ms/step - loss: 0.0516 - mean_absolute_error: 0.1755 - val_loss: 0.0481 - val_mean_absolute_error: 0.1698
Epoch 62/100
249/249 [==============================] - 2s 7ms/step - loss: 0.0497 - mean_absolute_error: 0.1753 - val_loss: 0.0468 - val_mean_absolute_error: 0.1692
Epoch 63/100
249/249 [==============================] - 2s 6ms/step - loss: 0.0498 - mean_absolute_error: 0.1746 - val_loss: 0.0526 - val_mean_absolute_error: 0.1754
Epoch 64/100
249/249 [==============================] - 1s 6ms/step - loss: 0.0536 - mean_absolute_error: 0.1741 - val_loss: 0.0490 - val_mean_absolute_error: 0.1690
Epoch 65/100
249/249 [==============================] - 1s 6ms/step - loss: 0.0523 - mean_absolute_error: 0.1735 - val_loss: 0.0579 - val_mean_absolute_error: 0.1747
Epoch 66/100
249/249 [==============================] - 1s 6ms/step - loss: 0.0524 - mean_absolute_error: 0.1732 - val_loss: 0.0490 - val_mean_absolute_error: 0.1710
Epoch 67/100
249/249 [==============================] - 1s 6ms/step - loss: 0.0605 - mean_absolute_error: 0.1737 - val_loss: 0.0442 - val_mean_absolute_error: 0.1678
Epoch 68/100
249/249 [==============================] - 2s 7ms/step - loss: 0.0505 - mean_absolute_error: 0.1726 - val_loss: 0.0580 - val_mean_absolute_error: 0.1681
Epoch 69/100
249/249 [==============================] - 2s 8ms/step - loss: 0.0461 - mean_absolute_error: 0.1715 - val_loss: 0.0486 - val_mean_absolute_error: 0.1704
Epoch 70/100
249/249 [==============================] - 2s 8ms/step - loss: 0.0451 - mean_absolute_error: 0.1718 - val_loss: 0.0567 - val_mean_absolute_error: 0.1740
Epoch 71/100
249/249 [==============================] - 2s 8ms/step - loss: 0.0600 - mean_absolute_error: 0.1720 - val_loss: 0.0725 - val_mean_absolute_error: 0.1740
Epoch 72/100
249/249 [==============================] - 2s 7ms/step - loss: 0.0429 - mean_absolute_error: 0.1703 - val_loss: 0.0511 - val_mean_absolute_error: 0.1669
Epoch 73/100
249/249 [==============================] - 2s 10ms/step - loss: 0.0521 - mean_absolute_error: 0.1697 - val_loss: 0.0522 - val_mean_absolute_error: 0.1637
Epoch 74/100
249/249 [==============================] - 3s 10ms/step - loss: 0.0465 - mean_absolute_error: 0.1696 - val_loss: 0.0598 - val_mean_absolute_error: 0.1681
Epoch 75/100
249/249 [==============================] - 2s 9ms/step - loss: 0.0455 - mean_absolute_error: 0.1690 - val_loss: 0.0921 - val_mean_absolute_error: 0.1700
Epoch 76/100
249/249 [==============================] - 2s 8ms/step - loss: 0.0446 - mean_absolute_error: 0.1687 - val_loss: 0.0460 - val_mean_absolute_error: 0.1699
Epoch 77/100
249/249 [==============================] - 2s 7ms/step - loss: 0.0436 - mean_absolute_error: 0.1686 - val_loss: 0.0389 - val_mean_absolute_error: 0.1660
Epoch 78/100
249/249 [==============================] - 2s 9ms/step - loss: 0.0464 - mean_absolute_error: 0.1685 - val_loss: 0.0448 - val_mean_absolute_error: 0.1709
Epoch 79/100
249/249 [==============================] - 2s 9ms/step - loss: 0.0461 - mean_absolute_error: 0.1686 - val_loss: 0.0668 - val_mean_absolute_error: 0.1704
Epoch 80/100
249/249 [==============================] - 2s 8ms/step - loss: 0.0444 - mean_absolute_error: 0.1675 - val_loss: 0.0921 - val_mean_absolute_error: 0.1681
Epoch 81/100
249/249 [==============================] - 2s 7ms/step - loss: 0.0426 - mean_absolute_error: 0.1675 - val_loss: 0.0489 - val_mean_absolute_error: 0.1662
Epoch 82/100
249/249 [==============================] - 2s 6ms/step - loss: 0.0378 - mean_absolute_error: 0.1680 - val_loss: 0.0384 - val_mean_absolute_error: 0.1663
Epoch 83/100
249/249 [==============================] - 2s 8ms/step - loss: 0.0413 - mean_absolute_error: 0.1668 - val_loss: 0.0398 - val_mean_absolute_error: 0.1618
Epoch 84/100
249/249 [==============================] - 2s 10ms/step - loss: 0.0376 - mean_absolute_error: 0.1666 - val_loss: 0.0480 - val_mean_absolute_error: 0.1665
Epoch 85/100
249/249 [==============================] - 2s 7ms/step - loss: 0.0432 - mean_absolute_error: 0.1667 - val_loss: 0.1177 - val_mean_absolute_error: 0.1659
Epoch 86/100
249/249 [==============================] - 2s 7ms/step - loss: 0.0405 - mean_absolute_error: 0.1662 - val_loss: 0.0574 - val_mean_absolute_error: 0.1644
Epoch 87/100
249/249 [==============================] - 2s 7ms/step - loss: 0.0414 - mean_absolute_error: 0.1648 - val_loss: 0.0460 - val_mean_absolute_error: 0.1613
Epoch 88/100
249/249 [==============================] - 2s 8ms/step - loss: 0.0380 - mean_absolute_error: 0.1645 - val_loss: 0.0375 - val_mean_absolute_error: 0.1619
Epoch 89/100
249/249 [==============================] - 2s 7ms/step - loss: 0.0424 - mean_absolute_error: 0.1643 - val_loss: 0.0622 - val_mean_absolute_error: 0.1640
Epoch 90/100
249/249 [==============================] - 2s 7ms/step - loss: 0.0367 - mean_absolute_error: 0.1634 - val_loss: 0.0475 - val_mean_absolute_error: 0.1599
Epoch 91/100
249/249 [==============================] - 1s 6ms/step - loss: 0.0378 - mean_absolute_error: 0.1644 - val_loss: 0.0351 - val_mean_absolute_error: 0.1600
Epoch 92/100
249/249 [==============================] - 2s 6ms/step - loss: 0.0438 - mean_absolute_error: 0.1635 - val_loss: 0.0621 - val_mean_absolute_error: 0.1621
Epoch 93/100
249/249 [==============================] - 2s 6ms/step - loss: 0.0438 - mean_absolute_error: 0.1626 - val_loss: 0.0354 - val_mean_absolute_error: 0.1631
Epoch 94/100
249/249 [==============================] - 2s 7ms/step - loss: 0.0377 - mean_absolute_error: 0.1626 - val_loss: 0.0431 - val_mean_absolute_error: 0.1642
Epoch 95/100
249/249 [==============================] - 2s 7ms/step - loss: 0.0474 - mean_absolute_error: 0.1631 - val_loss: 0.0546 - val_mean_absolute_error: 0.1582
Epoch 96/100
249/249 [==============================] - 2s 7ms/step - loss: 0.0438 - mean_absolute_error: 0.1618 - val_loss: 0.0354 - val_mean_absolute_error: 0.1603
Epoch 97/100
249/249 [==============================] - 2s 7ms/step - loss: 0.0463 - mean_absolute_error: 0.1628 - val_loss: 0.0440 - val_mean_absolute_error: 0.1585
Epoch 98/100
249/249 [==============================] - 1s 6ms/step - loss: 0.0391 - mean_absolute_error: 0.1616 - val_loss: 0.0471 - val_mean_absolute_error: 0.1605
Epoch 99/100
249/249 [==============================] - 1s 6ms/step - loss: 0.0400 - mean_absolute_error: 0.1611 - val_loss: 0.0448 - val_mean_absolute_error: 0.1568
Epoch 100/100
249/249 [==============================] - 2s 6ms/step - loss: 0.0359 - mean_absolute_error: 0.1617 - val_loss: 0.0400 - val_mean_absolute_error: 0.1567
32/32 [==============================] - 0s 4ms/step - loss: 0.0400 - mean_absolute_error: 0.1567
Accuracy:  [0.03999854996800423, 0.15673400461673737]
In [31]:
model.summary()
Model: "sequential_73"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
=================================================================
 simple_rnn_73 (SimpleRNN)   (None, 64)                4224      
                                                                 
 dense_146 (Dense)           (None, 32)                2080      
                                                                 
 dense_147 (Dense)           (None, 3)                 99        
                                                                 
=================================================================
Total params: 6403 (25.01 KB)
Trainable params: 6403 (25.01 KB)
Non-trainable params: 0 (0.00 Byte)
_________________________________________________________________
In [37]:
import joblib
joblib.dump(model, 'model/rnn_model.pkl')
INFO:tensorflow:Assets written to: C:\Users\Asus\AppData\Local\Temp\tmpj5ooq9gc\assets
INFO:tensorflow:Assets written to: C:\Users\Asus\AppData\Local\Temp\tmpj5ooq9gc\assets
Out[37]:
['model/rnn_model.pkl']
In [38]:
plot_model(model, to_file='mlp-mnist.png', show_shapes=True)
You must install pydot (`pip install pydot`) and install graphviz (see instructions at https://graphviz.gitlab.io/download/) for plot_model to work.
In [39]:
predictions = model.predict(X_test)
y_pred = np.argmax(predictions, axis=1)
11/32 [=========>....................] - ETA: 0s32/32 [==============================] - 0s 5ms/step
In [40]:
plt.figure(figsize=(12, 5))
plt.subplot(1, 2, 1)
plt.plot(history.history['loss'])
plt.plot(history.history['val_loss'])
plt.title('Loss over epochs')
plt.ylabel('Loss')
plt.xlabel('Epoch')
plt.legend(['Train', 'Validation'], loc='best')

plt.subplot(1, 2, 2)
plt.plot(history.history['mean_absolute_error'])
plt.plot(history.history['val_mean_absolute_error'])
plt.title('Accuracy over epochs')
plt.ylabel('Accuracy')
plt.xlabel('Epoch')
plt.legend(['Train', 'Validation'], loc='best')

plt.tight_layout()
plt.show()
In [41]:
cm = confusion_matrix(y_test, y_pred)
disp = ConfusionMatrixDisplay(confusion_matrix=cm)
disp.plot(cmap='Blues', values_format='d')
plt.title("Confusion Matrix: RNN")
plt.show()

print(classification_report(y_test, y_pred))
              precision    recall  f1-score   support

           0       0.99      0.99      0.99       327
           1       0.98      0.96      0.97       352
           2       0.97      0.99      0.98       317

    accuracy                           0.98       996
   macro avg       0.98      0.98      0.98       996
weighted avg       0.98      0.98      0.98       996