Libraries and Dataset¶

In [1]:
import numpy as np
import pandas as pd
import seaborn  as sns
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import LabelEncoder, OneHotEncoder, StandardScaler
from sklearn.model_selection import cross_val_score, GridSearchCV
import matplotlib.pyplot as plt
from sklearn.metrics import confusion_matrix,classification_report, ConfusionMatrixDisplay
In [2]:
dataset = pd.read_csv('data/market_cluster.csv', encoding='latin1')
In [3]:
dataset.head()
Out[3]:
Order ID Customer Name Category Sub Category City Order Date Region Sales Discount Profit State profit_margin Cluster
0 OD1 Harish Oil & Masala Masalas Vellore 11-08-2017 North 1254 0.12 401.28 Tamil Nadu 0.32 Medium
1 OD2 Sudha Beverages Health Drinks Krishnagiri 11-08-2017 South 749 0.18 149.80 Tamil Nadu 0.20 Medium
2 OD3 Hussain Food Grains Atta & Flour Perambalur 06-12-2017 West 2360 0.21 165.20 Tamil Nadu 0.07 Low
3 OD4 Jackson Fruits & Veggies Fresh Vegetables Dharmapuri 10-11-2016 South 896 0.25 89.60 Tamil Nadu 0.10 Low
4 OD5 Ridhesh Food Grains Organic Staples Ooty 10-11-2016 South 2355 0.26 918.45 Tamil Nadu 0.39 High
In [4]:
dataset.info()
<class 'pandas.core.frame.DataFrame'>
RangeIndex: 9994 entries, 0 to 9993
Data columns (total 13 columns):
 #   Column         Non-Null Count  Dtype  
---  ------         --------------  -----  
 0   Order ID       9994 non-null   object 
 1   Customer Name  9994 non-null   object 
 2   Category       9994 non-null   object 
 3   Sub Category   9994 non-null   object 
 4   City           9994 non-null   object 
 5   Order Date     9994 non-null   object 
 6   Region         9994 non-null   object 
 7   Sales          9994 non-null   int64  
 8   Discount       9994 non-null   float64
 9   Profit         9994 non-null   float64
 10  State          9994 non-null   object 
 11  profit_margin  9994 non-null   float64
 12  Cluster        9994 non-null   object 
dtypes: float64(3), int64(1), object(9)
memory usage: 1015.1+ KB
In [5]:
dataset.describe()
Out[5]:
Sales Discount Profit profit_margin
count 9994.000000 9994.000000 9994.000000 9994.000000
mean 1496.596158 0.226817 374.937082 0.250228
std 577.559036 0.074636 239.932881 0.118919
min 500.000000 0.100000 25.250000 0.050000
25% 1000.000000 0.160000 180.022500 0.150000
50% 1498.000000 0.230000 320.780000 0.250000
75% 1994.750000 0.290000 525.627500 0.350000
max 2500.000000 0.350000 1120.950000 0.450000
In [6]:
dataset.info()
<class 'pandas.core.frame.DataFrame'>
RangeIndex: 9994 entries, 0 to 9993
Data columns (total 13 columns):
 #   Column         Non-Null Count  Dtype  
---  ------         --------------  -----  
 0   Order ID       9994 non-null   object 
 1   Customer Name  9994 non-null   object 
 2   Category       9994 non-null   object 
 3   Sub Category   9994 non-null   object 
 4   City           9994 non-null   object 
 5   Order Date     9994 non-null   object 
 6   Region         9994 non-null   object 
 7   Sales          9994 non-null   int64  
 8   Discount       9994 non-null   float64
 9   Profit         9994 non-null   float64
 10  State          9994 non-null   object 
 11  profit_margin  9994 non-null   float64
 12  Cluster        9994 non-null   object 
dtypes: float64(3), int64(1), object(9)
memory usage: 1015.1+ KB

Data Cleaning & Preprocessing¶

In [7]:
dataset.drop(['Order ID'], axis=1, inplace=True)
In [8]:
dataset.isna().sum()
Out[8]:
Customer Name    0
Category         0
Sub Category     0
City             0
Order Date       0
Region           0
Sales            0
Discount         0
Profit           0
State            0
profit_margin    0
Cluster          0
dtype: int64
In [9]:
dataset.dropna(inplace=True)
In [10]:
def remove_outliers(data: pd.DataFrame, column: str) -> pd.Series:
    q3, q1 = np.nanpercentile(data[column], [75, 25])
    iqr = q3 - q1
    upper_bound = q3 + 1.5 * iqr
    lower_bound = q1 - 1.5 * iqr
    data = data[(data[column] > lower_bound) & (data[column] < upper_bound)]

    return data

dataset = remove_outliers(dataset, 'Discount')
dataset = remove_outliers(dataset, 'Sales')
dataset = remove_outliers(dataset, 'Profit')
In [11]:
dataset.head()
Out[11]:
Customer Name Category Sub Category City Order Date Region Sales Discount Profit State profit_margin Cluster
0 Harish Oil & Masala Masalas Vellore 11-08-2017 North 1254 0.12 401.28 Tamil Nadu 0.32 Medium
1 Sudha Beverages Health Drinks Krishnagiri 11-08-2017 South 749 0.18 149.80 Tamil Nadu 0.20 Medium
2 Hussain Food Grains Atta & Flour Perambalur 06-12-2017 West 2360 0.21 165.20 Tamil Nadu 0.07 Low
3 Jackson Fruits & Veggies Fresh Vegetables Dharmapuri 10-11-2016 South 896 0.25 89.60 Tamil Nadu 0.10 Low
4 Ridhesh Food Grains Organic Staples Ooty 10-11-2016 South 2355 0.26 918.45 Tamil Nadu 0.39 High
In [12]:
sns.histplot(dataset['Cluster'])
Out[12]:
<AxesSubplot:xlabel='Cluster', ylabel='Count'>
In [13]:
encoder = LabelEncoder()
scaler = StandardScaler()
onehot = OneHotEncoder()
In [14]:
dataset["Order Date"] = pd.to_datetime(dataset["Order Date"])
dataset["Order Date"] = dataset["Order Date"].dt.month

dataset["Customer Name"] = encoder.fit_transform(dataset["Customer Name"])
dataset["Category"] = encoder.fit_transform(dataset["Category"])
dataset["City"] = encoder.fit_transform(dataset["City"])
dataset["Region"] = encoder.fit_transform(dataset["Region"])
dataset["State"] = encoder.fit_transform(dataset["State"])
dataset["Sub Category"] = encoder.fit_transform(dataset["Sub Category"])

# dataset["Customer Name"] = onehot.fit_transform(dataset["Customer Name"].values.reshape(-1, 1)).toarray()
# dataset["Category"] = onehot.fit_transform(dataset["Category"].values.reshape(-1, 1)).toarray()
# dataset["City"] = onehot.fit_transform(dataset["City"].values.reshape(-1, 1)).toarray()
# dataset["Region"] = onehot.fit_transform(dataset["Region"].values.reshape(-1, 1)).toarray()
# dataset["State"] = onehot.fit_transform(dataset["State"].values.reshape(-1, 1)).toarray()
# dataset["Sub Category"] = onehot.fit_transform(dataset["Sub Category"].values.reshape(-1, 1)).toarray()

