import numpy as np
import pandas as pd
import seaborn as sns
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import LabelEncoder, OneHotEncoder, StandardScaler
from sklearn.model_selection import cross_val_score, GridSearchCV
import matplotlib.pyplot as plt
from sklearn.metrics import confusion_matrix,classification_report, ConfusionMatrixDisplay
dataset = pd.read_csv('data/market_cluster.csv', encoding='latin1')
dataset.head()
| Order ID | Customer Name | Category | Sub Category | City | Order Date | Region | Sales | Discount | Profit | State | profit_margin | Cluster | |
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| 0 | OD1 | Harish | Oil & Masala | Masalas | Vellore | 11-08-2017 | North | 1254 | 0.12 | 401.28 | Tamil Nadu | 0.32 | Medium |
| 1 | OD2 | Sudha | Beverages | Health Drinks | Krishnagiri | 11-08-2017 | South | 749 | 0.18 | 149.80 | Tamil Nadu | 0.20 | Medium |
| 2 | OD3 | Hussain | Food Grains | Atta & Flour | Perambalur | 06-12-2017 | West | 2360 | 0.21 | 165.20 | Tamil Nadu | 0.07 | Low |
| 3 | OD4 | Jackson | Fruits & Veggies | Fresh Vegetables | Dharmapuri | 10-11-2016 | South | 896 | 0.25 | 89.60 | Tamil Nadu | 0.10 | Low |
| 4 | OD5 | Ridhesh | Food Grains | Organic Staples | Ooty | 10-11-2016 | South | 2355 | 0.26 | 918.45 | Tamil Nadu | 0.39 | High |
dataset.info()
<class 'pandas.core.frame.DataFrame'> RangeIndex: 9994 entries, 0 to 9993 Data columns (total 13 columns): # Column Non-Null Count Dtype --- ------ -------------- ----- 0 Order ID 9994 non-null object 1 Customer Name 9994 non-null object 2 Category 9994 non-null object 3 Sub Category 9994 non-null object 4 City 9994 non-null object 5 Order Date 9994 non-null object 6 Region 9994 non-null object 7 Sales 9994 non-null int64 8 Discount 9994 non-null float64 9 Profit 9994 non-null float64 10 State 9994 non-null object 11 profit_margin 9994 non-null float64 12 Cluster 9994 non-null object dtypes: float64(3), int64(1), object(9) memory usage: 1015.1+ KB
dataset.describe()
| Sales | Discount | Profit | profit_margin | |
|---|---|---|---|---|
| count | 9994.000000 | 9994.000000 | 9994.000000 | 9994.000000 |
| mean | 1496.596158 | 0.226817 | 374.937082 | 0.250228 |
| std | 577.559036 | 0.074636 | 239.932881 | 0.118919 |
| min | 500.000000 | 0.100000 | 25.250000 | 0.050000 |
| 25% | 1000.000000 | 0.160000 | 180.022500 | 0.150000 |
| 50% | 1498.000000 | 0.230000 | 320.780000 | 0.250000 |
| 75% | 1994.750000 | 0.290000 | 525.627500 | 0.350000 |
| max | 2500.000000 | 0.350000 | 1120.950000 | 0.450000 |
dataset.info()
<class 'pandas.core.frame.DataFrame'> RangeIndex: 9994 entries, 0 to 9993 Data columns (total 13 columns): # Column Non-Null Count Dtype --- ------ -------------- ----- 0 Order ID 9994 non-null object 1 Customer Name 9994 non-null object 2 Category 9994 non-null object 3 Sub Category 9994 non-null object 4 City 9994 non-null object 5 Order Date 9994 non-null object 6 Region 9994 non-null object 7 Sales 9994 non-null int64 8 Discount 9994 non-null float64 9 Profit 9994 non-null float64 10 State 9994 non-null object 11 profit_margin 9994 non-null float64 12 Cluster 9994 non-null object dtypes: float64(3), int64(1), object(9) memory usage: 1015.1+ KB
dataset.drop(['Order ID'], axis=1, inplace=True)
dataset.isna().sum()
Customer Name 0 Category 0 Sub Category 0 City 0 Order Date 0 Region 0 Sales 0 Discount 0 Profit 0 State 0 profit_margin 0 Cluster 0 dtype: int64
dataset.dropna(inplace=True)
def remove_outliers(data: pd.DataFrame, column: str) -> pd.Series:
q3, q1 = np.nanpercentile(data[column], [75, 25])
iqr = q3 - q1
upper_bound = q3 + 1.5 * iqr
lower_bound = q1 - 1.5 * iqr
data = data[(data[column] > lower_bound) & (data[column] < upper_bound)]
return data
dataset = remove_outliers(dataset, 'Discount')
dataset = remove_outliers(dataset, 'Sales')
dataset = remove_outliers(dataset, 'Profit')
dataset.head()
| Customer Name | Category | Sub Category | City | Order Date | Region | Sales | Discount | Profit | State | profit_margin | Cluster | |
|---|---|---|---|---|---|---|---|---|---|---|---|---|
| 0 | Harish | Oil & Masala | Masalas | Vellore | 11-08-2017 | North | 1254 | 0.12 | 401.28 | Tamil Nadu | 0.32 | Medium |
| 1 | Sudha | Beverages | Health Drinks | Krishnagiri | 11-08-2017 | South | 749 | 0.18 | 149.80 | Tamil Nadu | 0.20 | Medium |
| 2 | Hussain | Food Grains | Atta & Flour | Perambalur | 06-12-2017 | West | 2360 | 0.21 | 165.20 | Tamil Nadu | 0.07 | Low |
| 3 | Jackson | Fruits & Veggies | Fresh Vegetables | Dharmapuri | 10-11-2016 | South | 896 | 0.25 | 89.60 | Tamil Nadu | 0.10 | Low |
| 4 | Ridhesh | Food Grains | Organic Staples | Ooty | 10-11-2016 | South | 2355 | 0.26 | 918.45 | Tamil Nadu | 0.39 | High |
sns.histplot(dataset['Cluster'])
<AxesSubplot:xlabel='Cluster', ylabel='Count'>
encoder = LabelEncoder()
scaler = StandardScaler()
onehot = OneHotEncoder()
dataset["Order Date"] = pd.to_datetime(dataset["Order Date"])
dataset["Order Date"] = dataset["Order Date"].dt.month
dataset["Customer Name"] = encoder.fit_transform(dataset["Customer Name"])
dataset["Category"] = encoder.fit_transform(dataset["Category"])
dataset["City"] = encoder.fit_transform(dataset["City"])
dataset["Region"] = encoder.fit_transform(dataset["Region"])
dataset["State"] = encoder.fit_transform(dataset["State"])
dataset["Sub Category"] = encoder.fit_transform(dataset["Sub Category"])
# dataset["Customer Name"] = onehot.fit_transform(dataset["Customer Name"].values.reshape(-1, 1)).toarray()
# dataset["Category"] = onehot.fit_transform(dataset["Category"].values.reshape(-1, 1)).toarray()
# dataset["City"] = onehot.fit_transform(dataset["City"].values.reshape(-1, 1)).toarray()
# dataset["Region"] = onehot.fit_transform(dataset["Region"].values.reshape(-1, 1)).toarray()
# dataset["State"] = onehot.fit_transform(dataset["State"].values.reshape(-1, 1)).toarray()
# dataset["Sub Category"] = onehot.fit_transform(dataset["Sub Category"].values.reshape(-1, 1)).toarray()
dataset["Order Date"] = encoder.fit_transform(dataset["Order Date"])
dataset[["Sales", "Discount", "profit_margin","Profit"]] = scaler.fit_transform(dataset[["Sales", "Discount", "profit_margin","Profit"]])
class_to_numeric = {'Low': 0, 'Medium': 1, 'High': 2}
dataset['Cluster'] = [class_to_numeric[label] for label in dataset['Cluster']]
dataset.head()
| Customer Name | Category | Sub Category | City | Order Date | Region | Sales | Discount | Profit | State | profit_margin | Cluster | |
|---|---|---|---|---|---|---|---|---|---|---|---|---|
| 0 | 12 | 5 | 14 | 21 | 10 | 2 | -0.414559 | -1.430908 | 0.124389 | 0 | 0.595874 | 1 |
| 1 | 37 | 1 | 13 | 8 | 10 | 3 | -1.291968 | -0.627370 | -0.941183 | 0 | -0.416872 | 1 |
| 2 | 14 | 3 | 0 | 13 | 5 | 4 | 1.507054 | -0.225601 | -0.875930 | 0 | -1.514014 | 0 |
| 3 | 15 | 4 | 12 | 4 | 9 | 3 | -1.036563 | 0.310092 | -1.196262 | 0 | -1.260827 | 0 |
| 4 | 28 | 3 | 18 | 12 | 9 | 3 | 1.498367 | 0.444015 | 2.315743 | 0 | 1.186643 | 2 |
X = dataset.drop(['Cluster','Sub Category','State','profit_margin'],axis=1)
y = dataset['Cluster']
X.head()
| Customer Name | Category | City | Order Date | Region | Sales | Discount | Profit | |
|---|---|---|---|---|---|---|---|---|
| 0 | 12 | 5 | 21 | 10 | 2 | -0.414559 | -1.430908 | 0.124389 |
| 1 | 37 | 1 | 8 | 10 | 3 | -1.291968 | -0.627370 | -0.941183 |
| 2 | 14 | 3 | 13 | 5 | 4 | 1.507054 | -0.225601 | -0.875930 |
| 3 | 15 | 4 | 4 | 9 | 3 | -1.036563 | 0.310092 | -1.196262 |
| 4 | 28 | 3 | 12 | 9 | 3 | 1.498367 | 0.444015 | 2.315743 |
X_train, X_temp, y_train, y_temp = train_test_split(X, y, test_size=0.2, random_state=42)
X_val, X_test, y_val, y_test = train_test_split(X_temp, y_temp, test_size=0.5, random_state=42)
heatcol = X.corr()
sns.heatmap(heatcol,cmap="BrBG",annot=True)
<AxesSubplot:>
print("Dimension of Train set",X_train.shape)
print("Dimension of Val set",X_val.shape)
print("Dimension of Test set",X_test.shape,"\n")
num_cols = X_train._get_numeric_data().columns
print("Number of numeric features:",num_cols.size)
Dimension of Train set (7960, 8) Dimension of Val set (995, 8) Dimension of Test set (996, 8) Number of numeric features: 8
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import LSTM, Dense
from tensorflow.keras.utils import to_categorical
from scikeras.wrappers import KerasClassifier
from sklearn.model_selection import GridSearchCV
y_train = to_categorical(y_train, num_classes=3)
y_val = to_categorical(y_val, num_classes=3)
def create_model(units=64,optimizer='adam',loss='categorical_crossentropy'):
model = Sequential()
model.add(LSTM(units, input_shape=(X_train.shape[1], 1)))
model.add(Dense(units, activation='relu'))
model.add(Dense(units=3, activation='sigmoid'))
model.compile(loss=loss, optimizer=optimizer, metrics=['categorical_accuracy'])
return model
model = KerasClassifier(build_fn=create_model, units=32, epochs=100, batch_size=32, verbose=0)
param_grid = {
'optimizer': ['adam', 'sgd', 'rmsprop'],
'units': [32, 64, 128],
'loss' : ['categorical_crossentropy','binary_crossentropy','hinge','kullback_leibler_divergence']
}
grid = GridSearchCV(estimator=model, param_grid=param_grid, cv=2, scoring='accuracy')
grid_result = grid.fit(X_train, y_train)
c:\Users\Asus\anaconda3\lib\site-packages\scikeras\wrappers.py:915: UserWarning: ``build_fn`` will be renamed to ``model`` in a future release, at which point use of ``build_fn`` will raise an Error instead.
X, y = self._initialize(X, y)
c:\Users\Asus\anaconda3\lib\site-packages\scikeras\wrappers.py:915: UserWarning: ``build_fn`` will be renamed to ``model`` in a future release, at which point use of ``build_fn`` will raise an Error instead.
X, y = self._initialize(X, y)
c:\Users\Asus\anaconda3\lib\site-packages\scikeras\wrappers.py:915: UserWarning: ``build_fn`` will be renamed to ``model`` in a future release, at which point use of ``build_fn`` will raise an Error instead.
X, y = self._initialize(X, y)
c:\Users\Asus\anaconda3\lib\site-packages\scikeras\wrappers.py:915: UserWarning: ``build_fn`` will be renamed to ``model`` in a future release, at which point use of ``build_fn`` will raise an Error instead.
X, y = self._initialize(X, y)
c:\Users\Asus\anaconda3\lib\site-packages\scikeras\wrappers.py:915: UserWarning: ``build_fn`` will be renamed to ``model`` in a future release, at which point use of ``build_fn`` will raise an Error instead.
X, y = self._initialize(X, y)
c:\Users\Asus\anaconda3\lib\site-packages\scikeras\wrappers.py:915: UserWarning: ``build_fn`` will be renamed to ``model`` in a future release, at which point use of ``build_fn`` will raise an Error instead.
X, y = self._initialize(X, y)
c:\Users\Asus\anaconda3\lib\site-packages\scikeras\wrappers.py:915: UserWarning: ``build_fn`` will be renamed to ``model`` in a future release, at which point use of ``build_fn`` will raise an Error instead.
X, y = self._initialize(X, y)
c:\Users\Asus\anaconda3\lib\site-packages\scikeras\wrappers.py:915: UserWarning: ``build_fn`` will be renamed to ``model`` in a future release, at which point use of ``build_fn`` will raise an Error instead.
X, y = self._initialize(X, y)
c:\Users\Asus\anaconda3\lib\site-packages\scikeras\wrappers.py:915: UserWarning: ``build_fn`` will be renamed to ``model`` in a future release, at which point use of ``build_fn`` will raise an Error instead.
X, y = self._initialize(X, y)
c:\Users\Asus\anaconda3\lib\site-packages\scikeras\wrappers.py:915: UserWarning: ``build_fn`` will be renamed to ``model`` in a future release, at which point use of ``build_fn`` will raise an Error instead.
X, y = self._initialize(X, y)
c:\Users\Asus\anaconda3\lib\site-packages\scikeras\wrappers.py:915: UserWarning: ``build_fn`` will be renamed to ``model`` in a future release, at which point use of ``build_fn`` will raise an Error instead.
X, y = self._initialize(X, y)
c:\Users\Asus\anaconda3\lib\site-packages\scikeras\wrappers.py:915: UserWarning: ``build_fn`` will be renamed to ``model`` in a future release, at which point use of ``build_fn`` will raise an Error instead.
X, y = self._initialize(X, y)
c:\Users\Asus\anaconda3\lib\site-packages\scikeras\wrappers.py:915: UserWarning: ``build_fn`` will be renamed to ``model`` in a future release, at which point use of ``build_fn`` will raise an Error instead.
X, y = self._initialize(X, y)
c:\Users\Asus\anaconda3\lib\site-packages\scikeras\wrappers.py:915: UserWarning: ``build_fn`` will be renamed to ``model`` in a future release, at which point use of ``build_fn`` will raise an Error instead.
X, y = self._initialize(X, y)
c:\Users\Asus\anaconda3\lib\site-packages\scikeras\wrappers.py:915: UserWarning: ``build_fn`` will be renamed to ``model`` in a future release, at which point use of ``build_fn`` will raise an Error instead.
X, y = self._initialize(X, y)
c:\Users\Asus\anaconda3\lib\site-packages\scikeras\wrappers.py:915: UserWarning: ``build_fn`` will be renamed to ``model`` in a future release, at which point use of ``build_fn`` will raise an Error instead.
X, y = self._initialize(X, y)
c:\Users\Asus\anaconda3\lib\site-packages\scikeras\wrappers.py:915: UserWarning: ``build_fn`` will be renamed to ``model`` in a future release, at which point use of ``build_fn`` will raise an Error instead.
X, y = self._initialize(X, y)
c:\Users\Asus\anaconda3\lib\site-packages\scikeras\wrappers.py:915: UserWarning: ``build_fn`` will be renamed to ``model`` in a future release, at which point use of ``build_fn`` will raise an Error instead.
X, y = self._initialize(X, y)
c:\Users\Asus\anaconda3\lib\site-packages\scikeras\wrappers.py:915: UserWarning: ``build_fn`` will be renamed to ``model`` in a future release, at which point use of ``build_fn`` will raise an Error instead.
X, y = self._initialize(X, y)
c:\Users\Asus\anaconda3\lib\site-packages\scikeras\wrappers.py:915: UserWarning: ``build_fn`` will be renamed to ``model`` in a future release, at which point use of ``build_fn`` will raise an Error instead.
X, y = self._initialize(X, y)
c:\Users\Asus\anaconda3\lib\site-packages\scikeras\wrappers.py:915: UserWarning: ``build_fn`` will be renamed to ``model`` in a future release, at which point use of ``build_fn`` will raise an Error instead.
X, y = self._initialize(X, y)
c:\Users\Asus\anaconda3\lib\site-packages\scikeras\wrappers.py:915: UserWarning: ``build_fn`` will be renamed to ``model`` in a future release, at which point use of ``build_fn`` will raise an Error instead.
X, y = self._initialize(X, y)
c:\Users\Asus\anaconda3\lib\site-packages\scikeras\wrappers.py:915: UserWarning: ``build_fn`` will be renamed to ``model`` in a future release, at which point use of ``build_fn`` will raise an Error instead.
X, y = self._initialize(X, y)
c:\Users\Asus\anaconda3\lib\site-packages\scikeras\wrappers.py:915: UserWarning: ``build_fn`` will be renamed to ``model`` in a future release, at which point use of ``build_fn`` will raise an Error instead.
X, y = self._initialize(X, y)
c:\Users\Asus\anaconda3\lib\site-packages\scikeras\wrappers.py:915: UserWarning: ``build_fn`` will be renamed to ``model`` in a future release, at which point use of ``build_fn`` will raise an Error instead.
X, y = self._initialize(X, y)
c:\Users\Asus\anaconda3\lib\site-packages\scikeras\wrappers.py:915: UserWarning: ``build_fn`` will be renamed to ``model`` in a future release, at which point use of ``build_fn`` will raise an Error instead.
X, y = self._initialize(X, y)
c:\Users\Asus\anaconda3\lib\site-packages\scikeras\wrappers.py:915: UserWarning: ``build_fn`` will be renamed to ``model`` in a future release, at which point use of ``build_fn`` will raise an Error instead.
X, y = self._initialize(X, y)
c:\Users\Asus\anaconda3\lib\site-packages\scikeras\wrappers.py:915: UserWarning: ``build_fn`` will be renamed to ``model`` in a future release, at which point use of ``build_fn`` will raise an Error instead.
X, y = self._initialize(X, y)
c:\Users\Asus\anaconda3\lib\site-packages\scikeras\wrappers.py:915: UserWarning: ``build_fn`` will be renamed to ``model`` in a future release, at which point use of ``build_fn`` will raise an Error instead.
X, y = self._initialize(X, y)
c:\Users\Asus\anaconda3\lib\site-packages\scikeras\wrappers.py:915: UserWarning: ``build_fn`` will be renamed to ``model`` in a future release, at which point use of ``build_fn`` will raise an Error instead.
X, y = self._initialize(X, y)
c:\Users\Asus\anaconda3\lib\site-packages\scikeras\wrappers.py:915: UserWarning: ``build_fn`` will be renamed to ``model`` in a future release, at which point use of ``build_fn`` will raise an Error instead.
X, y = self._initialize(X, y)
c:\Users\Asus\anaconda3\lib\site-packages\scikeras\wrappers.py:915: UserWarning: ``build_fn`` will be renamed to ``model`` in a future release, at which point use of ``build_fn`` will raise an Error instead.
X, y = self._initialize(X, y)
c:\Users\Asus\anaconda3\lib\site-packages\scikeras\wrappers.py:915: UserWarning: ``build_fn`` will be renamed to ``model`` in a future release, at which point use of ``build_fn`` will raise an Error instead.
X, y = self._initialize(X, y)
c:\Users\Asus\anaconda3\lib\site-packages\scikeras\wrappers.py:915: UserWarning: ``build_fn`` will be renamed to ``model`` in a future release, at which point use of ``build_fn`` will raise an Error instead.
X, y = self._initialize(X, y)
c:\Users\Asus\anaconda3\lib\site-packages\scikeras\wrappers.py:915: UserWarning: ``build_fn`` will be renamed to ``model`` in a future release, at which point use of ``build_fn`` will raise an Error instead.
X, y = self._initialize(X, y)
c:\Users\Asus\anaconda3\lib\site-packages\scikeras\wrappers.py:915: UserWarning: ``build_fn`` will be renamed to ``model`` in a future release, at which point use of ``build_fn`` will raise an Error instead.
X, y = self._initialize(X, y)
c:\Users\Asus\anaconda3\lib\site-packages\scikeras\wrappers.py:915: UserWarning: ``build_fn`` will be renamed to ``model`` in a future release, at which point use of ``build_fn`` will raise an Error instead.
X, y = self._initialize(X, y)
c:\Users\Asus\anaconda3\lib\site-packages\scikeras\wrappers.py:915: UserWarning: ``build_fn`` will be renamed to ``model`` in a future release, at which point use of ``build_fn`` will raise an Error instead.
X, y = self._initialize(X, y)
c:\Users\Asus\anaconda3\lib\site-packages\scikeras\wrappers.py:915: UserWarning: ``build_fn`` will be renamed to ``model`` in a future release, at which point use of ``build_fn`` will raise an Error instead.
X, y = self._initialize(X, y)
c:\Users\Asus\anaconda3\lib\site-packages\scikeras\wrappers.py:915: UserWarning: ``build_fn`` will be renamed to ``model`` in a future release, at which point use of ``build_fn`` will raise an Error instead.
X, y = self._initialize(X, y)
c:\Users\Asus\anaconda3\lib\site-packages\scikeras\wrappers.py:915: UserWarning: ``build_fn`` will be renamed to ``model`` in a future release, at which point use of ``build_fn`` will raise an Error instead.
X, y = self._initialize(X, y)
c:\Users\Asus\anaconda3\lib\site-packages\scikeras\wrappers.py:915: UserWarning: ``build_fn`` will be renamed to ``model`` in a future release, at which point use of ``build_fn`` will raise an Error instead.
X, y = self._initialize(X, y)
c:\Users\Asus\anaconda3\lib\site-packages\scikeras\wrappers.py:915: UserWarning: ``build_fn`` will be renamed to ``model`` in a future release, at which point use of ``build_fn`` will raise an Error instead.
X, y = self._initialize(X, y)
c:\Users\Asus\anaconda3\lib\site-packages\scikeras\wrappers.py:915: UserWarning: ``build_fn`` will be renamed to ``model`` in a future release, at which point use of ``build_fn`` will raise an Error instead.
X, y = self._initialize(X, y)
c:\Users\Asus\anaconda3\lib\site-packages\scikeras\wrappers.py:915: UserWarning: ``build_fn`` will be renamed to ``model`` in a future release, at which point use of ``build_fn`` will raise an Error instead.
X, y = self._initialize(X, y)
c:\Users\Asus\anaconda3\lib\site-packages\scikeras\wrappers.py:915: UserWarning: ``build_fn`` will be renamed to ``model`` in a future release, at which point use of ``build_fn`` will raise an Error instead.
X, y = self._initialize(X, y)
c:\Users\Asus\anaconda3\lib\site-packages\scikeras\wrappers.py:915: UserWarning: ``build_fn`` will be renamed to ``model`` in a future release, at which point use of ``build_fn`` will raise an Error instead.
X, y = self._initialize(X, y)
c:\Users\Asus\anaconda3\lib\site-packages\scikeras\wrappers.py:915: UserWarning: ``build_fn`` will be renamed to ``model`` in a future release, at which point use of ``build_fn`` will raise an Error instead.
X, y = self._initialize(X, y)
c:\Users\Asus\anaconda3\lib\site-packages\scikeras\wrappers.py:915: UserWarning: ``build_fn`` will be renamed to ``model`` in a future release, at which point use of ``build_fn`` will raise an Error instead.
X, y = self._initialize(X, y)
c:\Users\Asus\anaconda3\lib\site-packages\scikeras\wrappers.py:915: UserWarning: ``build_fn`` will be renamed to ``model`` in a future release, at which point use of ``build_fn`` will raise an Error instead.
X, y = self._initialize(X, y)
c:\Users\Asus\anaconda3\lib\site-packages\scikeras\wrappers.py:915: UserWarning: ``build_fn`` will be renamed to ``model`` in a future release, at which point use of ``build_fn`` will raise an Error instead.
X, y = self._initialize(X, y)
c:\Users\Asus\anaconda3\lib\site-packages\scikeras\wrappers.py:915: UserWarning: ``build_fn`` will be renamed to ``model`` in a future release, at which point use of ``build_fn`` will raise an Error instead.
X, y = self._initialize(X, y)
c:\Users\Asus\anaconda3\lib\site-packages\scikeras\wrappers.py:915: UserWarning: ``build_fn`` will be renamed to ``model`` in a future release, at which point use of ``build_fn`` will raise an Error instead.
X, y = self._initialize(X, y)
c:\Users\Asus\anaconda3\lib\site-packages\scikeras\wrappers.py:915: UserWarning: ``build_fn`` will be renamed to ``model`` in a future release, at which point use of ``build_fn`` will raise an Error instead.
X, y = self._initialize(X, y)
c:\Users\Asus\anaconda3\lib\site-packages\scikeras\wrappers.py:915: UserWarning: ``build_fn`` will be renamed to ``model`` in a future release, at which point use of ``build_fn`` will raise an Error instead.
X, y = self._initialize(X, y)
c:\Users\Asus\anaconda3\lib\site-packages\scikeras\wrappers.py:915: UserWarning: ``build_fn`` will be renamed to ``model`` in a future release, at which point use of ``build_fn`` will raise an Error instead.
X, y = self._initialize(X, y)
c:\Users\Asus\anaconda3\lib\site-packages\scikeras\wrappers.py:915: UserWarning: ``build_fn`` will be renamed to ``model`` in a future release, at which point use of ``build_fn`` will raise an Error instead.
X, y = self._initialize(X, y)
c:\Users\Asus\anaconda3\lib\site-packages\scikeras\wrappers.py:915: UserWarning: ``build_fn`` will be renamed to ``model`` in a future release, at which point use of ``build_fn`` will raise an Error instead.
X, y = self._initialize(X, y)
c:\Users\Asus\anaconda3\lib\site-packages\scikeras\wrappers.py:915: UserWarning: ``build_fn`` will be renamed to ``model`` in a future release, at which point use of ``build_fn`` will raise an Error instead.
X, y = self._initialize(X, y)
c:\Users\Asus\anaconda3\lib\site-packages\scikeras\wrappers.py:915: UserWarning: ``build_fn`` will be renamed to ``model`` in a future release, at which point use of ``build_fn`` will raise an Error instead.
X, y = self._initialize(X, y)
c:\Users\Asus\anaconda3\lib\site-packages\scikeras\wrappers.py:915: UserWarning: ``build_fn`` will be renamed to ``model`` in a future release, at which point use of ``build_fn`` will raise an Error instead.
X, y = self._initialize(X, y)
c:\Users\Asus\anaconda3\lib\site-packages\scikeras\wrappers.py:915: UserWarning: ``build_fn`` will be renamed to ``model`` in a future release, at which point use of ``build_fn`` will raise an Error instead.
X, y = self._initialize(X, y)
c:\Users\Asus\anaconda3\lib\site-packages\scikeras\wrappers.py:915: UserWarning: ``build_fn`` will be renamed to ``model`` in a future release, at which point use of ``build_fn`` will raise an Error instead.
X, y = self._initialize(X, y)
c:\Users\Asus\anaconda3\lib\site-packages\scikeras\wrappers.py:915: UserWarning: ``build_fn`` will be renamed to ``model`` in a future release, at which point use of ``build_fn`` will raise an Error instead.
X, y = self._initialize(X, y)
c:\Users\Asus\anaconda3\lib\site-packages\scikeras\wrappers.py:915: UserWarning: ``build_fn`` will be renamed to ``model`` in a future release, at which point use of ``build_fn`` will raise an Error instead.
X, y = self._initialize(X, y)
c:\Users\Asus\anaconda3\lib\site-packages\scikeras\wrappers.py:915: UserWarning: ``build_fn`` will be renamed to ``model`` in a future release, at which point use of ``build_fn`` will raise an Error instead.
X, y = self._initialize(X, y)
c:\Users\Asus\anaconda3\lib\site-packages\scikeras\wrappers.py:915: UserWarning: ``build_fn`` will be renamed to ``model`` in a future release, at which point use of ``build_fn`` will raise an Error instead.
X, y = self._initialize(X, y)
c:\Users\Asus\anaconda3\lib\site-packages\scikeras\wrappers.py:915: UserWarning: ``build_fn`` will be renamed to ``model`` in a future release, at which point use of ``build_fn`` will raise an Error instead.
X, y = self._initialize(X, y)
c:\Users\Asus\anaconda3\lib\site-packages\scikeras\wrappers.py:915: UserWarning: ``build_fn`` will be renamed to ``model`` in a future release, at which point use of ``build_fn`` will raise an Error instead.
X, y = self._initialize(X, y)
c:\Users\Asus\anaconda3\lib\site-packages\scikeras\wrappers.py:915: UserWarning: ``build_fn`` will be renamed to ``model`` in a future release, at which point use of ``build_fn`` will raise an Error instead.
X, y = self._initialize(X, y)
c:\Users\Asus\anaconda3\lib\site-packages\scikeras\wrappers.py:915: UserWarning: ``build_fn`` will be renamed to ``model`` in a future release, at which point use of ``build_fn`` will raise an Error instead.
X, y = self._initialize(X, y)
c:\Users\Asus\anaconda3\lib\site-packages\scikeras\wrappers.py:915: UserWarning: ``build_fn`` will be renamed to ``model`` in a future release, at which point use of ``build_fn`` will raise an Error instead.
X, y = self._initialize(X, y)
c:\Users\Asus\anaconda3\lib\site-packages\sklearn\model_selection\_validation.py:425: FitFailedWarning:
54 fits failed out of a total of 72.
The score on these train-test partitions for these parameters will be set to nan.
If these failures are not expected, you can try to debug them by setting error_score='raise'.
Below are more details about the failures:
--------------------------------------------------------------------------------
18 fits failed with the following error:
Traceback (most recent call last):
File "c:\Users\Asus\anaconda3\lib\site-packages\sklearn\model_selection\_validation.py", line 729, in _fit_and_score
estimator.fit(X_train, y_train, **fit_params)
File "c:\Users\Asus\anaconda3\lib\site-packages\scikeras\wrappers.py", line 1491, in fit
super().fit(X=X, y=y, sample_weight=sample_weight, **kwargs)
File "c:\Users\Asus\anaconda3\lib\site-packages\scikeras\wrappers.py", line 760, in fit
self._fit(
File "c:\Users\Asus\anaconda3\lib\site-packages\scikeras\wrappers.py", line 926, in _fit
self._check_model_compatibility(y)
File "c:\Users\Asus\anaconda3\lib\site-packages\scikeras\wrappers.py", line 569, in _check_model_compatibility
raise ValueError(
ValueError: loss=binary_crossentropy but model compiled with categorical_crossentropy. Data may not match loss function!
--------------------------------------------------------------------------------
18 fits failed with the following error:
Traceback (most recent call last):
File "c:\Users\Asus\anaconda3\lib\site-packages\sklearn\model_selection\_validation.py", line 729, in _fit_and_score
estimator.fit(X_train, y_train, **fit_params)
File "c:\Users\Asus\anaconda3\lib\site-packages\scikeras\wrappers.py", line 1491, in fit
super().fit(X=X, y=y, sample_weight=sample_weight, **kwargs)
File "c:\Users\Asus\anaconda3\lib\site-packages\scikeras\wrappers.py", line 760, in fit
self._fit(
File "c:\Users\Asus\anaconda3\lib\site-packages\scikeras\wrappers.py", line 926, in _fit
self._check_model_compatibility(y)
File "c:\Users\Asus\anaconda3\lib\site-packages\scikeras\wrappers.py", line 569, in _check_model_compatibility
raise ValueError(
ValueError: loss=hinge but model compiled with categorical_crossentropy. Data may not match loss function!
--------------------------------------------------------------------------------
18 fits failed with the following error:
Traceback (most recent call last):
File "c:\Users\Asus\anaconda3\lib\site-packages\sklearn\model_selection\_validation.py", line 729, in _fit_and_score
estimator.fit(X_train, y_train, **fit_params)
File "c:\Users\Asus\anaconda3\lib\site-packages\scikeras\wrappers.py", line 1491, in fit
super().fit(X=X, y=y, sample_weight=sample_weight, **kwargs)
File "c:\Users\Asus\anaconda3\lib\site-packages\scikeras\wrappers.py", line 760, in fit
self._fit(
File "c:\Users\Asus\anaconda3\lib\site-packages\scikeras\wrappers.py", line 926, in _fit
self._check_model_compatibility(y)
File "c:\Users\Asus\anaconda3\lib\site-packages\scikeras\wrappers.py", line 569, in _check_model_compatibility
raise ValueError(
ValueError: loss=kullback_leibler_divergence but model compiled with categorical_crossentropy. Data may not match loss function!
warnings.warn(some_fits_failed_message, FitFailedWarning)
c:\Users\Asus\anaconda3\lib\site-packages\sklearn\model_selection\_search.py:979: UserWarning: One or more of the test scores are non-finite: [0.98856784 0.9781407 0.97512563 0.98153266 0.97022613 0.97248744
0.98592965 0.98090452 0.97374372 nan nan nan
nan nan nan nan nan nan
nan nan nan nan nan nan
nan nan nan nan nan nan
nan nan nan nan nan nan]
warnings.warn(
c:\Users\Asus\anaconda3\lib\site-packages\scikeras\wrappers.py:915: UserWarning: ``build_fn`` will be renamed to ``model`` in a future release, at which point use of ``build_fn`` will raise an Error instead.
X, y = self._initialize(X, y)
print("Best: %f using %s" % (grid_result.best_score_, grid_result.best_params_))
# best_model = grid_result.best_estimator_
# accuracy = best_model.score(X_val, y_val)
# print("Validation Accuracy: %.2f%%" % (accuracy * 100))
Best: 0.988568 using {'loss': 'categorical_crossentropy', 'optimizer': 'adam', 'units': 32}
optimizer = grid_result.best_params_['optimizer']
units = grid_result.best_params_['units']
loss = grid_result.best_params_['loss']
model = Sequential()
model.add(LSTM(units, return_sequences = False, input_shape=(X_train.shape[1], 1)))
model.add(Dense(units, activation='relu'))
model.add(Dense(units=3, activation='sigmoid'))
model.compile(loss=loss, optimizer=optimizer, metrics=['mean_absolute_error'])
history = model.fit(X_train, y_train, epochs=100, batch_size=32, validation_data=(X_val, y_val))
accuracy = model.evaluate(X_val, y_val)
print("Accuracy: ", accuracy)
Epoch 1/100 249/249 [==============================] - 5s 10ms/step - loss: 0.8214 - mean_absolute_error: 0.4385 - val_loss: 0.4796 - val_mean_absolute_error: 0.3500 Epoch 2/100 249/249 [==============================] - 2s 7ms/step - loss: 0.3203 - mean_absolute_error: 0.3053 - val_loss: 0.2304 - val_mean_absolute_error: 0.2844 Epoch 3/100 249/249 [==============================] - 2s 7ms/step - loss: 0.1933 - mean_absolute_error: 0.2768 - val_loss: 0.1760 - val_mean_absolute_error: 0.2717 Epoch 4/100 249/249 [==============================] - 2s 7ms/step - loss: 0.1469 - mean_absolute_error: 0.2691 - val_loss: 0.1905 - val_mean_absolute_error: 0.2692 Epoch 5/100 249/249 [==============================] - 2s 8ms/step - loss: 0.1205 - mean_absolute_error: 0.2635 - val_loss: 0.1255 - val_mean_absolute_error: 0.2611 Epoch 6/100 249/249 [==============================] - 2s 6ms/step - loss: 0.1026 - mean_absolute_error: 0.2611 - val_loss: 0.0962 - val_mean_absolute_error: 0.2578 Epoch 7/100 249/249 [==============================] - 2s 7ms/step - loss: 0.0975 - mean_absolute_error: 0.2588 - val_loss: 0.0865 - val_mean_absolute_error: 0.2549 Epoch 8/100 249/249 [==============================] - 2s 6ms/step - loss: 0.0809 - mean_absolute_error: 0.2571 - val_loss: 0.0827 - val_mean_absolute_error: 0.2546 Epoch 9/100 249/249 [==============================] - 2s 6ms/step - loss: 0.0881 - mean_absolute_error: 0.2562 - val_loss: 0.0832 - val_mean_absolute_error: 0.2525 Epoch 10/100 249/249 [==============================] - 2s 7ms/step - loss: 0.0696 - mean_absolute_error: 0.2545 - val_loss: 0.0757 - val_mean_absolute_error: 0.2496 Epoch 11/100 249/249 [==============================] - 2s 7ms/step - loss: 0.0682 - mean_absolute_error: 0.2536 - val_loss: 0.0779 - val_mean_absolute_error: 0.2530 Epoch 12/100 249/249 [==============================] - 2s 7ms/step - loss: 0.0585 - mean_absolute_error: 0.2525 - val_loss: 0.0679 - val_mean_absolute_error: 0.2502 Epoch 13/100 249/249 [==============================] - 2s 7ms/step - loss: 0.0592 - mean_absolute_error: 0.2518 - val_loss: 0.0727 - val_mean_absolute_error: 0.2493 Epoch 14/100 249/249 [==============================] - 2s 7ms/step - loss: 0.0560 - mean_absolute_error: 0.2506 - val_loss: 0.0579 - val_mean_absolute_error: 0.2476 Epoch 15/100 249/249 [==============================] - 2s 7ms/step - loss: 0.0588 - mean_absolute_error: 0.2495 - val_loss: 0.0666 - val_mean_absolute_error: 0.2486 Epoch 16/100 249/249 [==============================] - 2s 7ms/step - loss: 0.0535 - mean_absolute_error: 0.2495 - val_loss: 0.0532 - val_mean_absolute_error: 0.2443 Epoch 17/100 249/249 [==============================] - 2s 7ms/step - loss: 0.0585 - mean_absolute_error: 0.2485 - val_loss: 0.0676 - val_mean_absolute_error: 0.2469 Epoch 18/100 249/249 [==============================] - 2s 7ms/step - loss: 0.0489 - mean_absolute_error: 0.2475 - val_loss: 0.0513 - val_mean_absolute_error: 0.2443 Epoch 19/100 249/249 [==============================] - 2s 7ms/step - loss: 0.0465 - mean_absolute_error: 0.2474 - val_loss: 0.0567 - val_mean_absolute_error: 0.2424 Epoch 20/100 249/249 [==============================] - 2s 7ms/step - loss: 0.0439 - mean_absolute_error: 0.2456 - val_loss: 0.0445 - val_mean_absolute_error: 0.2426 Epoch 21/100 249/249 [==============================] - 2s 6ms/step - loss: 0.0455 - mean_absolute_error: 0.2457 - val_loss: 0.0584 - val_mean_absolute_error: 0.2430 Epoch 22/100 249/249 [==============================] - 2s 7ms/step - loss: 0.0436 - mean_absolute_error: 0.2452 - val_loss: 0.0499 - val_mean_absolute_error: 0.2418 Epoch 23/100 249/249 [==============================] - 2s 7ms/step - loss: 0.0531 - mean_absolute_error: 0.2442 - val_loss: 0.0402 - val_mean_absolute_error: 0.2393 Epoch 24/100 249/249 [==============================] - 2s 7ms/step - loss: 0.0385 - mean_absolute_error: 0.2438 - val_loss: 0.0516 - val_mean_absolute_error: 0.2395 Epoch 25/100 249/249 [==============================] - 2s 6ms/step - loss: 0.0394 - mean_absolute_error: 0.2424 - val_loss: 0.0575 - val_mean_absolute_error: 0.2393 Epoch 26/100 249/249 [==============================] - 2s 7ms/step - loss: 0.0411 - mean_absolute_error: 0.2425 - val_loss: 0.0444 - val_mean_absolute_error: 0.2400 Epoch 27/100 249/249 [==============================] - 2s 7ms/step - loss: 0.0351 - mean_absolute_error: 0.2419 - val_loss: 0.0420 - val_mean_absolute_error: 0.2366 Epoch 28/100 249/249 [==============================] - 2s 7ms/step - loss: 0.0420 - mean_absolute_error: 0.2402 - val_loss: 0.0507 - val_mean_absolute_error: 0.2382 Epoch 29/100 249/249 [==============================] - 2s 6ms/step - loss: 0.0434 - mean_absolute_error: 0.2399 - val_loss: 0.0446 - val_mean_absolute_error: 0.2379 Epoch 30/100 249/249 [==============================] - 2s 6ms/step - loss: 0.0375 - mean_absolute_error: 0.2393 - val_loss: 0.0441 - val_mean_absolute_error: 0.2338 Epoch 31/100 249/249 [==============================] - 2s 7ms/step - loss: 0.0357 - mean_absolute_error: 0.2406 - val_loss: 0.0522 - val_mean_absolute_error: 0.2386 Epoch 32/100 249/249 [==============================] - 2s 7ms/step - loss: 0.0335 - mean_absolute_error: 0.2388 - val_loss: 0.0541 - val_mean_absolute_error: 0.2342 Epoch 33/100 249/249 [==============================] - 2s 8ms/step - loss: 0.0524 - mean_absolute_error: 0.2387 - val_loss: 0.0384 - val_mean_absolute_error: 0.2361 Epoch 34/100 249/249 [==============================] - 2s 7ms/step - loss: 0.0330 - mean_absolute_error: 0.2395 - val_loss: 0.0745 - val_mean_absolute_error: 0.2381 Epoch 35/100 249/249 [==============================] - 2s 7ms/step - loss: 0.0359 - mean_absolute_error: 0.2383 - val_loss: 0.0428 - val_mean_absolute_error: 0.2313 Epoch 36/100 249/249 [==============================] - 2s 7ms/step - loss: 0.0429 - mean_absolute_error: 0.2371 - val_loss: 0.0392 - val_mean_absolute_error: 0.2336 Epoch 37/100 249/249 [==============================] - 2s 7ms/step - loss: 0.0340 - mean_absolute_error: 0.2386 - val_loss: 0.0334 - val_mean_absolute_error: 0.2321 Epoch 38/100 249/249 [==============================] - 2s 7ms/step - loss: 0.0350 - mean_absolute_error: 0.2377 - val_loss: 0.0338 - val_mean_absolute_error: 0.2342 Epoch 39/100 249/249 [==============================] - 2s 7ms/step - loss: 0.0276 - mean_absolute_error: 0.2376 - val_loss: 0.0299 - val_mean_absolute_error: 0.2318 Epoch 40/100 249/249 [==============================] - 2s 7ms/step - loss: 0.0360 - mean_absolute_error: 0.2357 - val_loss: 0.0452 - val_mean_absolute_error: 0.2338 Epoch 41/100 249/249 [==============================] - 2s 7ms/step - loss: 0.0362 - mean_absolute_error: 0.2363 - val_loss: 0.0537 - val_mean_absolute_error: 0.2348 Epoch 42/100 249/249 [==============================] - 2s 7ms/step - loss: 0.0422 - mean_absolute_error: 0.2369 - val_loss: 0.0433 - val_mean_absolute_error: 0.2347 Epoch 43/100 249/249 [==============================] - 2s 7ms/step - loss: 0.0251 - mean_absolute_error: 0.2362 - val_loss: 0.0413 - val_mean_absolute_error: 0.2318 Epoch 44/100 249/249 [==============================] - 2s 7ms/step - loss: 0.0279 - mean_absolute_error: 0.2352 - val_loss: 0.0522 - val_mean_absolute_error: 0.2334 Epoch 45/100 249/249 [==============================] - 2s 7ms/step - loss: 0.0307 - mean_absolute_error: 0.2347 - val_loss: 0.0717 - val_mean_absolute_error: 0.2341 Epoch 46/100 249/249 [==============================] - 2s 6ms/step - loss: 0.0377 - mean_absolute_error: 0.2348 - val_loss: 0.0392 - val_mean_absolute_error: 0.2332 Epoch 47/100 249/249 [==============================] - 2s 7ms/step - loss: 0.0279 - mean_absolute_error: 0.2345 - val_loss: 0.0455 - val_mean_absolute_error: 0.2318 Epoch 48/100 249/249 [==============================] - 2s 7ms/step - loss: 0.0363 - mean_absolute_error: 0.2328 - val_loss: 0.0339 - val_mean_absolute_error: 0.2297 Epoch 49/100 249/249 [==============================] - 2s 7ms/step - loss: 0.0233 - mean_absolute_error: 0.2325 - val_loss: 0.0372 - val_mean_absolute_error: 0.2295 Epoch 50/100 249/249 [==============================] - 2s 7ms/step - loss: 0.0347 - mean_absolute_error: 0.2316 - val_loss: 0.0270 - val_mean_absolute_error: 0.2274 Epoch 51/100 249/249 [==============================] - 2s 6ms/step - loss: 0.0301 - mean_absolute_error: 0.2324 - val_loss: 0.0430 - val_mean_absolute_error: 0.2271 Epoch 52/100 249/249 [==============================] - 2s 7ms/step - loss: 0.0227 - mean_absolute_error: 0.2313 - val_loss: 0.0330 - val_mean_absolute_error: 0.2262 Epoch 53/100 249/249 [==============================] - 2s 7ms/step - loss: 0.0311 - mean_absolute_error: 0.2311 - val_loss: 0.0384 - val_mean_absolute_error: 0.2257 Epoch 54/100 249/249 [==============================] - 2s 7ms/step - loss: 0.0352 - mean_absolute_error: 0.2300 - val_loss: 0.1161 - val_mean_absolute_error: 0.2291 Epoch 55/100 249/249 [==============================] - 2s 7ms/step - loss: 0.0304 - mean_absolute_error: 0.2321 - val_loss: 0.0405 - val_mean_absolute_error: 0.2234 Epoch 56/100 249/249 [==============================] - 2s 6ms/step - loss: 0.0242 - mean_absolute_error: 0.2299 - val_loss: 0.0299 - val_mean_absolute_error: 0.2261 Epoch 57/100 249/249 [==============================] - 2s 7ms/step - loss: 0.0227 - mean_absolute_error: 0.2302 - val_loss: 0.0386 - val_mean_absolute_error: 0.2276 Epoch 58/100 249/249 [==============================] - 2s 7ms/step - loss: 0.0351 - mean_absolute_error: 0.2297 - val_loss: 0.0322 - val_mean_absolute_error: 0.2256 Epoch 59/100 249/249 [==============================] - 2s 7ms/step - loss: 0.0238 - mean_absolute_error: 0.2296 - val_loss: 0.0269 - val_mean_absolute_error: 0.2253 Epoch 60/100 249/249 [==============================] - 2s 7ms/step - loss: 0.0211 - mean_absolute_error: 0.2272 - val_loss: 0.0383 - val_mean_absolute_error: 0.2238 Epoch 61/100 249/249 [==============================] - 2s 7ms/step - loss: 0.0537 - mean_absolute_error: 0.2279 - val_loss: 0.0250 - val_mean_absolute_error: 0.2276 Epoch 62/100 249/249 [==============================] - 2s 7ms/step - loss: 0.0236 - mean_absolute_error: 0.2309 - val_loss: 0.0517 - val_mean_absolute_error: 0.2265 Epoch 63/100 249/249 [==============================] - 2s 6ms/step - loss: 0.0195 - mean_absolute_error: 0.2293 - val_loss: 0.0255 - val_mean_absolute_error: 0.2214 Epoch 64/100 249/249 [==============================] - 2s 7ms/step - loss: 0.0258 - mean_absolute_error: 0.2272 - val_loss: 0.0445 - val_mean_absolute_error: 0.2200 Epoch 65/100 249/249 [==============================] - 2s 7ms/step - loss: 0.0242 - mean_absolute_error: 0.2260 - val_loss: 0.0271 - val_mean_absolute_error: 0.2212 Epoch 66/100 249/249 [==============================] - 2s 7ms/step - loss: 0.0352 - mean_absolute_error: 0.2268 - val_loss: 0.0640 - val_mean_absolute_error: 0.2225 Epoch 67/100 249/249 [==============================] - 2s 7ms/step - loss: 0.0333 - mean_absolute_error: 0.2280 - val_loss: 0.0388 - val_mean_absolute_error: 0.2244 Epoch 68/100 249/249 [==============================] - 2s 7ms/step - loss: 0.0295 - mean_absolute_error: 0.2284 - val_loss: 0.0303 - val_mean_absolute_error: 0.2272 Epoch 69/100 249/249 [==============================] - 2s 6ms/step - loss: 0.0292 - mean_absolute_error: 0.2297 - val_loss: 0.0283 - val_mean_absolute_error: 0.2262 Epoch 70/100 249/249 [==============================] - 2s 7ms/step - loss: 0.0225 - mean_absolute_error: 0.2289 - val_loss: 0.0329 - val_mean_absolute_error: 0.2233 Epoch 71/100 249/249 [==============================] - 2s 7ms/step - loss: 0.0240 - mean_absolute_error: 0.2276 - val_loss: 0.0302 - val_mean_absolute_error: 0.2244 Epoch 72/100 249/249 [==============================] - 2s 7ms/step - loss: 0.0231 - mean_absolute_error: 0.2260 - val_loss: 0.0262 - val_mean_absolute_error: 0.2204 Epoch 73/100 249/249 [==============================] - 2s 7ms/step - loss: 0.0220 - mean_absolute_error: 0.2251 - val_loss: 0.0422 - val_mean_absolute_error: 0.2158 Epoch 74/100 249/249 [==============================] - 2s 7ms/step - loss: 0.0208 - mean_absolute_error: 0.2238 - val_loss: 0.0219 - val_mean_absolute_error: 0.2195 Epoch 75/100 249/249 [==============================] - 2s 7ms/step - loss: 0.0350 - mean_absolute_error: 0.2240 - val_loss: 0.0579 - val_mean_absolute_error: 0.2176 Epoch 76/100 249/249 [==============================] - 2s 7ms/step - loss: 0.0196 - mean_absolute_error: 0.2248 - val_loss: 0.0183 - val_mean_absolute_error: 0.2215 Epoch 77/100 249/249 [==============================] - 2s 7ms/step - loss: 0.0227 - mean_absolute_error: 0.2244 - val_loss: 0.0326 - val_mean_absolute_error: 0.2187 Epoch 78/100 249/249 [==============================] - 2s 7ms/step - loss: 0.0204 - mean_absolute_error: 0.2228 - val_loss: 0.0266 - val_mean_absolute_error: 0.2135 Epoch 79/100 249/249 [==============================] - 2s 7ms/step - loss: 0.0192 - mean_absolute_error: 0.2210 - val_loss: 0.0196 - val_mean_absolute_error: 0.2137 Epoch 80/100 249/249 [==============================] - 2s 7ms/step - loss: 0.0187 - mean_absolute_error: 0.2187 - val_loss: 0.0385 - val_mean_absolute_error: 0.2141 Epoch 81/100 249/249 [==============================] - 2s 7ms/step - loss: 0.0382 - mean_absolute_error: 0.2200 - val_loss: 0.0182 - val_mean_absolute_error: 0.2166 Epoch 82/100 249/249 [==============================] - 2s 7ms/step - loss: 0.0188 - mean_absolute_error: 0.2186 - val_loss: 0.0206 - val_mean_absolute_error: 0.2096 Epoch 83/100 249/249 [==============================] - 2s 6ms/step - loss: 0.0266 - mean_absolute_error: 0.2177 - val_loss: 0.0347 - val_mean_absolute_error: 0.2134 Epoch 84/100 249/249 [==============================] - 2s 6ms/step - loss: 0.0163 - mean_absolute_error: 0.2192 - val_loss: 0.0216 - val_mean_absolute_error: 0.2145 Epoch 85/100 249/249 [==============================] - 2s 6ms/step - loss: 0.0408 - mean_absolute_error: 0.2183 - val_loss: 0.0270 - val_mean_absolute_error: 0.2181 Epoch 86/100 249/249 [==============================] - 2s 7ms/step - loss: 0.0214 - mean_absolute_error: 0.2200 - val_loss: 0.0218 - val_mean_absolute_error: 0.2142 Epoch 87/100 249/249 [==============================] - 2s 7ms/step - loss: 0.0204 - mean_absolute_error: 0.2180 - val_loss: 0.0318 - val_mean_absolute_error: 0.2120 Epoch 88/100 249/249 [==============================] - 2s 7ms/step - loss: 0.0226 - mean_absolute_error: 0.2189 - val_loss: 0.0191 - val_mean_absolute_error: 0.2154 Epoch 89/100 249/249 [==============================] - 2s 7ms/step - loss: 0.0175 - mean_absolute_error: 0.2185 - val_loss: 0.0205 - val_mean_absolute_error: 0.2127 Epoch 90/100 249/249 [==============================] - 2s 7ms/step - loss: 0.0243 - mean_absolute_error: 0.2166 - val_loss: 0.0234 - val_mean_absolute_error: 0.2111 Epoch 91/100 249/249 [==============================] - 2s 7ms/step - loss: 0.0175 - mean_absolute_error: 0.2152 - val_loss: 0.0233 - val_mean_absolute_error: 0.2064 Epoch 92/100 249/249 [==============================] - 2s 6ms/step - loss: 0.0235 - mean_absolute_error: 0.2139 - val_loss: 0.0350 - val_mean_absolute_error: 0.2084 Epoch 93/100 249/249 [==============================] - 2s 7ms/step - loss: 0.0175 - mean_absolute_error: 0.2140 - val_loss: 0.0553 - val_mean_absolute_error: 0.2101 Epoch 94/100 249/249 [==============================] - 2s 7ms/step - loss: 0.0222 - mean_absolute_error: 0.2128 - val_loss: 0.0252 - val_mean_absolute_error: 0.2075 Epoch 95/100 249/249 [==============================] - 2s 6ms/step - loss: 0.0218 - mean_absolute_error: 0.2121 - val_loss: 0.0423 - val_mean_absolute_error: 0.2058 Epoch 96/100 249/249 [==============================] - 2s 6ms/step - loss: 0.0313 - mean_absolute_error: 0.2129 - val_loss: 0.0392 - val_mean_absolute_error: 0.2037 Epoch 97/100 249/249 [==============================] - 2s 6ms/step - loss: 0.0184 - mean_absolute_error: 0.2131 - val_loss: 0.0209 - val_mean_absolute_error: 0.2042 Epoch 98/100 249/249 [==============================] - 2s 6ms/step - loss: 0.0164 - mean_absolute_error: 0.2130 - val_loss: 0.0389 - val_mean_absolute_error: 0.2080 Epoch 99/100 249/249 [==============================] - 2s 6ms/step - loss: 0.0150 - mean_absolute_error: 0.2102 - val_loss: 0.0265 - val_mean_absolute_error: 0.2056 Epoch 100/100 249/249 [==============================] - 2s 6ms/step - loss: 0.0330 - mean_absolute_error: 0.2109 - val_loss: 0.0245 - val_mean_absolute_error: 0.2059 32/32 [==============================] - 0s 3ms/step - loss: 0.0245 - mean_absolute_error: 0.2059 Accuracy: [0.02451775223016739, 0.20593295991420746]
from tensorflow.keras.utils import plot_model
model.summary()
plot_model(model, to_file='mlp-mnist.png', show_shapes=True)
Model: "sequential_73"
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
lstm_73 (LSTM) (None, 32) 4352
dense_146 (Dense) (None, 32) 1056
dense_147 (Dense) (None, 3) 99
=================================================================
Total params: 5507 (21.51 KB)
Trainable params: 5507 (21.51 KB)
Non-trainable params: 0 (0.00 Byte)
_________________________________________________________________
You must install pydot (`pip install pydot`) and install graphviz (see instructions at https://graphviz.gitlab.io/download/) for plot_model to work.
import joblib
joblib.dump(model, 'model/lstm_model.pkl')
INFO:tensorflow:Assets written to: C:\Users\Asus\AppData\Local\Temp\tmp036863h3\assets
INFO:tensorflow:Assets written to: C:\Users\Asus\AppData\Local\Temp\tmp036863h3\assets
['model/lstm_model.pkl']
predictions = model.predict(X_test)
y_pred = np.argmax(predictions, axis=1)
32/32 [==============================] - 1s 3ms/step
plt.figure(figsize=(12, 5))
plt.subplot(1, 2, 1)
plt.plot(history.history['loss'])
plt.plot(history.history['val_loss'])
plt.title('Loss over epochs')
plt.ylabel('Loss')
plt.xlabel('Epoch')
plt.legend(['Train', 'Validation'], loc='best')
plt.subplot(1, 2, 2)
plt.plot(history.history['mean_absolute_error'])
plt.plot(history.history['val_mean_absolute_error'])
plt.title('Accuracy over epochs')
plt.ylabel('Accuracy')
plt.xlabel('Epoch')
plt.legend(['Train', 'Validation'], loc='best')
plt.tight_layout()
plt.show()
cm = confusion_matrix(y_test, y_pred)
disp = ConfusionMatrixDisplay(confusion_matrix=cm)
disp.plot(cmap='Blues', values_format='d')
plt.title("Confusion Matrix: RNN")
plt.show()
print(classification_report(y_test, y_pred))
precision recall f1-score support
0 1.00 0.99 1.00 327
1 0.98 0.98 0.98 352
2 0.98 0.99 0.98 317
accuracy 0.99 996
macro avg 0.99 0.99 0.99 996
weighted avg 0.99 0.99 0.99 996