import numpy as np
import pandas as pd
import seaborn as sns
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import LabelEncoder, OneHotEncoder, StandardScaler,MinMaxScaler
from sklearn.model_selection import cross_val_score, GridSearchCV
import matplotlib.pyplot as plt
from sklearn.metrics import confusion_matrix,classification_report, ConfusionMatrixDisplay
dataset = pd.read_csv('data/market_cluster.csv', encoding='latin1')
dataset.head()
| Order ID | Customer Name | Category | Sub Category | City | Order Date | Region | Sales | Discount | Profit | State | profit_margin | Cluster | |
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| 0 | OD1 | Harish | Oil & Masala | Masalas | Vellore | 11-08-2017 | North | 1254 | 0.12 | 401.28 | Tamil Nadu | 0.32 | Medium |
| 1 | OD2 | Sudha | Beverages | Health Drinks | Krishnagiri | 11-08-2017 | South | 749 | 0.18 | 149.80 | Tamil Nadu | 0.20 | Medium |
| 2 | OD3 | Hussain | Food Grains | Atta & Flour | Perambalur | 06-12-2017 | West | 2360 | 0.21 | 165.20 | Tamil Nadu | 0.07 | Low |
| 3 | OD4 | Jackson | Fruits & Veggies | Fresh Vegetables | Dharmapuri | 10-11-2016 | South | 896 | 0.25 | 89.60 | Tamil Nadu | 0.10 | Low |
| 4 | OD5 | Ridhesh | Food Grains | Organic Staples | Ooty | 10-11-2016 | South | 2355 | 0.26 | 918.45 | Tamil Nadu | 0.39 | High |
dataset.info()
<class 'pandas.core.frame.DataFrame'> RangeIndex: 9994 entries, 0 to 9993 Data columns (total 13 columns): # Column Non-Null Count Dtype --- ------ -------------- ----- 0 Order ID 9994 non-null object 1 Customer Name 9994 non-null object 2 Category 9994 non-null object 3 Sub Category 9994 non-null object 4 City 9994 non-null object 5 Order Date 9994 non-null object 6 Region 9994 non-null object 7 Sales 9994 non-null int64 8 Discount 9994 non-null float64 9 Profit 9994 non-null float64 10 State 9994 non-null object 11 profit_margin 9994 non-null float64 12 Cluster 9994 non-null object dtypes: float64(3), int64(1), object(9) memory usage: 1015.1+ KB
dataset.describe()
| Sales | Discount | Profit | profit_margin | |
|---|---|---|---|---|
| count | 9994.000000 | 9994.000000 | 9994.000000 | 9994.000000 |
| mean | 1496.596158 | 0.226817 | 374.937082 | 0.250228 |
| std | 577.559036 | 0.074636 | 239.932881 | 0.118919 |
| min | 500.000000 | 0.100000 | 25.250000 | 0.050000 |
| 25% | 1000.000000 | 0.160000 | 180.022500 | 0.150000 |
| 50% | 1498.000000 | 0.230000 | 320.780000 | 0.250000 |
| 75% | 1994.750000 | 0.290000 | 525.627500 | 0.350000 |
| max | 2500.000000 | 0.350000 | 1120.950000 | 0.450000 |
dataset.info()
<class 'pandas.core.frame.DataFrame'> RangeIndex: 9994 entries, 0 to 9993 Data columns (total 13 columns): # Column Non-Null Count Dtype --- ------ -------------- ----- 0 Order ID 9994 non-null object 1 Customer Name 9994 non-null object 2 Category 9994 non-null object 3 Sub Category 9994 non-null object 4 City 9994 non-null object 5 Order Date 9994 non-null object 6 Region 9994 non-null object 7 Sales 9994 non-null int64 8 Discount 9994 non-null float64 9 Profit 9994 non-null float64 10 State 9994 non-null object 11 profit_margin 9994 non-null float64 12 Cluster 9994 non-null object dtypes: float64(3), int64(1), object(9) memory usage: 1015.1+ KB
dataset.drop(['Order ID'], axis=1, inplace=True)
dataset.isna().sum()
Customer Name 0 Category 0 Sub Category 0 City 0 Order Date 0 Region 0 Sales 0 Discount 0 Profit 0 State 0 profit_margin 0 Cluster 0 dtype: int64
dataset.dropna(inplace=True)
def remove_outliers(data: pd.DataFrame, column: str) -> pd.Series:
q3, q1 = np.nanpercentile(data[column], [75, 25])
iqr = q3 - q1
upper_bound = q3 + 1.5 * iqr
lower_bound = q1 - 1.5 * iqr
data = data[(data[column] > lower_bound) & (data[column] < upper_bound)]
return data
dataset = remove_outliers(dataset, 'Discount')
dataset = remove_outliers(dataset, 'Sales')
dataset = remove_outliers(dataset, 'Profit')
dataset.head()
| Customer Name | Category | Sub Category | City | Order Date | Region | Sales | Discount | Profit | State | profit_margin | Cluster | |
|---|---|---|---|---|---|---|---|---|---|---|---|---|
| 0 | Harish | Oil & Masala | Masalas | Vellore | 11-08-2017 | North | 1254 | 0.12 | 401.28 | Tamil Nadu | 0.32 | Medium |
| 1 | Sudha | Beverages | Health Drinks | Krishnagiri | 11-08-2017 | South | 749 | 0.18 | 149.80 | Tamil Nadu | 0.20 | Medium |
| 2 | Hussain | Food Grains | Atta & Flour | Perambalur | 06-12-2017 | West | 2360 | 0.21 | 165.20 | Tamil Nadu | 0.07 | Low |
| 3 | Jackson | Fruits & Veggies | Fresh Vegetables | Dharmapuri | 10-11-2016 | South | 896 | 0.25 | 89.60 | Tamil Nadu | 0.10 | Low |
| 4 | Ridhesh | Food Grains | Organic Staples | Ooty | 10-11-2016 | South | 2355 | 0.26 | 918.45 | Tamil Nadu | 0.39 | High |
sns.histplot(dataset['Cluster'])
<AxesSubplot:xlabel='Cluster', ylabel='Count'>
encoder = LabelEncoder()
scaler = StandardScaler()
onehot = OneHotEncoder()
minmaxscaler = MinMaxScaler()
dataset["Order Date"] = pd.to_datetime(dataset["Order Date"])
dataset["Order Date"] = dataset["Order Date"].dt.month
# dataset["Customer Name"] = onehot.fit_transform(dataset["Customer Name"].values.reshape(-1, 1)).toarray()
# dataset["Category"] = onehot.fit_transform(dataset["Category"].values.reshape(-1, 1)).toarray()
# dataset["City"] = onehot.fit_transform(dataset["City"].values.reshape(-1, 1)).toarray()
# dataset["Region"] = onehot.fit_transform(dataset["Region"].values.reshape(-1, 1)).toarray()
# dataset["State"] = onehot.fit_transform(dataset["State"].values.reshape(-1, 1)).toarray()
# dataset["Sub Category"] = onehot.fit_transform(dataset["Sub Category"].values.reshape(-1, 1)).toarray()
dataset["Customer Name"] = encoder.fit_transform(dataset["Customer Name"])
dataset["Category"] = encoder.fit_transform(dataset["Category"])
dataset["City"] = encoder.fit_transform(dataset["City"])
dataset["Region"] = encoder.fit_transform(dataset["Region"])
dataset["State"] = encoder.fit_transform(dataset["State"])
dataset["Sub Category"] = encoder.fit_transform(dataset["Sub Category"])
dataset["Order Date"] = encoder.fit_transform(dataset["Order Date"])
dataset[["Sales", "Discount", "profit_margin","Profit"]] = scaler.fit_transform(dataset[["Sales", "Discount", "profit_margin","Profit"]])
class_to_numeric = {'Low': 0, 'Medium': 1, 'High': 2}
dataset['Cluster'] = [class_to_numeric[label] for label in dataset['Cluster']]
dataset.head()
| Customer Name | Category | Sub Category | City | Order Date | Region | Sales | Discount | Profit | State | profit_margin | Cluster | |
|---|---|---|---|---|---|---|---|---|---|---|---|---|
| 0 | 12 | 5 | 14 | 21 | 10 | 2 | -0.414559 | -1.430908 | 0.124389 | 0 | 0.595874 | 1 |
| 1 | 37 | 1 | 13 | 8 | 10 | 3 | -1.291968 | -0.627370 | -0.941183 | 0 | -0.416872 | 1 |
| 2 | 14 | 3 | 0 | 13 | 5 | 4 | 1.507054 | -0.225601 | -0.875930 | 0 | -1.514014 | 0 |
| 3 | 15 | 4 | 12 | 4 | 9 | 3 | -1.036563 | 0.310092 | -1.196262 | 0 | -1.260827 | 0 |
| 4 | 28 | 3 | 18 | 12 | 9 | 3 | 1.498367 | 0.444015 | 2.315743 | 0 | 1.186643 | 2 |
X = dataset.drop(['Cluster','Sub Category','State','profit_margin'],axis=1)
y = dataset['Cluster']
X.head()
| Customer Name | Category | City | Order Date | Region | Sales | Discount | Profit | |
|---|---|---|---|---|---|---|---|---|
| 0 | 12 | 5 | 21 | 10 | 2 | -0.414559 | -1.430908 | 0.124389 |
| 1 | 37 | 1 | 8 | 10 | 3 | -1.291968 | -0.627370 | -0.941183 |
| 2 | 14 | 3 | 13 | 5 | 4 | 1.507054 | -0.225601 | -0.875930 |
| 3 | 15 | 4 | 4 | 9 | 3 | -1.036563 | 0.310092 | -1.196262 |
| 4 | 28 | 3 | 12 | 9 | 3 | 1.498367 | 0.444015 | 2.315743 |
X_train, X_temp, y_train, y_temp = train_test_split(X, y, test_size=0.2, random_state=42)
X_val, X_test, y_val, y_test = train_test_split(X_temp, y_temp, test_size=0.5, random_state=42)
heatcol = X.corr()
sns.heatmap(heatcol,cmap="BrBG",annot=True)
<AxesSubplot:>
print("Dimension of Train set",X_train.shape)
print("Dimension of Val set",X_val.shape)
print("Dimension of Test set",X_test.shape,"\n")
num_cols = X_train._get_numeric_data().columns
print("Number of numeric features:",num_cols.size)
Dimension of Train set (7960, 8) Dimension of Val set (995, 8) Dimension of Test set (996, 8) Number of numeric features: 8
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense,SimpleRNN,Dropout
from tensorflow.keras.utils import to_categorical, plot_model
from scikeras.wrappers import KerasClassifier, KerasRegressor
from sklearn.model_selection import GridSearchCV
y_train = to_categorical(y_train, num_classes=3)
y_val = to_categorical(y_val, num_classes=3)
def create_model(units=64,optimizer='adam',loss='categorical_crossentropy'):
model = Sequential()
model.add(SimpleRNN(units, return_sequences=False, input_shape=(X_train.shape[1], 1)))
model.add(Dense(units, activation='relu'))
model.add(Dense(units=3, activation='sigmoid'))
model.compile(loss=loss, optimizer=optimizer, metrics=['categorical_accuracy'])
return model
# model = KerasClassifier(build_fn=create_model, epochs=100, batch_size=32, verbose=0, validation_data=(X_val, y_val))
# model = KerasClassifier(build_fn=create_model, units=32, epochs=100, batch_size=32, verbose=0)
model = KerasClassifier(build_fn=create_model, units=32, epochs=100, batch_size=32, verbose=0)
param_grid = {
'optimizer': ['adam', 'sgd', 'rmsprop'],
'units': [32, 64, 128],
'loss' : ['categorical_crossentropy','binary_crossentropy','hinge','kullback_leibler_divergence']
}
grid = GridSearchCV(estimator=model, param_grid=param_grid, cv=2, scoring='accuracy')
grid_result = grid.fit(X_train, y_train)
c:\Users\Asus\anaconda3\lib\site-packages\scikeras\wrappers.py:915: UserWarning: ``build_fn`` will be renamed to ``model`` in a future release, at which point use of ``build_fn`` will raise an Error instead.
X, y = self._initialize(X, y)
c:\Users\Asus\anaconda3\lib\site-packages\scikeras\wrappers.py:915: UserWarning: ``build_fn`` will be renamed to ``model`` in a future release, at which point use of ``build_fn`` will raise an Error instead.
X, y = self._initialize(X, y)
c:\Users\Asus\anaconda3\lib\site-packages\scikeras\wrappers.py:915: UserWarning: ``build_fn`` will be renamed to ``model`` in a future release, at which point use of ``build_fn`` will raise an Error instead.
X, y = self._initialize(X, y)
c:\Users\Asus\anaconda3\lib\site-packages\scikeras\wrappers.py:915: UserWarning: ``build_fn`` will be renamed to ``model`` in a future release, at which point use of ``build_fn`` will raise an Error instead.
X, y = self._initialize(X, y)
c:\Users\Asus\anaconda3\lib\site-packages\scikeras\wrappers.py:915: UserWarning: ``build_fn`` will be renamed to ``model`` in a future release, at which point use of ``build_fn`` will raise an Error instead.
X, y = self._initialize(X, y)
c:\Users\Asus\anaconda3\lib\site-packages\scikeras\wrappers.py:915: UserWarning: ``build_fn`` will be renamed to ``model`` in a future release, at which point use of ``build_fn`` will raise an Error instead.
X, y = self._initialize(X, y)
c:\Users\Asus\anaconda3\lib\site-packages\scikeras\wrappers.py:915: UserWarning: ``build_fn`` will be renamed to ``model`` in a future release, at which point use of ``build_fn`` will raise an Error instead.
X, y = self._initialize(X, y)
c:\Users\Asus\anaconda3\lib\site-packages\scikeras\wrappers.py:915: UserWarning: ``build_fn`` will be renamed to ``model`` in a future release, at which point use of ``build_fn`` will raise an Error instead.
X, y = self._initialize(X, y)
c:\Users\Asus\anaconda3\lib\site-packages\scikeras\wrappers.py:915: UserWarning: ``build_fn`` will be renamed to ``model`` in a future release, at which point use of ``build_fn`` will raise an Error instead.
X, y = self._initialize(X, y)
c:\Users\Asus\anaconda3\lib\site-packages\scikeras\wrappers.py:915: UserWarning: ``build_fn`` will be renamed to ``model`` in a future release, at which point use of ``build_fn`` will raise an Error instead.
X, y = self._initialize(X, y)
c:\Users\Asus\anaconda3\lib\site-packages\scikeras\wrappers.py:915: UserWarning: ``build_fn`` will be renamed to ``model`` in a future release, at which point use of ``build_fn`` will raise an Error instead.
X, y = self._initialize(X, y)
c:\Users\Asus\anaconda3\lib\site-packages\scikeras\wrappers.py:915: UserWarning: ``build_fn`` will be renamed to ``model`` in a future release, at which point use of ``build_fn`` will raise an Error instead.
X, y = self._initialize(X, y)
c:\Users\Asus\anaconda3\lib\site-packages\scikeras\wrappers.py:915: UserWarning: ``build_fn`` will be renamed to ``model`` in a future release, at which point use of ``build_fn`` will raise an Error instead.
X, y = self._initialize(X, y)
c:\Users\Asus\anaconda3\lib\site-packages\scikeras\wrappers.py:915: UserWarning: ``build_fn`` will be renamed to ``model`` in a future release, at which point use of ``build_fn`` will raise an Error instead.
X, y = self._initialize(X, y)
c:\Users\Asus\anaconda3\lib\site-packages\scikeras\wrappers.py:915: UserWarning: ``build_fn`` will be renamed to ``model`` in a future release, at which point use of ``build_fn`` will raise an Error instead.
X, y = self._initialize(X, y)
c:\Users\Asus\anaconda3\lib\site-packages\scikeras\wrappers.py:915: UserWarning: ``build_fn`` will be renamed to ``model`` in a future release, at which point use of ``build_fn`` will raise an Error instead.
X, y = self._initialize(X, y)
c:\Users\Asus\anaconda3\lib\site-packages\scikeras\wrappers.py:915: UserWarning: ``build_fn`` will be renamed to ``model`` in a future release, at which point use of ``build_fn`` will raise an Error instead.
X, y = self._initialize(X, y)
c:\Users\Asus\anaconda3\lib\site-packages\scikeras\wrappers.py:915: UserWarning: ``build_fn`` will be renamed to ``model`` in a future release, at which point use of ``build_fn`` will raise an Error instead.
X, y = self._initialize(X, y)
c:\Users\Asus\anaconda3\lib\site-packages\scikeras\wrappers.py:915: UserWarning: ``build_fn`` will be renamed to ``model`` in a future release, at which point use of ``build_fn`` will raise an Error instead.
X, y = self._initialize(X, y)
c:\Users\Asus\anaconda3\lib\site-packages\scikeras\wrappers.py:915: UserWarning: ``build_fn`` will be renamed to ``model`` in a future release, at which point use of ``build_fn`` will raise an Error instead.
X, y = self._initialize(X, y)
c:\Users\Asus\anaconda3\lib\site-packages\scikeras\wrappers.py:915: UserWarning: ``build_fn`` will be renamed to ``model`` in a future release, at which point use of ``build_fn`` will raise an Error instead.
X, y = self._initialize(X, y)
c:\Users\Asus\anaconda3\lib\site-packages\scikeras\wrappers.py:915: UserWarning: ``build_fn`` will be renamed to ``model`` in a future release, at which point use of ``build_fn`` will raise an Error instead.
X, y = self._initialize(X, y)
c:\Users\Asus\anaconda3\lib\site-packages\scikeras\wrappers.py:915: UserWarning: ``build_fn`` will be renamed to ``model`` in a future release, at which point use of ``build_fn`` will raise an Error instead.
X, y = self._initialize(X, y)
c:\Users\Asus\anaconda3\lib\site-packages\scikeras\wrappers.py:915: UserWarning: ``build_fn`` will be renamed to ``model`` in a future release, at which point use of ``build_fn`` will raise an Error instead.
X, y = self._initialize(X, y)
c:\Users\Asus\anaconda3\lib\site-packages\scikeras\wrappers.py:915: UserWarning: ``build_fn`` will be renamed to ``model`` in a future release, at which point use of ``build_fn`` will raise an Error instead.
X, y = self._initialize(X, y)
c:\Users\Asus\anaconda3\lib\site-packages\scikeras\wrappers.py:915: UserWarning: ``build_fn`` will be renamed to ``model`` in a future release, at which point use of ``build_fn`` will raise an Error instead.
X, y = self._initialize(X, y)
c:\Users\Asus\anaconda3\lib\site-packages\scikeras\wrappers.py:915: UserWarning: ``build_fn`` will be renamed to ``model`` in a future release, at which point use of ``build_fn`` will raise an Error instead.
X, y = self._initialize(X, y)
c:\Users\Asus\anaconda3\lib\site-packages\scikeras\wrappers.py:915: UserWarning: ``build_fn`` will be renamed to ``model`` in a future release, at which point use of ``build_fn`` will raise an Error instead.
X, y = self._initialize(X, y)
c:\Users\Asus\anaconda3\lib\site-packages\scikeras\wrappers.py:915: UserWarning: ``build_fn`` will be renamed to ``model`` in a future release, at which point use of ``build_fn`` will raise an Error instead.
X, y = self._initialize(X, y)
c:\Users\Asus\anaconda3\lib\site-packages\scikeras\wrappers.py:915: UserWarning: ``build_fn`` will be renamed to ``model`` in a future release, at which point use of ``build_fn`` will raise an Error instead.
X, y = self._initialize(X, y)
c:\Users\Asus\anaconda3\lib\site-packages\scikeras\wrappers.py:915: UserWarning: ``build_fn`` will be renamed to ``model`` in a future release, at which point use of ``build_fn`` will raise an Error instead.
X, y = self._initialize(X, y)
c:\Users\Asus\anaconda3\lib\site-packages\scikeras\wrappers.py:915: UserWarning: ``build_fn`` will be renamed to ``model`` in a future release, at which point use of ``build_fn`` will raise an Error instead.
X, y = self._initialize(X, y)
c:\Users\Asus\anaconda3\lib\site-packages\scikeras\wrappers.py:915: UserWarning: ``build_fn`` will be renamed to ``model`` in a future release, at which point use of ``build_fn`` will raise an Error instead.
X, y = self._initialize(X, y)
c:\Users\Asus\anaconda3\lib\site-packages\scikeras\wrappers.py:915: UserWarning: ``build_fn`` will be renamed to ``model`` in a future release, at which point use of ``build_fn`` will raise an Error instead.
X, y = self._initialize(X, y)
c:\Users\Asus\anaconda3\lib\site-packages\scikeras\wrappers.py:915: UserWarning: ``build_fn`` will be renamed to ``model`` in a future release, at which point use of ``build_fn`` will raise an Error instead.
X, y = self._initialize(X, y)
c:\Users\Asus\anaconda3\lib\site-packages\scikeras\wrappers.py:915: UserWarning: ``build_fn`` will be renamed to ``model`` in a future release, at which point use of ``build_fn`` will raise an Error instead.
X, y = self._initialize(X, y)
c:\Users\Asus\anaconda3\lib\site-packages\scikeras\wrappers.py:915: UserWarning: ``build_fn`` will be renamed to ``model`` in a future release, at which point use of ``build_fn`` will raise an Error instead.
X, y = self._initialize(X, y)
c:\Users\Asus\anaconda3\lib\site-packages\scikeras\wrappers.py:915: UserWarning: ``build_fn`` will be renamed to ``model`` in a future release, at which point use of ``build_fn`` will raise an Error instead.
X, y = self._initialize(X, y)
c:\Users\Asus\anaconda3\lib\site-packages\scikeras\wrappers.py:915: UserWarning: ``build_fn`` will be renamed to ``model`` in a future release, at which point use of ``build_fn`` will raise an Error instead.
X, y = self._initialize(X, y)
c:\Users\Asus\anaconda3\lib\site-packages\scikeras\wrappers.py:915: UserWarning: ``build_fn`` will be renamed to ``model`` in a future release, at which point use of ``build_fn`` will raise an Error instead.
X, y = self._initialize(X, y)
c:\Users\Asus\anaconda3\lib\site-packages\scikeras\wrappers.py:915: UserWarning: ``build_fn`` will be renamed to ``model`` in a future release, at which point use of ``build_fn`` will raise an Error instead.
X, y = self._initialize(X, y)
c:\Users\Asus\anaconda3\lib\site-packages\scikeras\wrappers.py:915: UserWarning: ``build_fn`` will be renamed to ``model`` in a future release, at which point use of ``build_fn`` will raise an Error instead.
X, y = self._initialize(X, y)
c:\Users\Asus\anaconda3\lib\site-packages\scikeras\wrappers.py:915: UserWarning: ``build_fn`` will be renamed to ``model`` in a future release, at which point use of ``build_fn`` will raise an Error instead.
X, y = self._initialize(X, y)
c:\Users\Asus\anaconda3\lib\site-packages\scikeras\wrappers.py:915: UserWarning: ``build_fn`` will be renamed to ``model`` in a future release, at which point use of ``build_fn`` will raise an Error instead.
X, y = self._initialize(X, y)
c:\Users\Asus\anaconda3\lib\site-packages\scikeras\wrappers.py:915: UserWarning: ``build_fn`` will be renamed to ``model`` in a future release, at which point use of ``build_fn`` will raise an Error instead.
X, y = self._initialize(X, y)
c:\Users\Asus\anaconda3\lib\site-packages\scikeras\wrappers.py:915: UserWarning: ``build_fn`` will be renamed to ``model`` in a future release, at which point use of ``build_fn`` will raise an Error instead.
X, y = self._initialize(X, y)
c:\Users\Asus\anaconda3\lib\site-packages\scikeras\wrappers.py:915: UserWarning: ``build_fn`` will be renamed to ``model`` in a future release, at which point use of ``build_fn`` will raise an Error instead.
X, y = self._initialize(X, y)
c:\Users\Asus\anaconda3\lib\site-packages\scikeras\wrappers.py:915: UserWarning: ``build_fn`` will be renamed to ``model`` in a future release, at which point use of ``build_fn`` will raise an Error instead.
X, y = self._initialize(X, y)
c:\Users\Asus\anaconda3\lib\site-packages\scikeras\wrappers.py:915: UserWarning: ``build_fn`` will be renamed to ``model`` in a future release, at which point use of ``build_fn`` will raise an Error instead.
X, y = self._initialize(X, y)
c:\Users\Asus\anaconda3\lib\site-packages\scikeras\wrappers.py:915: UserWarning: ``build_fn`` will be renamed to ``model`` in a future release, at which point use of ``build_fn`` will raise an Error instead.
X, y = self._initialize(X, y)
c:\Users\Asus\anaconda3\lib\site-packages\scikeras\wrappers.py:915: UserWarning: ``build_fn`` will be renamed to ``model`` in a future release, at which point use of ``build_fn`` will raise an Error instead.
X, y = self._initialize(X, y)
c:\Users\Asus\anaconda3\lib\site-packages\scikeras\wrappers.py:915: UserWarning: ``build_fn`` will be renamed to ``model`` in a future release, at which point use of ``build_fn`` will raise an Error instead.
X, y = self._initialize(X, y)
c:\Users\Asus\anaconda3\lib\site-packages\scikeras\wrappers.py:915: UserWarning: ``build_fn`` will be renamed to ``model`` in a future release, at which point use of ``build_fn`` will raise an Error instead.
X, y = self._initialize(X, y)
c:\Users\Asus\anaconda3\lib\site-packages\scikeras\wrappers.py:915: UserWarning: ``build_fn`` will be renamed to ``model`` in a future release, at which point use of ``build_fn`` will raise an Error instead.
X, y = self._initialize(X, y)
c:\Users\Asus\anaconda3\lib\site-packages\scikeras\wrappers.py:915: UserWarning: ``build_fn`` will be renamed to ``model`` in a future release, at which point use of ``build_fn`` will raise an Error instead.
X, y = self._initialize(X, y)
c:\Users\Asus\anaconda3\lib\site-packages\scikeras\wrappers.py:915: UserWarning: ``build_fn`` will be renamed to ``model`` in a future release, at which point use of ``build_fn`` will raise an Error instead.
X, y = self._initialize(X, y)
c:\Users\Asus\anaconda3\lib\site-packages\scikeras\wrappers.py:915: UserWarning: ``build_fn`` will be renamed to ``model`` in a future release, at which point use of ``build_fn`` will raise an Error instead.
X, y = self._initialize(X, y)
c:\Users\Asus\anaconda3\lib\site-packages\scikeras\wrappers.py:915: UserWarning: ``build_fn`` will be renamed to ``model`` in a future release, at which point use of ``build_fn`` will raise an Error instead.
X, y = self._initialize(X, y)
c:\Users\Asus\anaconda3\lib\site-packages\scikeras\wrappers.py:915: UserWarning: ``build_fn`` will be renamed to ``model`` in a future release, at which point use of ``build_fn`` will raise an Error instead.
X, y = self._initialize(X, y)
c:\Users\Asus\anaconda3\lib\site-packages\scikeras\wrappers.py:915: UserWarning: ``build_fn`` will be renamed to ``model`` in a future release, at which point use of ``build_fn`` will raise an Error instead.
X, y = self._initialize(X, y)
c:\Users\Asus\anaconda3\lib\site-packages\scikeras\wrappers.py:915: UserWarning: ``build_fn`` will be renamed to ``model`` in a future release, at which point use of ``build_fn`` will raise an Error instead.
X, y = self._initialize(X, y)
c:\Users\Asus\anaconda3\lib\site-packages\scikeras\wrappers.py:915: UserWarning: ``build_fn`` will be renamed to ``model`` in a future release, at which point use of ``build_fn`` will raise an Error instead.
X, y = self._initialize(X, y)
c:\Users\Asus\anaconda3\lib\site-packages\scikeras\wrappers.py:915: UserWarning: ``build_fn`` will be renamed to ``model`` in a future release, at which point use of ``build_fn`` will raise an Error instead.
X, y = self._initialize(X, y)
c:\Users\Asus\anaconda3\lib\site-packages\scikeras\wrappers.py:915: UserWarning: ``build_fn`` will be renamed to ``model`` in a future release, at which point use of ``build_fn`` will raise an Error instead.
X, y = self._initialize(X, y)
c:\Users\Asus\anaconda3\lib\site-packages\scikeras\wrappers.py:915: UserWarning: ``build_fn`` will be renamed to ``model`` in a future release, at which point use of ``build_fn`` will raise an Error instead.
X, y = self._initialize(X, y)
c:\Users\Asus\anaconda3\lib\site-packages\scikeras\wrappers.py:915: UserWarning: ``build_fn`` will be renamed to ``model`` in a future release, at which point use of ``build_fn`` will raise an Error instead.
X, y = self._initialize(X, y)
c:\Users\Asus\anaconda3\lib\site-packages\scikeras\wrappers.py:915: UserWarning: ``build_fn`` will be renamed to ``model`` in a future release, at which point use of ``build_fn`` will raise an Error instead.
X, y = self._initialize(X, y)
c:\Users\Asus\anaconda3\lib\site-packages\scikeras\wrappers.py:915: UserWarning: ``build_fn`` will be renamed to ``model`` in a future release, at which point use of ``build_fn`` will raise an Error instead.
X, y = self._initialize(X, y)
c:\Users\Asus\anaconda3\lib\site-packages\scikeras\wrappers.py:915: UserWarning: ``build_fn`` will be renamed to ``model`` in a future release, at which point use of ``build_fn`` will raise an Error instead.
X, y = self._initialize(X, y)
c:\Users\Asus\anaconda3\lib\site-packages\scikeras\wrappers.py:915: UserWarning: ``build_fn`` will be renamed to ``model`` in a future release, at which point use of ``build_fn`` will raise an Error instead.
X, y = self._initialize(X, y)
c:\Users\Asus\anaconda3\lib\site-packages\scikeras\wrappers.py:915: UserWarning: ``build_fn`` will be renamed to ``model`` in a future release, at which point use of ``build_fn`` will raise an Error instead.
X, y = self._initialize(X, y)
c:\Users\Asus\anaconda3\lib\site-packages\scikeras\wrappers.py:915: UserWarning: ``build_fn`` will be renamed to ``model`` in a future release, at which point use of ``build_fn`` will raise an Error instead.
X, y = self._initialize(X, y)
c:\Users\Asus\anaconda3\lib\site-packages\sklearn\model_selection\_validation.py:425: FitFailedWarning:
54 fits failed out of a total of 72.
The score on these train-test partitions for these parameters will be set to nan.
If these failures are not expected, you can try to debug them by setting error_score='raise'.
Below are more details about the failures:
--------------------------------------------------------------------------------
18 fits failed with the following error:
Traceback (most recent call last):
File "c:\Users\Asus\anaconda3\lib\site-packages\sklearn\model_selection\_validation.py", line 729, in _fit_and_score
estimator.fit(X_train, y_train, **fit_params)
File "c:\Users\Asus\anaconda3\lib\site-packages\scikeras\wrappers.py", line 1491, in fit
super().fit(X=X, y=y, sample_weight=sample_weight, **kwargs)
File "c:\Users\Asus\anaconda3\lib\site-packages\scikeras\wrappers.py", line 760, in fit
self._fit(
File "c:\Users\Asus\anaconda3\lib\site-packages\scikeras\wrappers.py", line 926, in _fit
self._check_model_compatibility(y)
File "c:\Users\Asus\anaconda3\lib\site-packages\scikeras\wrappers.py", line 569, in _check_model_compatibility
raise ValueError(
ValueError: loss=binary_crossentropy but model compiled with categorical_crossentropy. Data may not match loss function!
--------------------------------------------------------------------------------
18 fits failed with the following error:
Traceback (most recent call last):
File "c:\Users\Asus\anaconda3\lib\site-packages\sklearn\model_selection\_validation.py", line 729, in _fit_and_score
estimator.fit(X_train, y_train, **fit_params)
File "c:\Users\Asus\anaconda3\lib\site-packages\scikeras\wrappers.py", line 1491, in fit
super().fit(X=X, y=y, sample_weight=sample_weight, **kwargs)
File "c:\Users\Asus\anaconda3\lib\site-packages\scikeras\wrappers.py", line 760, in fit
self._fit(
File "c:\Users\Asus\anaconda3\lib\site-packages\scikeras\wrappers.py", line 926, in _fit
self._check_model_compatibility(y)
File "c:\Users\Asus\anaconda3\lib\site-packages\scikeras\wrappers.py", line 569, in _check_model_compatibility
raise ValueError(
ValueError: loss=hinge but model compiled with categorical_crossentropy. Data may not match loss function!
--------------------------------------------------------------------------------
18 fits failed with the following error:
Traceback (most recent call last):
File "c:\Users\Asus\anaconda3\lib\site-packages\sklearn\model_selection\_validation.py", line 729, in _fit_and_score
estimator.fit(X_train, y_train, **fit_params)
File "c:\Users\Asus\anaconda3\lib\site-packages\scikeras\wrappers.py", line 1491, in fit
super().fit(X=X, y=y, sample_weight=sample_weight, **kwargs)
File "c:\Users\Asus\anaconda3\lib\site-packages\scikeras\wrappers.py", line 760, in fit
self._fit(
File "c:\Users\Asus\anaconda3\lib\site-packages\scikeras\wrappers.py", line 926, in _fit
self._check_model_compatibility(y)
File "c:\Users\Asus\anaconda3\lib\site-packages\scikeras\wrappers.py", line 569, in _check_model_compatibility
raise ValueError(
ValueError: loss=kullback_leibler_divergence but model compiled with categorical_crossentropy. Data may not match loss function!
warnings.warn(some_fits_failed_message, FitFailedWarning)
c:\Users\Asus\anaconda3\lib\site-packages\sklearn\model_selection\_search.py:979: UserWarning: One or more of the test scores are non-finite: [0.9821608 0.97135678 0.95565327 0.98329146 0.96733668 0.94987437
0.97286432 0.9678392 0.95062814 nan nan nan
nan nan nan nan nan nan
nan nan nan nan nan nan
nan nan nan nan nan nan
nan nan nan nan nan nan]
warnings.warn(
c:\Users\Asus\anaconda3\lib\site-packages\scikeras\wrappers.py:915: UserWarning: ``build_fn`` will be renamed to ``model`` in a future release, at which point use of ``build_fn`` will raise an Error instead.
X, y = self._initialize(X, y)
print("Best: %f using %s" % (grid_result.best_score_, grid_result.best_params_))
# best_model = grid_result.best_estimator_
# accuracy = best_model.score(X_val, y_val)
# print("Validation Accuracy: %.2f%%" % (accuracy * 100))
Best: 0.983291 using {'loss': 'categorical_crossentropy', 'optimizer': 'sgd', 'units': 32}
optimizer = grid_result.best_params_['optimizer']
units = grid_result.best_params_['units']
loss = grid_result.best_params_['loss']
model = Sequential()
model.add(SimpleRNN(64, return_sequences = False, input_shape=(X_train.shape[1], 1)))
model.add(Dense(units, activation='relu'))
model.add(Dense(units=3, activation='sigmoid'))
model.compile(loss=loss, optimizer=optimizer, metrics=['mean_absolute_error'])
history = model.fit(X_train, y_train, epochs=100, batch_size=32, validation_data=(X_val, y_val))
accuracy = model.evaluate(X_val, y_val)
print("Accuracy: ", accuracy)
Epoch 1/100 249/249 [==============================] - 8s 13ms/step - loss: 0.8234 - mean_absolute_error: 0.4445 - val_loss: 0.6038 - val_mean_absolute_error: 0.3896 Epoch 2/100 249/249 [==============================] - 2s 8ms/step - loss: 0.4730 - mean_absolute_error: 0.3500 - val_loss: 0.3911 - val_mean_absolute_error: 0.3204 Epoch 3/100 249/249 [==============================] - 2s 9ms/step - loss: 0.3228 - mean_absolute_error: 0.2984 - val_loss: 0.2747 - val_mean_absolute_error: 0.2817 Epoch 4/100 249/249 [==============================] - 2s 9ms/step - loss: 0.2502 - mean_absolute_error: 0.2735 - val_loss: 0.2468 - val_mean_absolute_error: 0.2685 Epoch 5/100 249/249 [==============================] - 2s 8ms/step - loss: 0.2096 - mean_absolute_error: 0.2594 - val_loss: 0.1874 - val_mean_absolute_error: 0.2532 Epoch 6/100 249/249 [==============================] - 2s 7ms/step - loss: 0.1821 - mean_absolute_error: 0.2506 - val_loss: 0.1691 - val_mean_absolute_error: 0.2457 Epoch 7/100 249/249 [==============================] - 2s 7ms/step - loss: 0.1657 - mean_absolute_error: 0.2446 - val_loss: 0.1516 - val_mean_absolute_error: 0.2414 Epoch 8/100 249/249 [==============================] - 2s 7ms/step - loss: 0.1547 - mean_absolute_error: 0.2398 - val_loss: 0.1535 - val_mean_absolute_error: 0.2366 Epoch 9/100 249/249 [==============================] - 2s 8ms/step - loss: 0.1424 - mean_absolute_error: 0.2355 - val_loss: 0.1291 - val_mean_absolute_error: 0.2335 Epoch 10/100 249/249 [==============================] - 2s 7ms/step - loss: 0.1306 - mean_absolute_error: 0.2314 - val_loss: 0.1478 - val_mean_absolute_error: 0.2325 Epoch 11/100 249/249 [==============================] - 2s 7ms/step - loss: 0.1221 - mean_absolute_error: 0.2284 - val_loss: 0.1193 - val_mean_absolute_error: 0.2266 Epoch 12/100 249/249 [==============================] - 2s 7ms/step - loss: 0.1171 - mean_absolute_error: 0.2256 - val_loss: 0.1140 - val_mean_absolute_error: 0.2203 Epoch 13/100 249/249 [==============================] - 2s 7ms/step - loss: 0.1108 - mean_absolute_error: 0.2229 - val_loss: 0.1263 - val_mean_absolute_error: 0.2237 Epoch 14/100 249/249 [==============================] - 2s 7ms/step - loss: 0.1054 - mean_absolute_error: 0.2202 - val_loss: 0.1068 - val_mean_absolute_error: 0.2177 Epoch 15/100 249/249 [==============================] - 2s 7ms/step - loss: 0.1029 - mean_absolute_error: 0.2188 - val_loss: 0.1485 - val_mean_absolute_error: 0.2263 Epoch 16/100 249/249 [==============================] - 2s 7ms/step - loss: 0.1009 - mean_absolute_error: 0.2172 - val_loss: 0.1079 - val_mean_absolute_error: 0.2196 Epoch 17/100 249/249 [==============================] - 2s 7ms/step - loss: 0.0946 - mean_absolute_error: 0.2157 - val_loss: 0.0888 - val_mean_absolute_error: 0.2133 Epoch 18/100 249/249 [==============================] - 2s 7ms/step - loss: 0.0926 - mean_absolute_error: 0.2135 - val_loss: 0.0973 - val_mean_absolute_error: 0.2152 Epoch 19/100 249/249 [==============================] - 2s 7ms/step - loss: 0.0924 - mean_absolute_error: 0.2117 - val_loss: 0.0989 - val_mean_absolute_error: 0.2107 Epoch 20/100 249/249 [==============================] - 2s 7ms/step - loss: 0.0890 - mean_absolute_error: 0.2110 - val_loss: 0.0791 - val_mean_absolute_error: 0.2065 Epoch 21/100 249/249 [==============================] - 2s 7ms/step - loss: 0.0862 - mean_absolute_error: 0.2086 - val_loss: 0.0744 - val_mean_absolute_error: 0.2078 Epoch 22/100 249/249 [==============================] - 2s 7ms/step - loss: 0.0836 - mean_absolute_error: 0.2076 - val_loss: 0.0761 - val_mean_absolute_error: 0.2068 Epoch 23/100 249/249 [==============================] - 2s 7ms/step - loss: 0.0875 - mean_absolute_error: 0.2069 - val_loss: 0.0816 - val_mean_absolute_error: 0.2056 Epoch 24/100 249/249 [==============================] - 2s 7ms/step - loss: 0.0793 - mean_absolute_error: 0.2054 - val_loss: 0.0677 - val_mean_absolute_error: 0.2031 Epoch 25/100 249/249 [==============================] - 2s 7ms/step - loss: 0.0794 - mean_absolute_error: 0.2046 - val_loss: 0.0665 - val_mean_absolute_error: 0.2013 Epoch 26/100 249/249 [==============================] - 2s 7ms/step - loss: 0.0789 - mean_absolute_error: 0.2028 - val_loss: 0.0650 - val_mean_absolute_error: 0.1977 Epoch 27/100 249/249 [==============================] - 2s 7ms/step - loss: 0.0724 - mean_absolute_error: 0.2010 - val_loss: 0.0748 - val_mean_absolute_error: 0.1948 Epoch 28/100 249/249 [==============================] - 2s 8ms/step - loss: 0.0711 - mean_absolute_error: 0.2000 - val_loss: 0.0790 - val_mean_absolute_error: 0.1988 Epoch 29/100 249/249 [==============================] - 2s 7ms/step - loss: 0.0792 - mean_absolute_error: 0.1997 - val_loss: 0.0685 - val_mean_absolute_error: 0.1943 Epoch 30/100 249/249 [==============================] - 2s 7ms/step - loss: 0.0711 - mean_absolute_error: 0.1976 - val_loss: 0.0759 - val_mean_absolute_error: 0.1917 Epoch 31/100 249/249 [==============================] - 2s 7ms/step - loss: 0.0711 - mean_absolute_error: 0.1972 - val_loss: 0.0668 - val_mean_absolute_error: 0.1936 Epoch 32/100 249/249 [==============================] - 2s 7ms/step - loss: 0.0691 - mean_absolute_error: 0.1963 - val_loss: 0.0653 - val_mean_absolute_error: 0.1902 Epoch 33/100 249/249 [==============================] - 2s 7ms/step - loss: 0.0708 - mean_absolute_error: 0.1954 - val_loss: 0.1067 - val_mean_absolute_error: 0.1988 Epoch 34/100 249/249 [==============================] - 2s 7ms/step - loss: 0.0656 - mean_absolute_error: 0.1950 - val_loss: 0.0621 - val_mean_absolute_error: 0.1884 Epoch 35/100 249/249 [==============================] - 2s 7ms/step - loss: 0.0650 - mean_absolute_error: 0.1931 - val_loss: 0.0613 - val_mean_absolute_error: 0.1862 Epoch 36/100 249/249 [==============================] - 2s 7ms/step - loss: 0.0656 - mean_absolute_error: 0.1924 - val_loss: 0.0801 - val_mean_absolute_error: 0.1921 Epoch 37/100 249/249 [==============================] - 2s 8ms/step - loss: 0.0591 - mean_absolute_error: 0.1910 - val_loss: 0.0658 - val_mean_absolute_error: 0.1883 Epoch 38/100 249/249 [==============================] - 2s 7ms/step - loss: 0.0649 - mean_absolute_error: 0.1904 - val_loss: 0.0996 - val_mean_absolute_error: 0.1835 Epoch 39/100 249/249 [==============================] - 2s 7ms/step - loss: 0.0622 - mean_absolute_error: 0.1900 - val_loss: 0.0553 - val_mean_absolute_error: 0.1863 Epoch 40/100 249/249 [==============================] - 2s 7ms/step - loss: 0.0601 - mean_absolute_error: 0.1895 - val_loss: 0.0523 - val_mean_absolute_error: 0.1855 Epoch 41/100 249/249 [==============================] - 2s 7ms/step - loss: 0.0625 - mean_absolute_error: 0.1882 - val_loss: 0.0825 - val_mean_absolute_error: 0.1869 Epoch 42/100 249/249 [==============================] - 2s 7ms/step - loss: 0.0565 - mean_absolute_error: 0.1878 - val_loss: 0.0621 - val_mean_absolute_error: 0.1838 Epoch 43/100 249/249 [==============================] - 2s 7ms/step - loss: 0.0601 - mean_absolute_error: 0.1860 - val_loss: 0.0701 - val_mean_absolute_error: 0.1806 Epoch 44/100 249/249 [==============================] - 2s 7ms/step - loss: 0.0612 - mean_absolute_error: 0.1856 - val_loss: 0.0566 - val_mean_absolute_error: 0.1824 Epoch 45/100 249/249 [==============================] - 2s 7ms/step - loss: 0.0564 - mean_absolute_error: 0.1849 - val_loss: 0.0698 - val_mean_absolute_error: 0.1864 Epoch 46/100 249/249 [==============================] - 2s 7ms/step - loss: 0.0557 - mean_absolute_error: 0.1839 - val_loss: 0.0577 - val_mean_absolute_error: 0.1783 Epoch 47/100 249/249 [==============================] - 2s 7ms/step - loss: 0.0561 - mean_absolute_error: 0.1833 - val_loss: 0.0533 - val_mean_absolute_error: 0.1818 Epoch 48/100 249/249 [==============================] - 2s 7ms/step - loss: 0.0529 - mean_absolute_error: 0.1836 - val_loss: 0.0589 - val_mean_absolute_error: 0.1800 Epoch 49/100 249/249 [==============================] - 2s 8ms/step - loss: 0.0586 - mean_absolute_error: 0.1826 - val_loss: 0.0942 - val_mean_absolute_error: 0.1777 Epoch 50/100 249/249 [==============================] - 2s 7ms/step - loss: 0.0615 - mean_absolute_error: 0.1817 - val_loss: 0.0574 - val_mean_absolute_error: 0.1771 Epoch 51/100 249/249 [==============================] - 2s 7ms/step - loss: 0.0516 - mean_absolute_error: 0.1812 - val_loss: 0.0584 - val_mean_absolute_error: 0.1813 Epoch 52/100 249/249 [==============================] - 2s 7ms/step - loss: 0.0526 - mean_absolute_error: 0.1805 - val_loss: 0.1105 - val_mean_absolute_error: 0.1847 Epoch 53/100 249/249 [==============================] - 2s 7ms/step - loss: 0.0605 - mean_absolute_error: 0.1799 - val_loss: 0.0443 - val_mean_absolute_error: 0.1781 Epoch 54/100 249/249 [==============================] - 2s 7ms/step - loss: 0.0586 - mean_absolute_error: 0.1797 - val_loss: 0.0744 - val_mean_absolute_error: 0.1791 Epoch 55/100 249/249 [==============================] - 2s 8ms/step - loss: 0.0544 - mean_absolute_error: 0.1792 - val_loss: 0.0840 - val_mean_absolute_error: 0.1792 Epoch 56/100 249/249 [==============================] - 2s 7ms/step - loss: 0.0615 - mean_absolute_error: 0.1781 - val_loss: 0.0782 - val_mean_absolute_error: 0.1775 Epoch 57/100 249/249 [==============================] - 2s 7ms/step - loss: 0.0521 - mean_absolute_error: 0.1782 - val_loss: 0.0452 - val_mean_absolute_error: 0.1783 Epoch 58/100 249/249 [==============================] - 2s 8ms/step - loss: 0.0474 - mean_absolute_error: 0.1776 - val_loss: 0.0634 - val_mean_absolute_error: 0.1773 Epoch 59/100 249/249 [==============================] - 2s 7ms/step - loss: 0.0496 - mean_absolute_error: 0.1771 - val_loss: 0.0426 - val_mean_absolute_error: 0.1762 Epoch 60/100 249/249 [==============================] - 2s 7ms/step - loss: 0.0508 - mean_absolute_error: 0.1770 - val_loss: 0.0844 - val_mean_absolute_error: 0.1811 Epoch 61/100 249/249 [==============================] - 2s 7ms/step - loss: 0.0516 - mean_absolute_error: 0.1755 - val_loss: 0.0481 - val_mean_absolute_error: 0.1698 Epoch 62/100 249/249 [==============================] - 2s 7ms/step - loss: 0.0497 - mean_absolute_error: 0.1753 - val_loss: 0.0468 - val_mean_absolute_error: 0.1692 Epoch 63/100 249/249 [==============================] - 2s 6ms/step - loss: 0.0498 - mean_absolute_error: 0.1746 - val_loss: 0.0526 - val_mean_absolute_error: 0.1754 Epoch 64/100 249/249 [==============================] - 1s 6ms/step - loss: 0.0536 - mean_absolute_error: 0.1741 - val_loss: 0.0490 - val_mean_absolute_error: 0.1690 Epoch 65/100 249/249 [==============================] - 1s 6ms/step - loss: 0.0523 - mean_absolute_error: 0.1735 - val_loss: 0.0579 - val_mean_absolute_error: 0.1747 Epoch 66/100 249/249 [==============================] - 1s 6ms/step - loss: 0.0524 - mean_absolute_error: 0.1732 - val_loss: 0.0490 - val_mean_absolute_error: 0.1710 Epoch 67/100 249/249 [==============================] - 1s 6ms/step - loss: 0.0605 - mean_absolute_error: 0.1737 - val_loss: 0.0442 - val_mean_absolute_error: 0.1678 Epoch 68/100 249/249 [==============================] - 2s 7ms/step - loss: 0.0505 - mean_absolute_error: 0.1726 - val_loss: 0.0580 - val_mean_absolute_error: 0.1681 Epoch 69/100 249/249 [==============================] - 2s 8ms/step - loss: 0.0461 - mean_absolute_error: 0.1715 - val_loss: 0.0486 - val_mean_absolute_error: 0.1704 Epoch 70/100 249/249 [==============================] - 2s 8ms/step - loss: 0.0451 - mean_absolute_error: 0.1718 - val_loss: 0.0567 - val_mean_absolute_error: 0.1740 Epoch 71/100 249/249 [==============================] - 2s 8ms/step - loss: 0.0600 - mean_absolute_error: 0.1720 - val_loss: 0.0725 - val_mean_absolute_error: 0.1740 Epoch 72/100 249/249 [==============================] - 2s 7ms/step - loss: 0.0429 - mean_absolute_error: 0.1703 - val_loss: 0.0511 - val_mean_absolute_error: 0.1669 Epoch 73/100 249/249 [==============================] - 2s 10ms/step - loss: 0.0521 - mean_absolute_error: 0.1697 - val_loss: 0.0522 - val_mean_absolute_error: 0.1637 Epoch 74/100 249/249 [==============================] - 3s 10ms/step - loss: 0.0465 - mean_absolute_error: 0.1696 - val_loss: 0.0598 - val_mean_absolute_error: 0.1681 Epoch 75/100 249/249 [==============================] - 2s 9ms/step - loss: 0.0455 - mean_absolute_error: 0.1690 - val_loss: 0.0921 - val_mean_absolute_error: 0.1700 Epoch 76/100 249/249 [==============================] - 2s 8ms/step - loss: 0.0446 - mean_absolute_error: 0.1687 - val_loss: 0.0460 - val_mean_absolute_error: 0.1699 Epoch 77/100 249/249 [==============================] - 2s 7ms/step - loss: 0.0436 - mean_absolute_error: 0.1686 - val_loss: 0.0389 - val_mean_absolute_error: 0.1660 Epoch 78/100 249/249 [==============================] - 2s 9ms/step - loss: 0.0464 - mean_absolute_error: 0.1685 - val_loss: 0.0448 - val_mean_absolute_error: 0.1709 Epoch 79/100 249/249 [==============================] - 2s 9ms/step - loss: 0.0461 - mean_absolute_error: 0.1686 - val_loss: 0.0668 - val_mean_absolute_error: 0.1704 Epoch 80/100 249/249 [==============================] - 2s 8ms/step - loss: 0.0444 - mean_absolute_error: 0.1675 - val_loss: 0.0921 - val_mean_absolute_error: 0.1681 Epoch 81/100 249/249 [==============================] - 2s 7ms/step - loss: 0.0426 - mean_absolute_error: 0.1675 - val_loss: 0.0489 - val_mean_absolute_error: 0.1662 Epoch 82/100 249/249 [==============================] - 2s 6ms/step - loss: 0.0378 - mean_absolute_error: 0.1680 - val_loss: 0.0384 - val_mean_absolute_error: 0.1663 Epoch 83/100 249/249 [==============================] - 2s 8ms/step - loss: 0.0413 - mean_absolute_error: 0.1668 - val_loss: 0.0398 - val_mean_absolute_error: 0.1618 Epoch 84/100 249/249 [==============================] - 2s 10ms/step - loss: 0.0376 - mean_absolute_error: 0.1666 - val_loss: 0.0480 - val_mean_absolute_error: 0.1665 Epoch 85/100 249/249 [==============================] - 2s 7ms/step - loss: 0.0432 - mean_absolute_error: 0.1667 - val_loss: 0.1177 - val_mean_absolute_error: 0.1659 Epoch 86/100 249/249 [==============================] - 2s 7ms/step - loss: 0.0405 - mean_absolute_error: 0.1662 - val_loss: 0.0574 - val_mean_absolute_error: 0.1644 Epoch 87/100 249/249 [==============================] - 2s 7ms/step - loss: 0.0414 - mean_absolute_error: 0.1648 - val_loss: 0.0460 - val_mean_absolute_error: 0.1613 Epoch 88/100 249/249 [==============================] - 2s 8ms/step - loss: 0.0380 - mean_absolute_error: 0.1645 - val_loss: 0.0375 - val_mean_absolute_error: 0.1619 Epoch 89/100 249/249 [==============================] - 2s 7ms/step - loss: 0.0424 - mean_absolute_error: 0.1643 - val_loss: 0.0622 - val_mean_absolute_error: 0.1640 Epoch 90/100 249/249 [==============================] - 2s 7ms/step - loss: 0.0367 - mean_absolute_error: 0.1634 - val_loss: 0.0475 - val_mean_absolute_error: 0.1599 Epoch 91/100 249/249 [==============================] - 1s 6ms/step - loss: 0.0378 - mean_absolute_error: 0.1644 - val_loss: 0.0351 - val_mean_absolute_error: 0.1600 Epoch 92/100 249/249 [==============================] - 2s 6ms/step - loss: 0.0438 - mean_absolute_error: 0.1635 - val_loss: 0.0621 - val_mean_absolute_error: 0.1621 Epoch 93/100 249/249 [==============================] - 2s 6ms/step - loss: 0.0438 - mean_absolute_error: 0.1626 - val_loss: 0.0354 - val_mean_absolute_error: 0.1631 Epoch 94/100 249/249 [==============================] - 2s 7ms/step - loss: 0.0377 - mean_absolute_error: 0.1626 - val_loss: 0.0431 - val_mean_absolute_error: 0.1642 Epoch 95/100 249/249 [==============================] - 2s 7ms/step - loss: 0.0474 - mean_absolute_error: 0.1631 - val_loss: 0.0546 - val_mean_absolute_error: 0.1582 Epoch 96/100 249/249 [==============================] - 2s 7ms/step - loss: 0.0438 - mean_absolute_error: 0.1618 - val_loss: 0.0354 - val_mean_absolute_error: 0.1603 Epoch 97/100 249/249 [==============================] - 2s 7ms/step - loss: 0.0463 - mean_absolute_error: 0.1628 - val_loss: 0.0440 - val_mean_absolute_error: 0.1585 Epoch 98/100 249/249 [==============================] - 1s 6ms/step - loss: 0.0391 - mean_absolute_error: 0.1616 - val_loss: 0.0471 - val_mean_absolute_error: 0.1605 Epoch 99/100 249/249 [==============================] - 1s 6ms/step - loss: 0.0400 - mean_absolute_error: 0.1611 - val_loss: 0.0448 - val_mean_absolute_error: 0.1568 Epoch 100/100 249/249 [==============================] - 2s 6ms/step - loss: 0.0359 - mean_absolute_error: 0.1617 - val_loss: 0.0400 - val_mean_absolute_error: 0.1567 32/32 [==============================] - 0s 4ms/step - loss: 0.0400 - mean_absolute_error: 0.1567 Accuracy: [0.03999854996800423, 0.15673400461673737]
model.summary()
Model: "sequential_73"
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
simple_rnn_73 (SimpleRNN) (None, 64) 4224
dense_146 (Dense) (None, 32) 2080
dense_147 (Dense) (None, 3) 99
=================================================================
Total params: 6403 (25.01 KB)
Trainable params: 6403 (25.01 KB)
Non-trainable params: 0 (0.00 Byte)
_________________________________________________________________
import joblib
joblib.dump(model, 'model/rnn_model.pkl')
INFO:tensorflow:Assets written to: C:\Users\Asus\AppData\Local\Temp\tmpj5ooq9gc\assets
INFO:tensorflow:Assets written to: C:\Users\Asus\AppData\Local\Temp\tmpj5ooq9gc\assets
['model/rnn_model.pkl']
plot_model(model, to_file='mlp-mnist.png', show_shapes=True)
You must install pydot (`pip install pydot`) and install graphviz (see instructions at https://graphviz.gitlab.io/download/) for plot_model to work.
predictions = model.predict(X_test)
y_pred = np.argmax(predictions, axis=1)
11/32 [=========>....................] - ETA: 0s32/32 [==============================] - 0s 5ms/step
plt.figure(figsize=(12, 5))
plt.subplot(1, 2, 1)
plt.plot(history.history['loss'])
plt.plot(history.history['val_loss'])
plt.title('Loss over epochs')
plt.ylabel('Loss')
plt.xlabel('Epoch')
plt.legend(['Train', 'Validation'], loc='best')
plt.subplot(1, 2, 2)
plt.plot(history.history['mean_absolute_error'])
plt.plot(history.history['val_mean_absolute_error'])
plt.title('Accuracy over epochs')
plt.ylabel('Accuracy')
plt.xlabel('Epoch')
plt.legend(['Train', 'Validation'], loc='best')
plt.tight_layout()
plt.show()
cm = confusion_matrix(y_test, y_pred)
disp = ConfusionMatrixDisplay(confusion_matrix=cm)
disp.plot(cmap='Blues', values_format='d')
plt.title("Confusion Matrix: RNN")
plt.show()
print(classification_report(y_test, y_pred))
precision recall f1-score support
0 0.99 0.99 0.99 327
1 0.98 0.96 0.97 352
2 0.97 0.99 0.98 317
accuracy 0.98 996
macro avg 0.98 0.98 0.98 996
weighted avg 0.98 0.98 0.98 996