dataset["Order Date"] = encoder.fit_transform(dataset["Order Date"])
In [15]:
dataset[["Sales", "Discount", "profit_margin","Profit"]] = scaler.fit_transform(dataset[["Sales", "Discount", "profit_margin","Profit"]])
In [16]:
class_to_numeric = {'Low': 0, 'Medium': 1, 'High': 2}
dataset['Cluster'] = [class_to_numeric[label] for label in dataset['Cluster']]
In [17]:
dataset.head()
Out[17]:
Customer Name Category Sub Category City Order Date Region Sales Discount Profit State profit_margin Cluster
0 12 5 14 21 10 2 -0.414559 -1.430908 0.124389 0 0.595874 1
1 37 1 13 8 10 3 -1.291968 -0.627370 -0.941183 0 -0.416872 1
2 14 3 0 13 5 4 1.507054 -0.225601 -0.875930 0 -1.514014 0
3 15 4 12 4 9 3 -1.036563 0.310092 -1.196262 0 -1.260827 0
4 28 3 18 12 9 3 1.498367 0.444015 2.315743 0 1.186643 2

Split Data¶

In [18]:
X = dataset.drop(['Cluster','Sub Category','State','profit_margin'],axis=1)
y = dataset['Cluster']
In [19]:
X.head()
Out[19]:
Customer Name Category City Order Date Region Sales Discount Profit
0 12 5 21 10 2 -0.414559 -1.430908 0.124389
1 37 1 8 10 3 -1.291968 -0.627370 -0.941183
2 14 3 13 5 4 1.507054 -0.225601 -0.875930
3 15 4 4 9 3 -1.036563 0.310092 -1.196262
4 28 3 12 9 3 1.498367 0.444015 2.315743
In [20]:
X_train, X_temp, y_train, y_temp = train_test_split(X, y, test_size=0.2, random_state=42)
X_val, X_test, y_val, y_test = train_test_split(X_temp, y_temp, test_size=0.5, random_state=42)
In [21]:
heatcol = X.corr()
sns.heatmap(heatcol,cmap="BrBG",annot=True)
Out[21]:
<AxesSubplot:>
In [22]:
print("Dimension of Train set",X_train.shape)
print("Dimension of Val set",X_val.shape)
print("Dimension of Test set",X_test.shape,"\n")

num_cols = X_train._get_numeric_data().columns
print("Number of numeric features:",num_cols.size)
Dimension of Train set (7960, 8)
Dimension of Val set (995, 8)
Dimension of Test set (996, 8) 

Number of numeric features: 8

LSTM¶

In [23]:
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import LSTM, Dense
from tensorflow.keras.utils import to_categorical
from scikeras.wrappers import KerasClassifier
from sklearn.model_selection import GridSearchCV

y_train = to_categorical(y_train, num_classes=3)
y_val = to_categorical(y_val, num_classes=3)
In [24]:
def create_model(units=64,optimizer='adam',loss='categorical_crossentropy'):
    model = Sequential()
    model.add(LSTM(units, input_shape=(X_train.shape[1], 1)))

    model.add(Dense(units, activation='relu'))
    model.add(Dense(units=3, activation='sigmoid'))
    model.compile(loss=loss, optimizer=optimizer, metrics=['categorical_accuracy'])
    return model
In [25]:
model = KerasClassifier(build_fn=create_model, units=32, epochs=100, batch_size=32, verbose=0)

param_grid = {
    'optimizer': ['adam', 'sgd', 'rmsprop'],
    'units': [32, 64, 128],
    'loss' : ['categorical_crossentropy','binary_crossentropy','hinge','kullback_leibler_divergence']
}
grid = GridSearchCV(estimator=model, param_grid=param_grid, cv=2, scoring='accuracy')
grid_result = grid.fit(X_train, y_train)
c:\Users\Asus\anaconda3\lib\site-packages\scikeras\wrappers.py:915: UserWarning: ``build_fn`` will be renamed to ``model`` in a future release, at which point use of ``build_fn`` will raise an Error instead.
  X, y = self._initialize(X, y)
c:\Users\Asus\anaconda3\lib\site-packages\scikeras\wrappers.py:915: UserWarning: ``build_fn`` will be renamed to ``model`` in a future release, at which point use of ``build_fn`` will raise an Error instead.
  X, y = self._initialize(X, y)
c:\Users\Asus\anaconda3\lib\site-packages\scikeras\wrappers.py:915: UserWarning: ``build_fn`` will be renamed to ``model`` in a future release, at which point use of ``build_fn`` will raise an Error instead.
  X, y = self._initialize(X, y)
c:\Users\Asus\anaconda3\lib\site-packages\scikeras\wrappers.py:915: UserWarning: ``build_fn`` will be renamed to ``model`` in a future release, at which point use of ``build_fn`` will raise an Error instead.
  X, y = self._initialize(X, y)
c:\Users\Asus\anaconda3\lib\site-packages\scikeras\wrappers.py:915: UserWarning: ``build_fn`` will be renamed to ``model`` in a future release, at which point use of ``build_fn`` will raise an Error instead.
  X, y = self._initialize(X, y)
c:\Users\Asus\anaconda3\lib\site-packages\scikeras\wrappers.py:915: UserWarning: ``build_fn`` will be renamed to ``model`` in a future release, at which point use of ``build_fn`` will raise an Error instead.
  X, y = self._initialize(X, y)
c:\Users\Asus\anaconda3\lib\site-packages\scikeras\wrappers.py:915: UserWarning: ``build_fn`` will be renamed to ``model`` in a future release, at which point use of ``build_fn`` will raise an Error instead.
  X, y = self._initialize(X, y)
c:\Users\Asus\anaconda3\lib\site-packages\scikeras\wrappers.py:915: UserWarning: ``build_fn`` will be renamed to ``model`` in a future release, at which point use of ``build_fn`` will raise an Error instead.
  X, y = self._initialize(X, y)
c:\Users\Asus\anaconda3\lib\site-packages\scikeras\wrappers.py:915: UserWarning: ``build_fn`` will be renamed to ``model`` in a future release, at which point use of ``build_fn`` will raise an Error instead.
  X, y = self._initialize(X, y)
c:\Users\Asus\anaconda3\lib\site-packages\scikeras\wrappers.py:915: UserWarning: ``build_fn`` will be renamed to ``model`` in a future release, at which point use of ``build_fn`` will raise an Error instead.
  X, y = self._initialize(X, y)
c:\Users\Asus\anaconda3\lib\site-packages\scikeras\wrappers.py:915: UserWarning: ``build_fn`` will be renamed to ``model`` in a future release, at which point use of ``build_fn`` will raise an Error instead.
  X, y = self._initialize(X, y)
c:\Users\Asus\anaconda3\lib\site-packages\scikeras\wrappers.py:915: UserWarning: ``build_fn`` will be renamed to ``model`` in a future release, at which point use of ``build_fn`` will raise an Error instead.
  X, y = self._initialize(X, y)
c:\Users\Asus\anaconda3\lib\site-packages\scikeras\wrappers.py:915: UserWarning: ``build_fn`` will be renamed to ``model`` in a future release, at which point use of ``build_fn`` will raise an Error instead.
  X, y = self._initialize(X, y)
c:\Users\Asus\anaconda3\lib\site-packages\scikeras\wrappers.py:915: UserWarning: ``build_fn`` will be renamed to ``model`` in a future release, at which point use of ``build_fn`` will raise an Error instead.
  X, y = self._initialize(X, y)
c:\Users\Asus\anaconda3\lib\site-packages\scikeras\wrappers.py:915: UserWarning: ``build_fn`` will be renamed to ``model`` in a future release, at which point use of ``build_fn`` will raise an Error instead.
  X, y = self._initialize(X, y)
c:\Users\Asus\anaconda3\lib\site-packages\scikeras\wrappers.py:915: UserWarning: ``build_fn`` will be renamed to ``model`` in a future release, at which point use of ``build_fn`` will raise an Error instead.
  X, y = self._initialize(X, y)
c:\Users\Asus\anaconda3\lib\site-packages\scikeras\wrappers.py:915: UserWarning: ``build_fn`` will be renamed to ``model`` in a future release, at which point use of ``build_fn`` will raise an Error instead.
  X, y = self._initialize(X, y)
c:\Users\Asus\anaconda3\lib\site-packages\scikeras\wrappers.py:915: UserWarning: ``build_fn`` will be renamed to ``model`` in a future release, at which point use of ``build_fn`` will raise an Error instead.
  X, y = self._initialize(X, y)
c:\Users\Asus\anaconda3\lib\site-packages\scikeras\wrappers.py:915: UserWarning: ``build_fn`` will be renamed to ``model`` in a future release, at which point use of ``build_fn`` will raise an Error instead.
  X, y = self._initialize(X, y)
c:\Users\Asus\anaconda3\lib\site-packages\scikeras\wrappers.py:915: UserWarning: ``build_fn`` will be renamed to ``model`` in a future release, at which point use of ``build_fn`` will raise an Error instead.
  X, y = self._initialize(X, y)
c:\Users\Asus\anaconda3\lib\site-packages\scikeras\wrappers.py:915: UserWarning: ``build_fn`` will be renamed to ``model`` in a future release, at which point use of ``build_fn`` will raise an Error instead.
  X, y = self._initialize(X, y)
c:\Users\Asus\anaconda3\lib\site-packages\scikeras\wrappers.py:915: UserWarning: ``build_fn`` will be renamed to ``model`` in a future release, at which point use of ``build_fn`` will raise an Error instead.
  X, y = self._initialize(X, y)
c:\Users\Asus\anaconda3\lib\site-packages\scikeras\wrappers.py:915: UserWarning: ``build_fn`` will be renamed to ``model`` in a future release, at which point use of ``build_fn`` will raise an Error instead.
  X, y = self._initialize(X, y)
c:\Users\Asus\anaconda3\lib\site-packages\scikeras\wrappers.py:915: UserWarning: ``build_fn`` will be renamed to ``model`` in a future release, at which point use of ``build_fn`` will raise an Error instead.
  X, y = self._initialize(X, y)
c:\Users\Asus\anaconda3\lib\site-packages\scikeras\wrappers.py:915: UserWarning: ``build_fn`` will be renamed to ``model`` in a future release, at which point use of ``build_fn`` will raise an Error instead.
  X, y = self._initialize(X, y)
c:\Users\Asus\anaconda3\lib\site-packages\scikeras\wrappers.py:915: UserWarning: ``build_fn`` will be renamed to ``model`` in a future release, at which point use of ``build_fn`` will raise an Error instead.
  X, y = self._initialize(X, y)
c:\Users\Asus\anaconda3\lib\site-packages\scikeras\wrappers.py:915: UserWarning: ``build_fn`` will be renamed to ``model`` in a future release, at which point use of ``build_fn`` will raise an Error instead.
  X, y = self._initialize(X, y)
c:\Users\Asus\anaconda3\lib\site-packages\scikeras\wrappers.py:915: UserWarning: ``build_fn`` will be renamed to ``model`` in a future release, at which point use of ``build_fn`` will raise an Error instead.
  X, y = self._initialize(X, y)
c:\Users\Asus\anaconda3\lib\site-packages\scikeras\wrappers.py:915: UserWarning: ``build_fn`` will be renamed to ``model`` in a future release, at which point use of ``build_fn`` will raise an Error instead.
  X, y = self._initialize(X, y)
c:\Users\Asus\anaconda3\lib\site-packages\scikeras\wrappers.py:915: UserWarning: ``build_fn`` will be renamed to ``model`` in a future release, at which point use of ``build_fn`` will raise an Error instead.
  X, y = self._initialize(X, y)
c:\Users\Asus\anaconda3\lib\site-packages\scikeras\wrappers.py:915: UserWarning: ``build_fn`` will be renamed to ``model`` in a future release, at which point use of ``build_fn`` will raise an Error instead.
  X, y = self._initialize(X, y)
c:\Users\Asus\anaconda3\lib\site-packages\scikeras\wrappers.py:915: UserWarning: ``build_fn`` will be renamed to ``model`` in a future release, at which point use of ``build_fn`` will raise an Error instead.
  X, y = self._initialize(X, y)
c:\Users\Asus\anaconda3\lib\site-packages\scikeras\wrappers.py:915: UserWarning: ``build_fn`` will be renamed to ``model`` in a future release, at which point use of ``build_fn`` will raise an Error instead.
  X, y = self._initialize(X, y)
c:\Users\Asus\anaconda3\lib\site-packages\scikeras\wrappers.py:915: UserWarning: ``build_fn`` will be renamed to ``model`` in a future release, at which point use of ``build_fn`` will raise an Error instead.
  X, y = self._initialize(X, y)
c:\Users\Asus\anaconda3\lib\site-packages\scikeras\wrappers.py:915: UserWarning: ``build_fn`` will be renamed to ``model`` in a future release, at which point use of ``build_fn`` will raise an Error instead.
  X, y = self._initialize(X, y)
c:\Users\Asus\anaconda3\lib\site-packages\scikeras\wrappers.py:915: UserWarning: ``build_fn`` will be renamed to ``model`` in a future release, at which point use of ``build_fn`` will raise an Error instead.
  X, y = self._initialize(X, y)
c:\Users\Asus\anaconda3\lib\site-packages\scikeras\wrappers.py:915: UserWarning: ``build_fn`` will be renamed to ``model`` in a future release, at which point use of ``build_fn`` will raise an Error instead.
  X, y = self._initialize(X, y)
c:\Users\Asus\anaconda3\lib\site-packages\scikeras\wrappers.py:915: UserWarning: ``build_fn`` will be renamed to ``model`` in a future release, at which point use of ``build_fn`` will raise an Error instead.
  X, y = self._initialize(X, y)
c:\Users\Asus\anaconda3\lib\site-packages\scikeras\wrappers.py:915: UserWarning: ``build_fn`` will be renamed to ``model`` in a future release, at which point use of ``build_fn`` will raise an Error instead.
  X, y = self._initialize(X, y)
c:\Users\Asus\anaconda3\lib\site-packages\scikeras\wrappers.py:915: UserWarning: ``build_fn`` will be renamed to ``model`` in a future release, at which point use of ``build_fn`` will raise an Error instead.
  X, y = self._initialize(X, y)
c:\Users\Asus\anaconda3\lib\site-packages\scikeras\wrappers.py:915: UserWarning: ``build_fn`` will be renamed to ``model`` in a future release, at which point use of ``build_fn`` will raise an Error instead.
  X, y = self._initialize(X, y)
c:\Users\Asus\anaconda3\lib\site-packages\scikeras\wrappers.py:915: UserWarning: ``build_fn`` will be renamed to ``model`` in a future release, at which point use of ``build_fn`` will raise an Error instead.
  X, y = self._initialize(X, y)
c:\Users\Asus\anaconda3\lib\site-packages\scikeras\wrappers.py:915: UserWarning: ``build_fn`` will be renamed to ``model`` in a future release, at which point use of ``build_fn`` will raise an Error instead.
  X, y = self._initialize(X, y)
c:\Users\Asus\anaconda3\lib\site-packages\scikeras\wrappers.py:915: UserWarning: ``build_fn`` will be renamed to ``model`` in a future release, at which point use of ``build_fn`` will raise an Error instead.
  X, y = self._initialize(X, y)
c:\Users\Asus\anaconda3\lib\site-packages\scikeras\wrappers.py:915: UserWarning: ``build_fn`` will be renamed to ``model`` in a future release, at which point use of ``build_fn`` will raise an Error instead.
  X, y = self._initialize(X, y)
c:\Users\Asus\anaconda3\lib\site-packages\scikeras\wrappers.py:915: UserWarning: ``build_fn`` will be renamed to ``model`` in a future release, at which point use of ``build_fn`` will raise an Error instead.
  X, y = self._initialize(X, y)
c:\Users\Asus\anaconda3\lib\site-packages\scikeras\wrappers.py:915: UserWarning: ``build_fn`` will be renamed to ``model`` in a future release, at which point use of ``build_fn`` will raise an Error instead.
  X, y = self._initialize(X, y)
c:\Users\Asus\anaconda3\lib\site-packages\scikeras\wrappers.py:915: UserWarning: ``build_fn`` will be renamed to ``model`` in a future release, at which point use of ``build_fn`` will raise an Error instead.
  X, y = self._initialize(X, y)
c:\Users\Asus\anaconda3\lib\site-packages\scikeras\wrappers.py:915: UserWarning: ``build_fn`` will be renamed to ``model`` in a future release, at which point use of ``build_fn`` will raise an Error instead.
  X, y = self._initialize(X, y)
c:\Users\Asus\anaconda3\lib\site-packages\scikeras\wrappers.py:915: UserWarning: ``build_fn`` will be renamed to ``model`` in a future release, at which point use of ``build_fn`` will raise an Error instead.
  X, y = self._initialize(X, y)
c:\Users\Asus\anaconda3\lib\site-packages\scikeras\wrappers.py:915: UserWarning: ``build_fn`` will be renamed to ``model`` in a future release, at which point use of ``build_fn`` will raise an Error instead.
  X, y = self._initialize(X, y)
c:\Users\Asus\anaconda3\lib\site-packages\scikeras\wrappers.py:915: UserWarning: ``build_fn`` will be renamed to ``model`` in a future release, at which point use of ``build_fn`` will raise an Error instead.
  X, y = self._initialize(X, y)
c:\Users\Asus\anaconda3\lib\site-packages\scikeras\wrappers.py:915: UserWarning: ``build_fn`` will be renamed to ``model`` in a future release, at which point use of ``build_fn`` will raise an Error instead.
  X, y = self._initialize(X, y)
c:\Users\Asus\anaconda3\lib\site-packages\scikeras\wrappers.py:915: UserWarning: ``build_fn`` will be renamed to ``model`` in a future release, at which point use of ``build_fn`` will raise an Error instead.
  X, y = self._initialize(X, y)
c:\Users\Asus\anaconda3\lib\site-packages\scikeras\wrappers.py:915: UserWarning: ``build_fn`` will be renamed to ``model`` in a future release, at which point use of ``build_fn`` will raise an Error instead.
  X, y = self._initialize(X, y)
c:\Users\Asus\anaconda3\lib\site-packages\scikeras\wrappers.py:915: UserWarning: ``build_fn`` will be renamed to ``model`` in a future release, at which point use of ``build_fn`` will raise an Error instead.
  X, y = self._initialize(X, y)
c:\Users\Asus\anaconda3\lib\site-packages\scikeras\wrappers.py:915: UserWarning: ``build_fn`` will be renamed to ``model`` in a future release, at which point use of ``build_fn`` will raise an Error instead.
  X, y = self._initialize(X, y)
c:\Users\Asus\anaconda3\lib\site-packages\scikeras\wrappers.py:915: UserWarning: ``build_fn`` will be renamed to ``model`` in a future release, at which point use of ``build_fn`` will raise an Error instead.
  X, y = self._initialize(X, y)
c:\Users\Asus\anaconda3\lib\site-packages\scikeras\wrappers.py:915: UserWarning: ``build_fn`` will be renamed to ``model`` in a future release, at which point use of ``build_fn`` will raise an Error instead.
  X, y = self._initialize(X, y)
c:\Users\Asus\anaconda3\lib\site-packages\scikeras\wrappers.py:915: UserWarning: ``build_fn`` will be renamed to ``model`` in a future release, at which point use of ``build_fn`` will raise an Error instead.
  X, y = self._initialize(X, y)
c:\Users\Asus\anaconda3\lib\site-packages\scikeras\wrappers.py:915: UserWarning: ``build_fn`` will be renamed to ``model`` in a future release, at which point use of ``build_fn`` will raise an Error instead.
  X, y = self._initialize(X, y)
c:\Users\Asus\anaconda3\lib\site-packages\scikeras\wrappers.py:915: UserWarning: ``build_fn`` will be renamed to ``model`` in a future release, at which point use of ``build_fn`` will raise an Error instead.
  X, y = self._initialize(X, y)
c:\Users\Asus\anaconda3\lib\site-packages\scikeras\wrappers.py:915: UserWarning: ``build_fn`` will be renamed to ``model`` in a future release, at which point use of ``build_fn`` will raise an Error instead.
  X, y = self._initialize(X, y)
c:\Users\Asus\anaconda3\lib\site-packages\scikeras\wrappers.py:915: UserWarning: ``build_fn`` will be renamed to ``model`` in a future release, at which point use of ``build_fn`` will raise an Error instead.
  X, y = self._initialize(X, y)
c:\Users\Asus\anaconda3\lib\site-packages\scikeras\wrappers.py:915: UserWarning: ``build_fn`` will be renamed to ``model`` in a future release, at which point use of ``build_fn`` will raise an Error instead.
  X, y = self._initialize(X, y)
c:\Users\Asus\anaconda3\lib\site-packages\scikeras\wrappers.py:915: UserWarning: ``build_fn`` will be renamed to ``model`` in a future release, at which point use of ``build_fn`` will raise an Error instead.
  X, y = self._initialize(X, y)
c:\Users\Asus\anaconda3\lib\site-packages\scikeras\wrappers.py:915: UserWarning: ``build_fn`` will be renamed to ``model`` in a future release, at which point use of ``build_fn`` will raise an Error instead.
  X, y = self._initialize(X, y)
c:\Users\Asus\anaconda3\lib\site-packages\scikeras\wrappers.py:915: UserWarning: ``build_fn`` will be renamed to ``model`` in a future release, at which point use of ``build_fn`` will raise an Error instead.
  X, y = self._initialize(X, y)
c:\Users\Asus\anaconda3\lib\site-packages\scikeras\wrappers.py:915: UserWarning: ``build_fn`` will be renamed to ``model`` in a future release, at which point use of ``build_fn`` will raise an Error instead.
  X, y = self._initialize(X, y)
c:\Users\Asus\anaconda3\lib\site-packages\scikeras\wrappers.py:915: UserWarning: ``build_fn`` will be renamed to ``model`` in a future release, at which point use of ``build_fn`` will raise an Error instead.
  X, y = self._initialize(X, y)
c:\Users\Asus\anaconda3\lib\site-packages\scikeras\wrappers.py:915: UserWarning: ``build_fn`` will be renamed to ``model`` in a future release, at which point use of ``build_fn`` will raise an Error instead.
  X, y = self._initialize(X, y)
c:\Users\Asus\anaconda3\lib\site-packages\scikeras\wrappers.py:915: UserWarning: ``build_fn`` will be renamed to ``model`` in a future release, at which point use of ``build_fn`` will raise an Error instead.
  X, y = self._initialize(X, y)
c:\Users\Asus\anaconda3\lib\site-packages\sklearn\model_selection\_validation.py:425: FitFailedWarning: 
54 fits failed out of a total of 72.
The score on these train-test partitions for these parameters will be set to nan.
If these failures are not expected, you can try to debug them by setting error_score='raise'.

Below are more details about the failures:
--------------------------------------------------------------------------------
18 fits failed with the following error:
Traceback (most recent call last):
  File "c:\Users\Asus\anaconda3\lib\site-packages\sklearn\model_selection\_validation.py", line 729, in _fit_and_score
    estimator.fit(X_train, y_train, **fit_params)
  File "c:\Users\Asus\anaconda3\lib\site-packages\scikeras\wrappers.py", line 1491, in fit
    super().fit(X=X, y=y, sample_weight=sample_weight, **kwargs)
  File "c:\Users\Asus\anaconda3\lib\site-packages\scikeras\wrappers.py", line 760, in fit
    self._fit(
  File "c:\Users\Asus\anaconda3\lib\site-packages\scikeras\wrappers.py", line 926, in _fit
    self._check_model_compatibility(y)
  File "c:\Users\Asus\anaconda3\lib\site-packages\scikeras\wrappers.py", line 569, in _check_model_compatibility
    raise ValueError(
ValueError: loss=binary_crossentropy but model compiled with categorical_crossentropy. Data may not match loss function!

--------------------------------------------------------------------------------
18 fits failed with the following error:
Traceback (most recent call last):
  File "c:\Users\Asus\anaconda3\lib\site-packages\sklearn\model_selection\_validation.py", line 729, in _fit_and_score
    estimator.fit(X_train, y_train, **fit_params)
  File "c:\Users\Asus\anaconda3\lib\site-packages\scikeras\wrappers.py", line 1491, in fit
    super().fit(X=X, y=y, sample_weight=sample_weight, **kwargs)
  File "c:\Users\Asus\anaconda3\lib\site-packages\scikeras\wrappers.py", line 760, in fit
    self._fit(
  File "c:\Users\Asus\anaconda3\lib\site-packages\scikeras\wrappers.py", line 926, in _fit
    self._check_model_compatibility(y)
  File "c:\Users\Asus\anaconda3\lib\site-packages\scikeras\wrappers.py", line 569, in _check_model_compatibility
    raise ValueError(
ValueError: loss=hinge but model compiled with categorical_crossentropy. Data may not match loss function!

--------------------------------------------------------------------------------
18 fits failed with the following error:
Traceback (most recent call last):
  File "c:\Users\Asus\anaconda3\lib\site-packages\sklearn\model_selection\_validation.py", line 729, in _fit_and_score
    estimator.fit(X_train, y_train, **fit_params)
  File "c:\Users\Asus\anaconda3\lib\site-packages\scikeras\wrappers.py", line 1491, in fit
    super().fit(X=X, y=y, sample_weight=sample_weight, **kwargs)
  File "c:\Users\Asus\anaconda3\lib\site-packages\scikeras\wrappers.py", line 760, in fit
    self._fit(
  File "c:\Users\Asus\anaconda3\lib\site-packages\scikeras\wrappers.py", line 926, in _fit
    self._check_model_compatibility(y)
  File "c:\Users\Asus\anaconda3\lib\site-packages\scikeras\wrappers.py", line 569, in _check_model_compatibility
    raise ValueError(
ValueError: loss=kullback_leibler_divergence but model compiled with categorical_crossentropy. Data may not match loss function!

  warnings.warn(some_fits_failed_message, FitFailedWarning)
c:\Users\Asus\anaconda3\lib\site-packages\sklearn\model_selection\_search.py:979: UserWarning: One or more of the test scores are non-finite: [0.98856784 0.9781407  0.97512563 0.98153266 0.97022613 0.97248744
 0.98592965 0.98090452 0.97374372        nan        nan        nan
        nan        nan        nan        nan        nan        nan
        nan        nan        nan        nan        nan        nan
        nan        nan        nan        nan        nan        nan
        nan        nan        nan        nan        nan        nan]
  warnings.warn(
c:\Users\Asus\anaconda3\lib\site-packages\scikeras\wrappers.py:915: UserWarning: ``build_fn`` will be renamed to ``model`` in a future release, at which point use of ``build_fn`` will raise an Error instead.
  X, y = self._initialize(X, y)
In [26]:
print("Best: %f using %s" % (grid_result.best_score_, grid_result.best_params_))

# best_model = grid_result.best_estimator_
# accuracy = best_model.score(X_val, y_val)
# print("Validation Accuracy: %.2f%%" % (accuracy * 100))
Best: 0.988568 using {'loss': 'categorical_crossentropy', 'optimizer': 'adam', 'units': 32}
In [27]:
optimizer = grid_result.best_params_['optimizer']
units = grid_result.best_params_['units']
loss = grid_result.best_params_['loss']
In [28]:
model = Sequential()
model.add(LSTM(units, return_sequences = False, input_shape=(X_train.shape[1], 1)))

model.add(Dense(units, activation='relu'))
model.add(Dense(units=3, activation='sigmoid'))

model.compile(loss=loss, optimizer=optimizer, metrics=['mean_absolute_error'])

history = model.fit(X_train, y_train, epochs=100, batch_size=32, validation_data=(X_val, y_val))

accuracy = model.evaluate(X_val, y_val)
print("Accuracy: ", accuracy)
Epoch 1/100
249/249 [==============================] - 5s 10ms/step - loss: 0.8214 - mean_absolute_error: 0.4385 - val_loss: 0.4796 - val_mean_absolute_error: 0.3500
Epoch 2/100
249/249 [==============================] - 2s 7ms/step - loss: 0.3203 - mean_absolute_error: 0.3053 - val_loss: 0.2304 - val_mean_absolute_error: 0.2844
Epoch 3/100
249/249 [==============================] - 2s 7ms/step - loss: 0.1933 - mean_absolute_error: 0.2768 - val_loss: 0.1760 - val_mean_absolute_error: 0.2717
Epoch 4/100
249/249 [==============================] - 2s 7ms/step - loss: 0.1469 - mean_absolute_error: 0.2691 - val_loss: 0.1905 - val_mean_absolute_error: 0.2692
Epoch 5/100
249/249 [==============================] - 2s 8ms/step - loss: 0.1205 - mean_absolute_error: 0.2635 - val_loss: 0.1255 - val_mean_absolute_error: 0.2611
Epoch 6/100
249/249 [==============================] - 2s 6ms/step - loss: 0.1026 - mean_absolute_error: 0.2611 - val_loss: 0.0962 - val_mean_absolute_error: 0.2578
Epoch 7/100
249/249 [==============================] - 2s 7ms/step - loss: 0.0975 - mean_absolute_error: 0.2588 - val_loss: 0.0865 - val_mean_absolute_error: 0.2549
Epoch 8/100
249/249 [==============================] - 2s 6ms/step - loss: 0.0809 - mean_absolute_error: 0.2571 - val_loss: 0.0827 - val_mean_absolute_error: 0.2546
Epoch 9/100
249/249 [==============================] - 2s 6ms/step - loss: 0.0881 - mean_absolute_error: 0.2562 - val_loss: 0.0832 - val_mean_absolute_error: 0.2525
Epoch 10/100
249/249 [==============================] - 2s 7ms/step - loss: 0.0696 - mean_absolute_error: 0.2545 - val_loss: 0.0757 - val_mean_absolute_error: 0.2496
Epoch 11/100
249/249 [==============================] - 2s 7ms/step - loss: 0.0682 - mean_absolute_error: 0.2536 - val_loss: 0.0779 - val_mean_absolute_error: 0.2530
Epoch 12/100
249/249 [==============================] - 2s 7ms/step - loss: 0.0585 - mean_absolute_error: 0.2525 - val_loss: 0.0679 - val_mean_absolute_error: 0.2502
Epoch 13/100
249/249 [==============================] - 2s 7ms/step - loss: 0.0592 - mean_absolute_error: 0.2518 - val_loss: 0.0727 - val_mean_absolute_error: 0.2493
Epoch 14/100
249/249 [==============================] - 2s 7ms/step - loss: 0.0560 - mean_absolute_error: 0.2506 - val_loss: 0.0579 - val_mean_absolute_error: 0.2476
Epoch 15/100
249/249 [==============================] - 2s 7ms/step - loss: 0.0588 - mean_absolute_error: 0.2495 - val_loss: 0.0666 - val_mean_absolute_error: 0.2486
Epoch 16/100
249/249 [==============================] - 2s 7ms/step - loss: 0.0535 - mean_absolute_error: 0.2495 - val_loss: 0.0532 - val_mean_absolute_error: 0.2443
Epoch 17/100
249/249 [==============================] - 2s 7ms/step - loss: 0.0585 - mean_absolute_error: 0.2485 - val_loss: 0.0676 - val_mean_absolute_error: 0.2469
Epoch 18/100
249/249 [==============================] - 2s 7ms/step - loss: 0.0489 - mean_absolute_error: 0.2475 - val_loss: 0.0513 - val_mean_absolute_error: 0.2443
Epoch 19/100
249/249 [==============================] - 2s 7ms/step - loss: 0.0465 - mean_absolute_error: 0.2474 - val_loss: 0.0567 - val_mean_absolute_error: 0.2424
Epoch 20/100
249/249 [==============================] - 2s 7ms/step - loss: 0.0439 - mean_absolute_error: 0.2456 - val_loss: 0.0445 - val_mean_absolute_error: 0.2426
Epoch 21/100
249/249 [==============================] - 2s 6ms/step - loss: 0.0455 - mean_absolute_error: 0.2457 - val_loss: 0.0584 - val_mean_absolute_error: 0.2430
Epoch 22/100
249/249 [==============================] - 2s 7ms/step - loss: 0.0436 - mean_absolute_error: 0.2452 - val_loss: 0.0499 - val_mean_absolute_error: 0.2418
Epoch 23/100
249/249 [==============================] - 2s 7ms/step - loss: 0.0531 - mean_absolute_error: 0.2442 - val_loss: 0.0402 - val_mean_absolute_error: 0.2393
Epoch 24/100
249/249 [==============================] - 2s 7ms/step - loss: 0.0385 - mean_absolute_error: 0.2438 - val_loss: 0.0516 - val_mean_absolute_error: 0.2395
Epoch 25/100
249/249 [==============================] - 2s 6ms/step - loss: 0.0394 - mean_absolute_error: 0.2424 - val_loss: 0.0575 - val_mean_absolute_error: 0.2393
Epoch 26/100
249/249 [==============================] - 2s 7ms/step - loss: 0.0411 - mean_absolute_error: 0.2425 - val_loss: 0.0444 - val_mean_absolute_error: 0.2400
Epoch 27/100
249/249 [==============================] - 2s 7ms/step - loss: 0.0351 - mean_absolute_error: 0.2419 - val_loss: 0.0420 - val_mean_absolute_error: 0.2366
Epoch 28/100
249/249 [==============================] - 2s 7ms/step - loss: 0.0420 - mean_absolute_error: 0.2402 - val_loss: 0.0507 - val_mean_absolute_error: 0.2382
Epoch 29/100
249/249 [==============================] - 2s 6ms/step - loss: 0.0434 - mean_absolute_error: 0.2399 - val_loss: 0.0446 - val_mean_absolute_error: 0.2379
Epoch 30/100
249/249 [==============================] - 2s 6ms/step - loss: 0.0375 - mean_absolute_error: 0.2393 - val_loss: 0.0441 - val_mean_absolute_error: 0.2338
Epoch 31/100
249/249 [==============================] - 2s 7ms/step - loss: 0.0357 - mean_absolute_error: 0.2406 - val_loss: 0.0522 - val_mean_absolute_error: 0.2386
Epoch 32/100
249/249 [==============================] - 2s 7ms/step - loss: 0.0335 - mean_absolute_error: 0.2388 - val_loss: 0.0541 - val_mean_absolute_error: 0.2342
Epoch 33/100
249/249 [==============================] - 2s 8ms/step - loss: 0.0524 - mean_absolute_error: 0.2387 - val_loss: 0.0384 - val_mean_absolute_error: 0.2361
Epoch 34/100
249/249 [==============================] - 2s 7ms/step - loss: 0.0330 - mean_absolute_error: 0.2395 - val_loss: 0.0745 - val_mean_absolute_error: 0.2381
Epoch 35/100
249/249 [==============================] - 2s 7ms/step - loss: 0.0359 - mean_absolute_error: 0.2383 - val_loss: 0.0428 - val_mean_absolute_error: 0.2313
Epoch 36/100
249/249 [==============================] - 2s 7ms/step - loss: 0.0429 - mean_absolute_error: 0.2371 - val_loss: 0.0392 - val_mean_absolute_error: 0.2336
Epoch 37/100
249/249 [==============================] - 2s 7ms/step - loss: 0.0340 - mean_absolute_error: 0.2386 - val_loss: 0.0334 - val_mean_absolute_error: 0.2321
Epoch 38/100
249/249 [==============================] - 2s 7ms/step - loss: 0.0350 - mean_absolute_error: 0.2377 - val_loss: 0.0338 - val_mean_absolute_error: 0.2342
Epoch 39/100
249/249 [==============================] - 2s 7ms/step - loss: 0.0276 - mean_absolute_error: 0.2376 - val_loss: 0.0299 - val_mean_absolute_error: 0.2318
Epoch 40/100
249/249 [==============================] - 2s 7ms/step - loss: 0.0360 - mean_absolute_error: 0.2357 - val_loss: 0.0452 - val_mean_absolute_error: 0.2338
Epoch 41/100
249/249 [==============================] - 2s 7ms/step - loss: 0.0362 - mean_absolute_error: 0.2363 - val_loss: 0.0537 - val_mean_absolute_error: 0.2348
Epoch 42/100
249/249 [==============================] - 2s 7ms/step - loss: 0.0422 - mean_absolute_error: 0.2369 - val_loss: 0.0433 - val_mean_absolute_error: 0.2347
Epoch 43/100
249/249 [==============================] - 2s 7ms/step - loss: 0.0251 - mean_absolute_error: 0.2362 - val_loss: 0.0413 - val_mean_absolute_error: 0.2318
Epoch 44/100
249/249 [==============================] - 2s 7ms/step - loss: 0.0279 - mean_absolute_error: 0.2352 - val_loss: 0.0522 - val_mean_absolute_error: 0.2334
Epoch 45/100
249/249 [==============================] - 2s 7ms/step - loss: 0.0307 - mean_absolute_error: 0.2347 - val_loss: 0.0717 - val_mean_absolute_error: 0.2341
Epoch 46/100
249/249 [==============================] - 2s 6ms/step - loss: 0.0377 - mean_absolute_error: 0.2348 - val_loss: 0.0392 - val_mean_absolute_error: 0.2332
Epoch 47/100
249/249 [==============================] - 2s 7ms/step - loss: 0.0279 - mean_absolute_error: 0.2345 - val_loss: 0.0455 - val_mean_absolute_error: 0.2318
Epoch 48/100
249/249 [==============================] - 2s 7ms/step - loss: 0.0363 - mean_absolute_error: 0.2328 - val_loss: 0.0339 - val_mean_absolute_error: 0.2297
Epoch 49/100
249/249 [==============================] - 2s 7ms/step - loss: 0.0233 - mean_absolute_error: 0.2325 - val_loss: 0.0372 - val_mean_absolute_error: 0.2295
Epoch 50/100
249/249 [==============================] - 2s 7ms/step - loss: 0.0347 - mean_absolute_error: 0.2316 - val_loss: 0.0270 - val_mean_absolute_error: 0.2274
Epoch 51/100
249/249 [==============================] - 2s 6ms/step - loss: 0.0301 - mean_absolute_error: 0.2324 - val_loss: 0.0430 - val_mean_absolute_error: 0.2271
Epoch 52/100
249/249 [==============================] - 2s 7ms/step - loss: 0.0227 - mean_absolute_error: 0.2313 - val_loss: 0.0330 - val_mean_absolute_error: 0.2262
Epoch 53/100
249/249 [==============================] - 2s 7ms/step - loss: 0.0311 - mean_absolute_error: 0.2311 - val_loss: 0.0384 - val_mean_absolute_error: 0.2257
Epoch 54/100
249/249 [==============================] - 2s 7ms/step - loss: 0.0352 - mean_absolute_error: 0.2300 - val_loss: 0.1161 - val_mean_absolute_error: 0.2291
Epoch 55/100
249/249 [==============================] - 2s 7ms/step - loss: 0.0304 - mean_absolute_error: 0.2321 - val_loss: 0.0405 - val_mean_absolute_error: 0.2234
Epoch 56/100
249/249 [==============================] - 2s 6ms/step - loss: 0.0242 - mean_absolute_error: 0.2299 - val_loss: 0.0299 - val_mean_absolute_error: 0.2261
Epoch 57/100
249/249 [==============================] - 2s 7ms/step - loss: 0.0227 - mean_absolute_error: 0.2302 - val_loss: 0.0386 - val_mean_absolute_error: 0.2276
Epoch 58/100
249/249 [==============================] - 2s 7ms/step - loss: 0.0351 - mean_absolute_error: 0.2297 - val_loss: 0.0322 - val_mean_absolute_error: 0.2256
Epoch 59/100
249/249 [==============================] - 2s 7ms/step - loss: 0.0238 - mean_absolute_error: 0.2296 - val_loss: 0.0269 - val_mean_absolute_error: 0.2253
Epoch 60/100
249/249 [==============================] - 2s 7ms/step - loss: 0.0211 - mean_absolute_error: 0.2272 - val_loss: 0.0383 - val_mean_absolute_error: 0.2238
Epoch 61/100
249/249 [==============================] - 2s 7ms/step - loss: 0.0537 - mean_absolute_error: 0.2279 - val_loss: 0.0250 - val_mean_absolute_error: 0.2276
Epoch 62/100
249/249 [==============================] - 2s 7ms/step - loss: 0.0236 - mean_absolute_error: 0.2309 - val_loss: 0.0517 - val_mean_absolute_error: 0.2265
Epoch 63/100
249/249 [==============================] - 2s 6ms/step - loss: 0.0195 - mean_absolute_error: 0.2293 - val_loss: 0.0255 - val_mean_absolute_error: 0.2214
Epoch 64/100
249/249 [==============================] - 2s 7ms/step - loss: 0.0258 - mean_absolute_error: 0.2272 - val_loss: 0.0445 - val_mean_absolute_error: 0.2200
Epoch 65/100
249/249 [==============================] - 2s 7ms/step - loss: 0.0242 - mean_absolute_error: 0.2260 - val_loss: 0.0271 - val_mean_absolute_error: 0.2212
Epoch 66/100
249/249 [==============================] - 2s 7ms/step - loss: 0.0352 - mean_absolute_error: 0.2268 - val_loss: 0.0640 - val_mean_absolute_error: 0.2225
Epoch 67/100
249/249 [==============================] - 2s 7ms/step - loss: 0.0333 - mean_absolute_error: 0.2280 - val_loss: 0.0388 - val_mean_absolute_error: 0.2244
Epoch 68/100
249/249 [==============================] - 2s 7ms/step - loss: 0.0295 - mean_absolute_error: 0.2284 - val_loss: 0.0303 - val_mean_absolute_error: 0.2272
Epoch 69/100
249/249 [==============================] - 2s 6ms/step - loss: 0.0292 - mean_absolute_error: 0.2297 - val_loss: 0.0283 - val_mean_absolute_error: 0.2262
Epoch 70/100
249/249 [==============================] - 2s 7ms/step - loss: 0.0225 - mean_absolute_error: 0.2289 - val_loss: 0.0329 - val_mean_absolute_error: 0.2233
Epoch 71/100
249/249 [==============================] - 2s 7ms/step - loss: 0.0240 - mean_absolute_error: 0.2276 - val_loss: 0.0302 - val_mean_absolute_error: 0.2244
Epoch 72/100
249/249 [==============================] - 2s 7ms/step - loss: 0.0231 - mean_absolute_error: 0.2260 - val_loss: 0.0262 - val_mean_absolute_error: 0.2204
Epoch 73/100
249/249 [==============================] - 2s 7ms/step - loss: 0.0220 - mean_absolute_error: 0.2251 - val_loss: 0.0422 - val_mean_absolute_error: 0.2158
Epoch 74/100
249/249 [==============================] - 2s 7ms/step - loss: 0.0208 - mean_absolute_error: 0.2238 - val_loss: 0.0219 - val_mean_absolute_error: 0.2195
Epoch 75/100
249/249 [==============================] - 2s 7ms/step - loss: 0.0350 - mean_absolute_error: 0.2240 - val_loss: 0.0579 - val_mean_absolute_error: 0.2176
Epoch 76/100
249/249 [==============================] - 2s 7ms/step - loss: 0.0196 - mean_absolute_error: 0.2248 - val_loss: 0.0183 - val_mean_absolute_error: 0.2215
Epoch 77/100
249/249 [==============================] - 2s 7ms/step - loss: 0.0227 - mean_absolute_error: 0.2244 - val_loss: 0.0326 - val_mean_absolute_error: 0.2187
Epoch 78/100
249/249 [==============================] - 2s 7ms/step - loss: 0.0204 - mean_absolute_error: 0.2228 - val_loss: 0.0266 - val_mean_absolute_error: 0.2135
Epoch 79/100
249/249 [==============================] - 2s 7ms/step - loss: 0.0192 - mean_absolute_error: 0.2210 - val_loss: 0.0196 - val_mean_absolute_error: 0.2137
Epoch 80/100
249/249 [==============================] - 2s 7ms/step - loss: 0.0187 - mean_absolute_error: 0.2187 - val_loss: 0.0385 - val_mean_absolute_error: 0.2141
Epoch 81/100
249/249 [==============================] - 2s 7ms/step - loss: 0.0382 - mean_absolute_error: 0.2200 - val_loss: 0.0182 - val_mean_absolute_error: 0.2166
Epoch 82/100
249/249 [==============================] - 2s 7ms/step - loss: 0.0188 - mean_absolute_error: 0.2186 - val_loss: 0.0206 - val_mean_absolute_error: 0.2096
Epoch 83/100
249/249 [==============================] - 2s 6ms/step - loss: 0.0266 - mean_absolute_error: 0.2177 - val_loss: 0.0347 - val_mean_absolute_error: 0.2134
Epoch 84/100
249/249 [==============================] - 2s 6ms/step - loss: 0.0163 - mean_absolute_error: 0.2192 - val_loss: 0.0216 - val_mean_absolute_error: 0.2145
Epoch 85/100
249/249 [==============================] - 2s 6ms/step - loss: 0.0408 - mean_absolute_error: 0.2183 - val_loss: 0.0270 - val_mean_absolute_error: 0.2181
Epoch 86/100
249/249 [==============================] - 2s 7ms/step - loss: 0.0214 - mean_absolute_error: 0.2200 - val_loss: 0.0218 - val_mean_absolute_error: 0.2142
Epoch 87/100
249/249 [==============================] - 2s 7ms/step - loss: 0.0204 - mean_absolute_error: 0.2180 - val_loss: 0.0318 - val_mean_absolute_error: 0.2120
Epoch 88/100
249/249 [==============================] - 2s 7ms/step - loss: 0.0226 - mean_absolute_error: 0.2189 - val_loss: 0.0191 - val_mean_absolute_error: 0.2154
Epoch 89/100
249/249 [==============================] - 2s 7ms/step - loss: 0.0175 - mean_absolute_error: 0.2185 - val_loss: 0.0205 - val_mean_absolute_error: 0.2127
Epoch 90/100
249/249 [==============================] - 2s 7ms/step - loss: 0.0243 - mean_absolute_error: 0.2166 - val_loss: 0.0234 - val_mean_absolute_error: 0.2111
Epoch 91/100
249/249 [==============================] - 2s 7ms/step - loss: 0.0175 - mean_absolute_error: 0.2152 - val_loss: 0.0233 - val_mean_absolute_error: 0.2064
Epoch 92/100
249/249 [==============================] - 2s 6ms/step - loss: 0.0235 - mean_absolute_error: 0.2139 - val_loss: 0.0350 - val_mean_absolute_error: 0.2084
Epoch 93/100
249/249 [==============================] - 2s 7ms/step - loss: 0.0175 - mean_absolute_error: 0.2140 - val_loss: 0.0553 - val_mean_absolute_error: 0.2101
Epoch 94/100
249/249 [==============================] - 2s 7ms/step - loss: 0.0222 - mean_absolute_error: 0.2128 - val_loss: 0.0252 - val_mean_absolute_error: 0.2075
Epoch 95/100
249/249 [==============================] - 2s 6ms/step - loss: 0.0218 - mean_absolute_error: 0.2121 - val_loss: 0.0423 - val_mean_absolute_error: 0.2058
Epoch 96/100
249/249 [==============================] - 2s 6ms/step - loss: 0.0313 - mean_absolute_error: 0.2129 - val_loss: 0.0392 - val_mean_absolute_error: 0.2037
Epoch 97/100
249/249 [==============================] - 2s 6ms/step - loss: 0.0184 - mean_absolute_error: 0.2131 - val_loss: 0.0209 - val_mean_absolute_error: 0.2042
Epoch 98/100
249/249 [==============================] - 2s 6ms/step - loss: 0.0164 - mean_absolute_error: 0.2130 - val_loss: 0.0389 - val_mean_absolute_error: 0.2080
Epoch 99/100
249/249 [==============================] - 2s 6ms/step - loss: 0.0150 - mean_absolute_error: 0.2102 - val_loss: 0.0265 - val_mean_absolute_error: 0.2056
Epoch 100/100
249/249 [==============================] - 2s 6ms/step - loss: 0.0330 - mean_absolute_error: 0.2109 - val_loss: 0.0245 - val_mean_absolute_error: 0.2059
32/32 [==============================] - 0s 3ms/step - loss: 0.0245 - mean_absolute_error: 0.2059
Accuracy:  [0.02451775223016739, 0.20593295991420746]
In [29]:
from tensorflow.keras.utils import plot_model
model.summary()
plot_model(model, to_file='mlp-mnist.png', show_shapes=True)
Model: "sequential_73"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
=================================================================
 lstm_73 (LSTM)              (None, 32)                4352      
                                                                 
 dense_146 (Dense)           (None, 32)                1056      
                                                                 
 dense_147 (Dense)           (None, 3)                 99        
                                                                 
=================================================================
Total params: 5507 (21.51 KB)
Trainable params: 5507 (21.51 KB)
Non-trainable params: 0 (0.00 Byte)
_________________________________________________________________
You must install pydot (`pip install pydot`) and install graphviz (see instructions at https://graphviz.gitlab.io/download/) for plot_model to work.
In [30]:
import joblib
joblib.dump(model, 'model/lstm_model.pkl')
INFO:tensorflow:Assets written to: C:\Users\Asus\AppData\Local\Temp\tmp036863h3\assets
INFO:tensorflow:Assets written to: C:\Users\Asus\AppData\Local\Temp\tmp036863h3\assets
Out[30]:
['model/lstm_model.pkl']
In [31]:
predictions = model.predict(X_test)
y_pred = np.argmax(predictions, axis=1)
32/32 [==============================] - 1s 3ms/step
In [32]:
plt.figure(figsize=(12, 5))
plt.subplot(1, 2, 1)
plt.plot(history.history['loss'])
plt.plot(history.history['val_loss'])
plt.title('Loss over epochs')
plt.ylabel('Loss')
plt.xlabel('Epoch')
plt.legend(['Train', 'Validation'], loc='best')

plt.subplot(1, 2, 2)
plt.plot(history.history['mean_absolute_error'])
plt.plot(history.history['val_mean_absolute_error'])
plt.title('Accuracy over epochs')
plt.ylabel('Accuracy')
plt.xlabel('Epoch')
plt.legend(['Train', 'Validation'], loc='best')

plt.tight_layout()
plt.show()
In [33]:
cm = confusion_matrix(y_test, y_pred)
disp = ConfusionMatrixDisplay(confusion_matrix=cm)
disp.plot(cmap='Blues', values_format='d')
plt.title("Confusion Matrix: RNN")
plt.show()

print(classification_report(y_test, y_pred))
              precision    recall  f1-score   support

           0       1.00      0.99      1.00       327
           1       0.98      0.98      0.98       352
           2       0.98      0.99      0.98       317

    accuracy                           0.99       996
   macro avg       0.99      0.99      0.99       996
weighted avg       0.99      0.99      0.99       996