import numpy as np
import pandas as pd
import seaborn as sns
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import LabelEncoder, OneHotEncoder, StandardScaler,MinMaxScaler
from sklearn.model_selection import cross_val_score, GridSearchCV
import matplotlib.pyplot as plt
from sklearn.metrics import confusion_matrix,classification_report, ConfusionMatrixDisplay
dataset = pd.read_csv('data/market_cluster.csv', encoding='latin1')
dataset.head()
| Order ID | Customer Name | Category | Sub Category | City | Order Date | Region | Sales | Discount | Profit | State | profit_margin | Cluster | |
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| 0 | OD1 | Harish | Oil & Masala | Masalas | Vellore | 11-08-2017 | North | 1254 | 0.12 | 401.28 | Tamil Nadu | 0.32 | Medium |
| 1 | OD2 | Sudha | Beverages | Health Drinks | Krishnagiri | 11-08-2017 | South | 749 | 0.18 | 149.80 | Tamil Nadu | 0.20 | Medium |
| 2 | OD3 | Hussain | Food Grains | Atta & Flour | Perambalur | 06-12-2017 | West | 2360 | 0.21 | 165.20 | Tamil Nadu | 0.07 | Low |
| 3 | OD4 | Jackson | Fruits & Veggies | Fresh Vegetables | Dharmapuri | 10-11-2016 | South | 896 | 0.25 | 89.60 | Tamil Nadu | 0.10 | Low |
| 4 | OD5 | Ridhesh | Food Grains | Organic Staples | Ooty | 10-11-2016 | South | 2355 | 0.26 | 918.45 | Tamil Nadu | 0.39 | High |
dataset.info()
<class 'pandas.core.frame.DataFrame'> RangeIndex: 9994 entries, 0 to 9993 Data columns (total 13 columns): # Column Non-Null Count Dtype --- ------ -------------- ----- 0 Order ID 9994 non-null object 1 Customer Name 9994 non-null object 2 Category 9994 non-null object 3 Sub Category 9994 non-null object 4 City 9994 non-null object 5 Order Date 9994 non-null object 6 Region 9994 non-null object 7 Sales 9994 non-null int64 8 Discount 9994 non-null float64 9 Profit 9994 non-null float64 10 State 9994 non-null object 11 profit_margin 9994 non-null float64 12 Cluster 9994 non-null object dtypes: float64(3), int64(1), object(9) memory usage: 1015.1+ KB
dataset.describe()
| Sales | Discount | Profit | profit_margin | |
|---|---|---|---|---|
| count | 9994.000000 | 9994.000000 | 9994.000000 | 9994.000000 |
| mean | 1496.596158 | 0.226817 | 374.937082 | 0.250228 |
| std | 577.559036 | 0.074636 | 239.932881 | 0.118919 |
| min | 500.000000 | 0.100000 | 25.250000 | 0.050000 |
| 25% | 1000.000000 | 0.160000 | 180.022500 | 0.150000 |
| 50% | 1498.000000 | 0.230000 | 320.780000 | 0.250000 |
| 75% | 1994.750000 | 0.290000 | 525.627500 | 0.350000 |
| max | 2500.000000 | 0.350000 | 1120.950000 | 0.450000 |
dataset.info()
<class 'pandas.core.frame.DataFrame'> RangeIndex: 9994 entries, 0 to 9993 Data columns (total 13 columns): # Column Non-Null Count Dtype --- ------ -------------- ----- 0 Order ID 9994 non-null object 1 Customer Name 9994 non-null object 2 Category 9994 non-null object 3 Sub Category 9994 non-null object 4 City 9994 non-null object 5 Order Date 9994 non-null object 6 Region 9994 non-null object 7 Sales 9994 non-null int64 8 Discount 9994 non-null float64 9 Profit 9994 non-null float64 10 State 9994 non-null object 11 profit_margin 9994 non-null float64 12 Cluster 9994 non-null object dtypes: float64(3), int64(1), object(9) memory usage: 1015.1+ KB
dataset.drop(['Order ID'], axis=1, inplace=True)
dataset.isna().sum()
Customer Name 0 Category 0 Sub Category 0 City 0 Order Date 0 Region 0 Sales 0 Discount 0 Profit 0 State 0 profit_margin 0 Cluster 0 dtype: int64
dataset.dropna(inplace=True)
def remove_outliers(data: pd.DataFrame, column: str) -> pd.Series:
q3, q1 = np.nanpercentile(data[column], [75, 25])
iqr = q3 - q1
upper_bound = q3 + 1.5 * iqr
lower_bound = q1 - 1.5 * iqr
data = data[(data[column] > lower_bound) & (data[column] < upper_bound)]
return data
dataset = remove_outliers(dataset, 'Discount')
dataset = remove_outliers(dataset, 'Sales')
dataset = remove_outliers(dataset, 'Profit')
dataset.head()
| Customer Name | Category | Sub Category | City | Order Date | Region | Sales | Discount | Profit | State | profit_margin | Cluster | |
|---|---|---|---|---|---|---|---|---|---|---|---|---|
| 0 | Harish | Oil & Masala | Masalas | Vellore | 11-08-2017 | North | 1254 | 0.12 | 401.28 | Tamil Nadu | 0.32 | Medium |
| 1 | Sudha | Beverages | Health Drinks | Krishnagiri | 11-08-2017 | South | 749 | 0.18 | 149.80 | Tamil Nadu | 0.20 | Medium |
| 2 | Hussain | Food Grains | Atta & Flour | Perambalur | 06-12-2017 | West | 2360 | 0.21 | 165.20 | Tamil Nadu | 0.07 | Low |
| 3 | Jackson | Fruits & Veggies | Fresh Vegetables | Dharmapuri | 10-11-2016 | South | 896 | 0.25 | 89.60 | Tamil Nadu | 0.10 | Low |
| 4 | Ridhesh | Food Grains | Organic Staples | Ooty | 10-11-2016 | South | 2355 | 0.26 | 918.45 | Tamil Nadu | 0.39 | High |
sns.histplot(dataset['Cluster'])
<AxesSubplot:xlabel='Cluster', ylabel='Count'>
encoder = LabelEncoder()
scaler = StandardScaler()
onehot = OneHotEncoder()
minmaxscaler = MinMaxScaler()
dataset["Order Date"] = pd.to_datetime(dataset["Order Date"])
dataset["Order Date"] = dataset["Order Date"].dt.month
dataset["Customer Name"] = encoder.fit_transform(dataset["Customer Name"])
dataset["Category"] = encoder.fit_transform(dataset["Category"])
dataset["City"] = encoder.fit_transform(dataset["City"])
dataset["Region"] = encoder.fit_transform(dataset["Region"])
dataset["State"] = encoder.fit_transform(dataset["State"])
dataset["Sub Category"] = encoder.fit_transform(dataset["Sub Category"])
# dataset["Customer Name"] = onehot.fit_transform(dataset["Customer Name"].values.reshape(-1, 1)).toarray()
# dataset["Category"] = onehot.fit_transform(dataset["Category"].values.reshape(-1, 1)).toarray()
# dataset["City"] = onehot.fit_transform(dataset["City"].values.reshape(-1, 1)).toarray()
# dataset["Region"] = onehot.fit_transform(dataset["Region"].values.reshape(-1, 1)).toarray()
# dataset["State"] = onehot.fit_transform(dataset["State"].values.reshape(-1, 1)).toarray()
# dataset["Sub Category"] = onehot.fit_transform(dataset["Sub Category"].values.reshape(-1, 1)).toarray()
dataset["Order Date"] = encoder.fit_transform(dataset["Order Date"])
dataset[["Sales", "Discount", "profit_margin"]] = scaler.fit_transform(dataset[["Sales", "Discount", "profit_margin"]])
dataset["Profit"] = minmaxscaler.fit_transform(dataset["Profit"].values.reshape(-1, 1))
class_to_numeric = {'Low': 0, 'Medium': 1, 'High': 2}
dataset['Cluster'] = [class_to_numeric[label] for label in dataset['Cluster']]
dataset.head()
| Customer Name | Category | Sub Category | City | Order Date | Region | Sales | Discount | Profit | State | profit_margin | Cluster | |
|---|---|---|---|---|---|---|---|---|---|---|---|---|
| 0 | 12 | 5 | 14 | 21 | 10 | 2 | -0.414559 | -1.430908 | 0.369225 | 0 | 0.595874 | 1 |
| 1 | 37 | 1 | 13 | 8 | 10 | 3 | -1.291968 | -0.627370 | 0.122296 | 0 | -0.416872 | 1 |
| 2 | 14 | 3 | 0 | 13 | 5 | 4 | 1.507054 | -0.225601 | 0.137417 | 0 | -1.514014 | 0 |
| 3 | 15 | 4 | 12 | 4 | 9 | 3 | -1.036563 | 0.310092 | 0.063185 | 0 | -1.260827 | 0 |
| 4 | 28 | 3 | 18 | 12 | 9 | 3 | 1.498367 | 0.444015 | 0.877036 | 0 | 1.186643 | 2 |
X = dataset.drop(['Cluster','Sub Category','State','profit_margin'],axis=1)
y = dataset['Cluster']
df_new_ma = pd.DataFrame(minmaxscaler.fit_transform(X), columns=X.columns)
df_new_ma.head()
| Customer Name | Category | City | Order Date | Region | Sales | Discount | Profit | |
|---|---|---|---|---|---|---|---|---|
| 0 | 0.244898 | 0.833333 | 0.913043 | 0.909091 | 0.50 | 0.3770 | 0.08 | 0.369225 |
| 1 | 0.755102 | 0.166667 | 0.347826 | 0.909091 | 0.75 | 0.1245 | 0.32 | 0.122296 |
| 2 | 0.285714 | 0.500000 | 0.565217 | 0.454545 | 1.00 | 0.9300 | 0.44 | 0.137417 |
| 3 | 0.306122 | 0.666667 | 0.173913 | 0.818182 | 0.75 | 0.1980 | 0.60 | 0.063185 |
| 4 | 0.571429 | 0.500000 | 0.521739 | 0.818182 | 0.75 | 0.9275 | 0.64 | 0.877036 |
subset_of_data = X.sample(frac=0.05, random_state=42)
plt.figure(figsize=(14,7))
plt.subplot(1,2,2)
plt.title("Scatterplot Scaling", fontsize=18)
sns.scatterplot(data = subset_of_data, color="red")
plt.tight_layout()
plt.show()
plt.figure(figsize=(14,7))
plt.subplot(1,2,2)
plt.title("Data Scaling", fontsize=18)
sns.kdeplot(data = X, color="red")
plt.tight_layout()
plt.show()
X.head()
| Customer Name | Category | City | Order Date | Region | Sales | Discount | Profit | |
|---|---|---|---|---|---|---|---|---|
| 0 | 12 | 5 | 21 | 10 | 2 | -0.414559 | -1.430908 | 0.369225 |
| 1 | 37 | 1 | 8 | 10 | 3 | -1.291968 | -0.627370 | 0.122296 |
| 2 | 14 | 3 | 13 | 5 | 4 | 1.507054 | -0.225601 | 0.137417 |
| 3 | 15 | 4 | 4 | 9 | 3 | -1.036563 | 0.310092 | 0.063185 |
| 4 | 28 | 3 | 12 | 9 | 3 | 1.498367 | 0.444015 | 0.877036 |
X_train, X_temp, y_train, y_temp = train_test_split(X, y, test_size=0.2, random_state=42)
X_val, X_test, y_val, y_test = train_test_split(X_temp, y_temp, test_size=0.5, random_state=42)
heatcol = X.corr()
sns.heatmap(heatcol,cmap="BrBG",annot=True)
<AxesSubplot:>
print("Dimension of Train set",X_train.shape)
print("Dimension of Val set",X_val.shape)
print("Dimension of Test set",X_test.shape,"\n")
num_cols = X_train._get_numeric_data().columns
print("Number of numeric features:",num_cols.size)
Dimension of Train set (7960, 8) Dimension of Val set (995, 8) Dimension of Test set (996, 8) Number of numeric features: 8
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense,Flatten,Dropout
from tensorflow.keras.utils import to_categorical
from scikeras.wrappers import KerasClassifier
from sklearn.model_selection import GridSearchCV
y_train = to_categorical(y_train, num_classes=3)
y_val = to_categorical(y_val, num_classes=3)
def create_model(units=64,optimizer='adam',loss='categorical_crossentropy'):
model = Sequential()
model.add(Flatten(input_shape=(X_train.shape[1], 1)))
model.add(Dense(units, activation='relu'))
model.add(Dense(units=3, activation='sigmoid'))
model.compile(loss=loss, optimizer=optimizer, metrics=['categorical_accuracy'])
return model
model = KerasClassifier(build_fn=create_model, units=32, epochs=100, batch_size=32, verbose=0)
param_grid = {
'optimizer': ['adam', 'sgd', 'rmsprop'],
'units': [32, 64, 128],
'loss' : ['categorical_crossentropy','binary_crossentropy','hinge','kullback_leibler_divergence']
}
grid = GridSearchCV(estimator=model, param_grid=param_grid, cv=2, scoring='accuracy')
grid_result = grid.fit(X_train, y_train)
c:\Users\Asus\anaconda3\lib\site-packages\scikeras\wrappers.py:915: UserWarning: ``build_fn`` will be renamed to ``model`` in a future release, at which point use of ``build_fn`` will raise an Error instead.
X, y = self._initialize(X, y)
c:\Users\Asus\anaconda3\lib\site-packages\scikeras\wrappers.py:915: UserWarning: ``build_fn`` will be renamed to ``model`` in a future release, at which point use of ``build_fn`` will raise an Error instead.
X, y = self._initialize(X, y)
c:\Users\Asus\anaconda3\lib\site-packages\scikeras\wrappers.py:915: UserWarning: ``build_fn`` will be renamed to ``model`` in a future release, at which point use of ``build_fn`` will raise an Error instead.
X, y = self._initialize(X, y)
c:\Users\Asus\anaconda3\lib\site-packages\scikeras\wrappers.py:915: UserWarning: ``build_fn`` will be renamed to ``model`` in a future release, at which point use of ``build_fn`` will raise an Error instead.
X, y = self._initialize(X, y)
c:\Users\Asus\anaconda3\lib\site-packages\scikeras\wrappers.py:915: UserWarning: ``build_fn`` will be renamed to ``model`` in a future release, at which point use of ``build_fn`` will raise an Error instead.
X, y = self._initialize(X, y)
c:\Users\Asus\anaconda3\lib\site-packages\scikeras\wrappers.py:915: UserWarning: ``build_fn`` will be renamed to ``model`` in a future release, at which point use of ``build_fn`` will raise an Error instead.
X, y = self._initialize(X, y)
c:\Users\Asus\anaconda3\lib\site-packages\scikeras\wrappers.py:915: UserWarning: ``build_fn`` will be renamed to ``model`` in a future release, at which point use of ``build_fn`` will raise an Error instead.
X, y = self._initialize(X, y)
c:\Users\Asus\anaconda3\lib\site-packages\scikeras\wrappers.py:915: UserWarning: ``build_fn`` will be renamed to ``model`` in a future release, at which point use of ``build_fn`` will raise an Error instead.
X, y = self._initialize(X, y)
c:\Users\Asus\anaconda3\lib\site-packages\scikeras\wrappers.py:915: UserWarning: ``build_fn`` will be renamed to ``model`` in a future release, at which point use of ``build_fn`` will raise an Error instead.
X, y = self._initialize(X, y)
c:\Users\Asus\anaconda3\lib\site-packages\scikeras\wrappers.py:915: UserWarning: ``build_fn`` will be renamed to ``model`` in a future release, at which point use of ``build_fn`` will raise an Error instead.
X, y = self._initialize(X, y)
c:\Users\Asus\anaconda3\lib\site-packages\scikeras\wrappers.py:915: UserWarning: ``build_fn`` will be renamed to ``model`` in a future release, at which point use of ``build_fn`` will raise an Error instead.
X, y = self._initialize(X, y)
c:\Users\Asus\anaconda3\lib\site-packages\scikeras\wrappers.py:915: UserWarning: ``build_fn`` will be renamed to ``model`` in a future release, at which point use of ``build_fn`` will raise an Error instead.
X, y = self._initialize(X, y)
c:\Users\Asus\anaconda3\lib\site-packages\scikeras\wrappers.py:915: UserWarning: ``build_fn`` will be renamed to ``model`` in a future release, at which point use of ``build_fn`` will raise an Error instead.
X, y = self._initialize(X, y)
c:\Users\Asus\anaconda3\lib\site-packages\scikeras\wrappers.py:915: UserWarning: ``build_fn`` will be renamed to ``model`` in a future release, at which point use of ``build_fn`` will raise an Error instead.
X, y = self._initialize(X, y)
c:\Users\Asus\anaconda3\lib\site-packages\scikeras\wrappers.py:915: UserWarning: ``build_fn`` will be renamed to ``model`` in a future release, at which point use of ``build_fn`` will raise an Error instead.
X, y = self._initialize(X, y)
c:\Users\Asus\anaconda3\lib\site-packages\scikeras\wrappers.py:915: UserWarning: ``build_fn`` will be renamed to ``model`` in a future release, at which point use of ``build_fn`` will raise an Error instead.
X, y = self._initialize(X, y)
c:\Users\Asus\anaconda3\lib\site-packages\scikeras\wrappers.py:915: UserWarning: ``build_fn`` will be renamed to ``model`` in a future release, at which point use of ``build_fn`` will raise an Error instead.
X, y = self._initialize(X, y)
c:\Users\Asus\anaconda3\lib\site-packages\scikeras\wrappers.py:915: UserWarning: ``build_fn`` will be renamed to ``model`` in a future release, at which point use of ``build_fn`` will raise an Error instead.
X, y = self._initialize(X, y)
c:\Users\Asus\anaconda3\lib\site-packages\scikeras\wrappers.py:915: UserWarning: ``build_fn`` will be renamed to ``model`` in a future release, at which point use of ``build_fn`` will raise an Error instead.
X, y = self._initialize(X, y)
c:\Users\Asus\anaconda3\lib\site-packages\scikeras\wrappers.py:915: UserWarning: ``build_fn`` will be renamed to ``model`` in a future release, at which point use of ``build_fn`` will raise an Error instead.
X, y = self._initialize(X, y)
c:\Users\Asus\anaconda3\lib\site-packages\scikeras\wrappers.py:915: UserWarning: ``build_fn`` will be renamed to ``model`` in a future release, at which point use of ``build_fn`` will raise an Error instead.
X, y = self._initialize(X, y)
c:\Users\Asus\anaconda3\lib\site-packages\scikeras\wrappers.py:915: UserWarning: ``build_fn`` will be renamed to ``model`` in a future release, at which point use of ``build_fn`` will raise an Error instead.
X, y = self._initialize(X, y)
c:\Users\Asus\anaconda3\lib\site-packages\scikeras\wrappers.py:915: UserWarning: ``build_fn`` will be renamed to ``model`` in a future release, at which point use of ``build_fn`` will raise an Error instead.
X, y = self._initialize(X, y)
c:\Users\Asus\anaconda3\lib\site-packages\scikeras\wrappers.py:915: UserWarning: ``build_fn`` will be renamed to ``model`` in a future release, at which point use of ``build_fn`` will raise an Error instead.
X, y = self._initialize(X, y)
c:\Users\Asus\anaconda3\lib\site-packages\scikeras\wrappers.py:915: UserWarning: ``build_fn`` will be renamed to ``model`` in a future release, at which point use of ``build_fn`` will raise an Error instead.
X, y = self._initialize(X, y)
c:\Users\Asus\anaconda3\lib\site-packages\scikeras\wrappers.py:915: UserWarning: ``build_fn`` will be renamed to ``model`` in a future release, at which point use of ``build_fn`` will raise an Error instead.
X, y = self._initialize(X, y)
c:\Users\Asus\anaconda3\lib\site-packages\scikeras\wrappers.py:915: UserWarning: ``build_fn`` will be renamed to ``model`` in a future release, at which point use of ``build_fn`` will raise an Error instead.
X, y = self._initialize(X, y)
c:\Users\Asus\anaconda3\lib\site-packages\scikeras\wrappers.py:915: UserWarning: ``build_fn`` will be renamed to ``model`` in a future release, at which point use of ``build_fn`` will raise an Error instead.
X, y = self._initialize(X, y)
c:\Users\Asus\anaconda3\lib\site-packages\scikeras\wrappers.py:915: UserWarning: ``build_fn`` will be renamed to ``model`` in a future release, at which point use of ``build_fn`` will raise an Error instead.
X, y = self._initialize(X, y)
c:\Users\Asus\anaconda3\lib\site-packages\scikeras\wrappers.py:915: UserWarning: ``build_fn`` will be renamed to ``model`` in a future release, at which point use of ``build_fn`` will raise an Error instead.
X, y = self._initialize(X, y)
c:\Users\Asus\anaconda3\lib\site-packages\scikeras\wrappers.py:915: UserWarning: ``build_fn`` will be renamed to ``model`` in a future release, at which point use of ``build_fn`` will raise an Error instead.
X, y = self._initialize(X, y)
c:\Users\Asus\anaconda3\lib\site-packages\scikeras\wrappers.py:915: UserWarning: ``build_fn`` will be renamed to ``model`` in a future release, at which point use of ``build_fn`` will raise an Error instead.
X, y = self._initialize(X, y)
c:\Users\Asus\anaconda3\lib\site-packages\scikeras\wrappers.py:915: UserWarning: ``build_fn`` will be renamed to ``model`` in a future release, at which point use of ``build_fn`` will raise an Error instead.
X, y = self._initialize(X, y)
c:\Users\Asus\anaconda3\lib\site-packages\scikeras\wrappers.py:915: UserWarning: ``build_fn`` will be renamed to ``model`` in a future release, at which point use of ``build_fn`` will raise an Error instead.
X, y = self._initialize(X, y)
c:\Users\Asus\anaconda3\lib\site-packages\scikeras\wrappers.py:915: UserWarning: ``build_fn`` will be renamed to ``model`` in a future release, at which point use of ``build_fn`` will raise an Error instead.
X, y = self._initialize(X, y)
c:\Users\Asus\anaconda3\lib\site-packages\scikeras\wrappers.py:915: UserWarning: ``build_fn`` will be renamed to ``model`` in a future release, at which point use of ``build_fn`` will raise an Error instead.
X, y = self._initialize(X, y)
c:\Users\Asus\anaconda3\lib\site-packages\scikeras\wrappers.py:915: UserWarning: ``build_fn`` will be renamed to ``model`` in a future release, at which point use of ``build_fn`` will raise an Error instead.
X, y = self._initialize(X, y)
c:\Users\Asus\anaconda3\lib\site-packages\scikeras\wrappers.py:915: UserWarning: ``build_fn`` will be renamed to ``model`` in a future release, at which point use of ``build_fn`` will raise an Error instead.
X, y = self._initialize(X, y)
c:\Users\Asus\anaconda3\lib\site-packages\scikeras\wrappers.py:915: UserWarning: ``build_fn`` will be renamed to ``model`` in a future release, at which point use of ``build_fn`` will raise an Error instead.
X, y = self._initialize(X, y)
c:\Users\Asus\anaconda3\lib\site-packages\scikeras\wrappers.py:915: UserWarning: ``build_fn`` will be renamed to ``model`` in a future release, at which point use of ``build_fn`` will raise an Error instead.
X, y = self._initialize(X, y)
c:\Users\Asus\anaconda3\lib\site-packages\scikeras\wrappers.py:915: UserWarning: ``build_fn`` will be renamed to ``model`` in a future release, at which point use of ``build_fn`` will raise an Error instead.
X, y = self._initialize(X, y)
c:\Users\Asus\anaconda3\lib\site-packages\scikeras\wrappers.py:915: UserWarning: ``build_fn`` will be renamed to ``model`` in a future release, at which point use of ``build_fn`` will raise an Error instead.
X, y = self._initialize(X, y)
c:\Users\Asus\anaconda3\lib\site-packages\scikeras\wrappers.py:915: UserWarning: ``build_fn`` will be renamed to ``model`` in a future release, at which point use of ``build_fn`` will raise an Error instead.
X, y = self._initialize(X, y)
c:\Users\Asus\anaconda3\lib\site-packages\scikeras\wrappers.py:915: UserWarning: ``build_fn`` will be renamed to ``model`` in a future release, at which point use of ``build_fn`` will raise an Error instead.
X, y = self._initialize(X, y)
c:\Users\Asus\anaconda3\lib\site-packages\scikeras\wrappers.py:915: UserWarning: ``build_fn`` will be renamed to ``model`` in a future release, at which point use of ``build_fn`` will raise an Error instead.
X, y = self._initialize(X, y)
c:\Users\Asus\anaconda3\lib\site-packages\scikeras\wrappers.py:915: UserWarning: ``build_fn`` will be renamed to ``model`` in a future release, at which point use of ``build_fn`` will raise an Error instead.
X, y = self._initialize(X, y)
c:\Users\Asus\anaconda3\lib\site-packages\scikeras\wrappers.py:915: UserWarning: ``build_fn`` will be renamed to ``model`` in a future release, at which point use of ``build_fn`` will raise an Error instead.
X, y = self._initialize(X, y)
c:\Users\Asus\anaconda3\lib\site-packages\scikeras\wrappers.py:915: UserWarning: ``build_fn`` will be renamed to ``model`` in a future release, at which point use of ``build_fn`` will raise an Error instead.
X, y = self._initialize(X, y)
c:\Users\Asus\anaconda3\lib\site-packages\scikeras\wrappers.py:915: UserWarning: ``build_fn`` will be renamed to ``model`` in a future release, at which point use of ``build_fn`` will raise an Error instead.
X, y = self._initialize(X, y)
c:\Users\Asus\anaconda3\lib\site-packages\scikeras\wrappers.py:915: UserWarning: ``build_fn`` will be renamed to ``model`` in a future release, at which point use of ``build_fn`` will raise an Error instead.
X, y = self._initialize(X, y)
c:\Users\Asus\anaconda3\lib\site-packages\scikeras\wrappers.py:915: UserWarning: ``build_fn`` will be renamed to ``model`` in a future release, at which point use of ``build_fn`` will raise an Error instead.
X, y = self._initialize(X, y)
c:\Users\Asus\anaconda3\lib\site-packages\scikeras\wrappers.py:915: UserWarning: ``build_fn`` will be renamed to ``model`` in a future release, at which point use of ``build_fn`` will raise an Error instead.
X, y = self._initialize(X, y)
c:\Users\Asus\anaconda3\lib\site-packages\scikeras\wrappers.py:915: UserWarning: ``build_fn`` will be renamed to ``model`` in a future release, at which point use of ``build_fn`` will raise an Error instead.
X, y = self._initialize(X, y)
c:\Users\Asus\anaconda3\lib\site-packages\scikeras\wrappers.py:915: UserWarning: ``build_fn`` will be renamed to ``model`` in a future release, at which point use of ``build_fn`` will raise an Error instead.
X, y = self._initialize(X, y)
c:\Users\Asus\anaconda3\lib\site-packages\scikeras\wrappers.py:915: UserWarning: ``build_fn`` will be renamed to ``model`` in a future release, at which point use of ``build_fn`` will raise an Error instead.
X, y = self._initialize(X, y)
c:\Users\Asus\anaconda3\lib\site-packages\scikeras\wrappers.py:915: UserWarning: ``build_fn`` will be renamed to ``model`` in a future release, at which point use of ``build_fn`` will raise an Error instead.
X, y = self._initialize(X, y)
c:\Users\Asus\anaconda3\lib\site-packages\scikeras\wrappers.py:915: UserWarning: ``build_fn`` will be renamed to ``model`` in a future release, at which point use of ``build_fn`` will raise an Error instead.
X, y = self._initialize(X, y)
c:\Users\Asus\anaconda3\lib\site-packages\scikeras\wrappers.py:915: UserWarning: ``build_fn`` will be renamed to ``model`` in a future release, at which point use of ``build_fn`` will raise an Error instead.
X, y = self._initialize(X, y)
c:\Users\Asus\anaconda3\lib\site-packages\scikeras\wrappers.py:915: UserWarning: ``build_fn`` will be renamed to ``model`` in a future release, at which point use of ``build_fn`` will raise an Error instead.
X, y = self._initialize(X, y)
c:\Users\Asus\anaconda3\lib\site-packages\scikeras\wrappers.py:915: UserWarning: ``build_fn`` will be renamed to ``model`` in a future release, at which point use of ``build_fn`` will raise an Error instead.
X, y = self._initialize(X, y)
c:\Users\Asus\anaconda3\lib\site-packages\scikeras\wrappers.py:915: UserWarning: ``build_fn`` will be renamed to ``model`` in a future release, at which point use of ``build_fn`` will raise an Error instead.
X, y = self._initialize(X, y)
c:\Users\Asus\anaconda3\lib\site-packages\scikeras\wrappers.py:915: UserWarning: ``build_fn`` will be renamed to ``model`` in a future release, at which point use of ``build_fn`` will raise an Error instead.
X, y = self._initialize(X, y)
c:\Users\Asus\anaconda3\lib\site-packages\scikeras\wrappers.py:915: UserWarning: ``build_fn`` will be renamed to ``model`` in a future release, at which point use of ``build_fn`` will raise an Error instead.
X, y = self._initialize(X, y)
c:\Users\Asus\anaconda3\lib\site-packages\scikeras\wrappers.py:915: UserWarning: ``build_fn`` will be renamed to ``model`` in a future release, at which point use of ``build_fn`` will raise an Error instead.
X, y = self._initialize(X, y)
c:\Users\Asus\anaconda3\lib\site-packages\scikeras\wrappers.py:915: UserWarning: ``build_fn`` will be renamed to ``model`` in a future release, at which point use of ``build_fn`` will raise an Error instead.
X, y = self._initialize(X, y)
c:\Users\Asus\anaconda3\lib\site-packages\scikeras\wrappers.py:915: UserWarning: ``build_fn`` will be renamed to ``model`` in a future release, at which point use of ``build_fn`` will raise an Error instead.
X, y = self._initialize(X, y)
c:\Users\Asus\anaconda3\lib\site-packages\scikeras\wrappers.py:915: UserWarning: ``build_fn`` will be renamed to ``model`` in a future release, at which point use of ``build_fn`` will raise an Error instead.
X, y = self._initialize(X, y)
c:\Users\Asus\anaconda3\lib\site-packages\scikeras\wrappers.py:915: UserWarning: ``build_fn`` will be renamed to ``model`` in a future release, at which point use of ``build_fn`` will raise an Error instead.
X, y = self._initialize(X, y)
c:\Users\Asus\anaconda3\lib\site-packages\scikeras\wrappers.py:915: UserWarning: ``build_fn`` will be renamed to ``model`` in a future release, at which point use of ``build_fn`` will raise an Error instead.
X, y = self._initialize(X, y)
c:\Users\Asus\anaconda3\lib\site-packages\scikeras\wrappers.py:915: UserWarning: ``build_fn`` will be renamed to ``model`` in a future release, at which point use of ``build_fn`` will raise an Error instead.
X, y = self._initialize(X, y)
c:\Users\Asus\anaconda3\lib\site-packages\scikeras\wrappers.py:915: UserWarning: ``build_fn`` will be renamed to ``model`` in a future release, at which point use of ``build_fn`` will raise an Error instead.
X, y = self._initialize(X, y)
c:\Users\Asus\anaconda3\lib\site-packages\scikeras\wrappers.py:915: UserWarning: ``build_fn`` will be renamed to ``model`` in a future release, at which point use of ``build_fn`` will raise an Error instead.
X, y = self._initialize(X, y)
c:\Users\Asus\anaconda3\lib\site-packages\sklearn\model_selection\_validation.py:425: FitFailedWarning:
54 fits failed out of a total of 72.
The score on these train-test partitions for these parameters will be set to nan.
If these failures are not expected, you can try to debug them by setting error_score='raise'.
Below are more details about the failures:
--------------------------------------------------------------------------------
18 fits failed with the following error:
Traceback (most recent call last):
File "c:\Users\Asus\anaconda3\lib\site-packages\sklearn\model_selection\_validation.py", line 729, in _fit_and_score
estimator.fit(X_train, y_train, **fit_params)
File "c:\Users\Asus\anaconda3\lib\site-packages\scikeras\wrappers.py", line 1491, in fit
super().fit(X=X, y=y, sample_weight=sample_weight, **kwargs)
File "c:\Users\Asus\anaconda3\lib\site-packages\scikeras\wrappers.py", line 760, in fit
self._fit(
File "c:\Users\Asus\anaconda3\lib\site-packages\scikeras\wrappers.py", line 926, in _fit
self._check_model_compatibility(y)
File "c:\Users\Asus\anaconda3\lib\site-packages\scikeras\wrappers.py", line 569, in _check_model_compatibility
raise ValueError(
ValueError: loss=binary_crossentropy but model compiled with categorical_crossentropy. Data may not match loss function!
--------------------------------------------------------------------------------
18 fits failed with the following error:
Traceback (most recent call last):
File "c:\Users\Asus\anaconda3\lib\site-packages\sklearn\model_selection\_validation.py", line 729, in _fit_and_score
estimator.fit(X_train, y_train, **fit_params)
File "c:\Users\Asus\anaconda3\lib\site-packages\scikeras\wrappers.py", line 1491, in fit
super().fit(X=X, y=y, sample_weight=sample_weight, **kwargs)
File "c:\Users\Asus\anaconda3\lib\site-packages\scikeras\wrappers.py", line 760, in fit
self._fit(
File "c:\Users\Asus\anaconda3\lib\site-packages\scikeras\wrappers.py", line 926, in _fit
self._check_model_compatibility(y)
File "c:\Users\Asus\anaconda3\lib\site-packages\scikeras\wrappers.py", line 569, in _check_model_compatibility
raise ValueError(
ValueError: loss=hinge but model compiled with categorical_crossentropy. Data may not match loss function!
--------------------------------------------------------------------------------
18 fits failed with the following error:
Traceback (most recent call last):
File "c:\Users\Asus\anaconda3\lib\site-packages\sklearn\model_selection\_validation.py", line 729, in _fit_and_score
estimator.fit(X_train, y_train, **fit_params)
File "c:\Users\Asus\anaconda3\lib\site-packages\scikeras\wrappers.py", line 1491, in fit
super().fit(X=X, y=y, sample_weight=sample_weight, **kwargs)
File "c:\Users\Asus\anaconda3\lib\site-packages\scikeras\wrappers.py", line 760, in fit
self._fit(
File "c:\Users\Asus\anaconda3\lib\site-packages\scikeras\wrappers.py", line 926, in _fit
self._check_model_compatibility(y)
File "c:\Users\Asus\anaconda3\lib\site-packages\scikeras\wrappers.py", line 569, in _check_model_compatibility
raise ValueError(
ValueError: loss=kullback_leibler_divergence but model compiled with categorical_crossentropy. Data may not match loss function!
warnings.warn(some_fits_failed_message, FitFailedWarning)
c:\Users\Asus\anaconda3\lib\site-packages\sklearn\model_selection\_search.py:979: UserWarning: One or more of the test scores are non-finite: [0.975 0.96331658 0.96105528 0.96218593 0.97072864 0.96055276
0.96683417 0.97487437 0.96494975 nan nan nan
nan nan nan nan nan nan
nan nan nan nan nan nan
nan nan nan nan nan nan
nan nan nan nan nan nan]
warnings.warn(
c:\Users\Asus\anaconda3\lib\site-packages\scikeras\wrappers.py:915: UserWarning: ``build_fn`` will be renamed to ``model`` in a future release, at which point use of ``build_fn`` will raise an Error instead.
X, y = self._initialize(X, y)
print("Best: %f using %s" % (grid_result.best_score_, grid_result.best_params_))
Best: 0.975000 using {'loss': 'categorical_crossentropy', 'optimizer': 'adam', 'units': 32}
optimizer = grid_result.best_params_['optimizer']
units = grid_result.best_params_['units']
loss = grid_result.best_params_['loss']
model = Sequential()
model.add(Flatten(input_shape=(X_train.shape[1], 1)))
model.add(Dense(units, activation='relu'))
model.add(Dense(units=3, activation='sigmoid'))
model.compile(loss=loss, optimizer=optimizer, metrics=['mean_absolute_error'])
history = model.fit(X_train, y_train, epochs=100, batch_size=32, validation_data=(X_val, y_val))
accuracy = model.evaluate(X_val, y_val)
print("Accuracy: ", accuracy)
Epoch 1/100 249/249 [==============================] - 2s 4ms/step - loss: 1.1910 - mean_absolute_error: 0.6411 - val_loss: 1.1244 - val_mean_absolute_error: 0.6395 Epoch 2/100 249/249 [==============================] - 1s 3ms/step - loss: 1.0659 - mean_absolute_error: 0.6380 - val_loss: 1.0774 - val_mean_absolute_error: 0.6335 Epoch 3/100 249/249 [==============================] - 1s 3ms/step - loss: 0.9923 - mean_absolute_error: 0.6326 - val_loss: 0.9391 - val_mean_absolute_error: 0.6281 Epoch 4/100 249/249 [==============================] - 1s 3ms/step - loss: 0.8947 - mean_absolute_error: 0.6261 - val_loss: 0.8354 - val_mean_absolute_error: 0.6206 Epoch 5/100 249/249 [==============================] - 1s 4ms/step - loss: 0.7923 - mean_absolute_error: 0.6163 - val_loss: 0.7401 - val_mean_absolute_error: 0.6092 Epoch 6/100 249/249 [==============================] - 1s 3ms/step - loss: 0.7080 - mean_absolute_error: 0.6052 - val_loss: 0.6866 - val_mean_absolute_error: 0.5997 Epoch 7/100 249/249 [==============================] - 1s 4ms/step - loss: 0.6435 - mean_absolute_error: 0.5925 - val_loss: 0.6067 - val_mean_absolute_error: 0.5837 Epoch 8/100 249/249 [==============================] - 1s 3ms/step - loss: 0.5873 - mean_absolute_error: 0.5801 - val_loss: 0.5607 - val_mean_absolute_error: 0.5696 Epoch 9/100 249/249 [==============================] - 1s 2ms/step - loss: 0.5372 - mean_absolute_error: 0.5675 - val_loss: 0.5247 - val_mean_absolute_error: 0.5587 Epoch 10/100 249/249 [==============================] - 1s 2ms/step - loss: 0.4932 - mean_absolute_error: 0.5509 - val_loss: 0.4803 - val_mean_absolute_error: 0.5343 Epoch 11/100 249/249 [==============================] - 0s 2ms/step - loss: 0.4585 - mean_absolute_error: 0.5335 - val_loss: 0.4467 - val_mean_absolute_error: 0.5177 Epoch 12/100 249/249 [==============================] - 0s 2ms/step - loss: 0.4255 - mean_absolute_error: 0.5139 - val_loss: 0.4227 - val_mean_absolute_error: 0.4969 Epoch 13/100 249/249 [==============================] - 1s 2ms/step - loss: 0.3996 - mean_absolute_error: 0.4973 - val_loss: 0.4061 - val_mean_absolute_error: 0.4880 Epoch 14/100 249/249 [==============================] - 1s 2ms/step - loss: 0.3785 - mean_absolute_error: 0.4817 - val_loss: 0.3621 - val_mean_absolute_error: 0.4704 Epoch 15/100 249/249 [==============================] - 0s 2ms/step - loss: 0.3542 - mean_absolute_error: 0.4687 - val_loss: 0.3536 - val_mean_absolute_error: 0.4550 Epoch 16/100 249/249 [==============================] - 0s 2ms/step - loss: 0.3334 - mean_absolute_error: 0.4537 - val_loss: 0.3299 - val_mean_absolute_error: 0.4426 Epoch 17/100 249/249 [==============================] - 0s 2ms/step - loss: 0.3154 - mean_absolute_error: 0.4416 - val_loss: 0.3357 - val_mean_absolute_error: 0.4417 Epoch 18/100 249/249 [==============================] - 0s 2ms/step - loss: 0.3021 - mean_absolute_error: 0.4304 - val_loss: 0.2929 - val_mean_absolute_error: 0.4229 Epoch 19/100 249/249 [==============================] - 1s 2ms/step - loss: 0.2862 - mean_absolute_error: 0.4196 - val_loss: 0.2929 - val_mean_absolute_error: 0.4146 Epoch 20/100 249/249 [==============================] - 0s 2ms/step - loss: 0.2757 - mean_absolute_error: 0.4095 - val_loss: 0.2736 - val_mean_absolute_error: 0.4053 Epoch 21/100 249/249 [==============================] - 1s 2ms/step - loss: 0.2629 - mean_absolute_error: 0.3994 - val_loss: 0.2760 - val_mean_absolute_error: 0.3953 Epoch 22/100 249/249 [==============================] - 0s 2ms/step - loss: 0.2531 - mean_absolute_error: 0.3916 - val_loss: 0.2753 - val_mean_absolute_error: 0.3941 Epoch 23/100 249/249 [==============================] - 0s 2ms/step - loss: 0.2426 - mean_absolute_error: 0.3848 - val_loss: 0.2376 - val_mean_absolute_error: 0.3791 Epoch 24/100 249/249 [==============================] - 0s 2ms/step - loss: 0.2357 - mean_absolute_error: 0.3771 - val_loss: 0.2394 - val_mean_absolute_error: 0.3702 Epoch 25/100 249/249 [==============================] - 0s 2ms/step - loss: 0.2279 - mean_absolute_error: 0.3707 - val_loss: 0.2248 - val_mean_absolute_error: 0.3674 Epoch 26/100 249/249 [==============================] - 0s 2ms/step - loss: 0.2193 - mean_absolute_error: 0.3640 - val_loss: 0.2354 - val_mean_absolute_error: 0.3640 Epoch 27/100 249/249 [==============================] - 1s 2ms/step - loss: 0.2089 - mean_absolute_error: 0.3579 - val_loss: 0.2131 - val_mean_absolute_error: 0.3570 Epoch 28/100 249/249 [==============================] - 0s 2ms/step - loss: 0.2069 - mean_absolute_error: 0.3544 - val_loss: 0.2008 - val_mean_absolute_error: 0.3480 Epoch 29/100 249/249 [==============================] - 0s 2ms/step - loss: 0.2021 - mean_absolute_error: 0.3487 - val_loss: 0.1961 - val_mean_absolute_error: 0.3450 Epoch 30/100 249/249 [==============================] - 0s 2ms/step - loss: 0.1916 - mean_absolute_error: 0.3440 - val_loss: 0.1964 - val_mean_absolute_error: 0.3376 Epoch 31/100 249/249 [==============================] - 0s 2ms/step - loss: 0.1885 - mean_absolute_error: 0.3389 - val_loss: 0.1858 - val_mean_absolute_error: 0.3388 Epoch 32/100 249/249 [==============================] - 0s 2ms/step - loss: 0.1851 - mean_absolute_error: 0.3363 - val_loss: 0.1802 - val_mean_absolute_error: 0.3323 Epoch 33/100 249/249 [==============================] - 0s 2ms/step - loss: 0.1796 - mean_absolute_error: 0.3333 - val_loss: 0.1808 - val_mean_absolute_error: 0.3329 Epoch 34/100 249/249 [==============================] - 0s 2ms/step - loss: 0.1722 - mean_absolute_error: 0.3303 - val_loss: 0.1789 - val_mean_absolute_error: 0.3264 Epoch 35/100 249/249 [==============================] - 1s 2ms/step - loss: 0.1708 - mean_absolute_error: 0.3281 - val_loss: 0.1660 - val_mean_absolute_error: 0.3260 Epoch 36/100 249/249 [==============================] - 0s 2ms/step - loss: 0.1629 - mean_absolute_error: 0.3245 - val_loss: 0.1664 - val_mean_absolute_error: 0.3233 Epoch 37/100 249/249 [==============================] - 0s 2ms/step - loss: 0.1624 - mean_absolute_error: 0.3226 - val_loss: 0.1608 - val_mean_absolute_error: 0.3180 Epoch 38/100 249/249 [==============================] - 1s 3ms/step - loss: 0.1589 - mean_absolute_error: 0.3204 - val_loss: 0.1551 - val_mean_absolute_error: 0.3176 Epoch 39/100 249/249 [==============================] - 0s 2ms/step - loss: 0.1577 - mean_absolute_error: 0.3183 - val_loss: 0.1614 - val_mean_absolute_error: 0.3207 Epoch 40/100 249/249 [==============================] - 0s 2ms/step - loss: 0.1519 - mean_absolute_error: 0.3168 - val_loss: 0.1600 - val_mean_absolute_error: 0.3190 Epoch 41/100 249/249 [==============================] - 1s 2ms/step - loss: 0.1492 - mean_absolute_error: 0.3146 - val_loss: 0.1483 - val_mean_absolute_error: 0.3111 Epoch 42/100 249/249 [==============================] - 0s 2ms/step - loss: 0.1460 - mean_absolute_error: 0.3126 - val_loss: 0.1433 - val_mean_absolute_error: 0.3117 Epoch 43/100 249/249 [==============================] - 0s 2ms/step - loss: 0.1449 - mean_absolute_error: 0.3104 - val_loss: 0.1429 - val_mean_absolute_error: 0.3103 Epoch 44/100 249/249 [==============================] - 0s 2ms/step - loss: 0.1395 - mean_absolute_error: 0.3088 - val_loss: 0.1551 - val_mean_absolute_error: 0.3091 Epoch 45/100 249/249 [==============================] - 0s 2ms/step - loss: 0.1398 - mean_absolute_error: 0.3070 - val_loss: 0.1453 - val_mean_absolute_error: 0.3085 Epoch 46/100 249/249 [==============================] - 0s 2ms/step - loss: 0.1347 - mean_absolute_error: 0.3065 - val_loss: 0.1335 - val_mean_absolute_error: 0.3057 Epoch 47/100 249/249 [==============================] - 0s 2ms/step - loss: 0.1339 - mean_absolute_error: 0.3058 - val_loss: 0.1429 - val_mean_absolute_error: 0.3065 Epoch 48/100 249/249 [==============================] - 1s 2ms/step - loss: 0.1293 - mean_absolute_error: 0.3045 - val_loss: 0.1582 - val_mean_absolute_error: 0.3063 Epoch 49/100 249/249 [==============================] - 0s 2ms/step - loss: 0.1302 - mean_absolute_error: 0.3038 - val_loss: 0.1316 - val_mean_absolute_error: 0.2995 Epoch 50/100 249/249 [==============================] - 0s 2ms/step - loss: 0.1270 - mean_absolute_error: 0.3020 - val_loss: 0.1248 - val_mean_absolute_error: 0.3014 Epoch 51/100 249/249 [==============================] - 1s 2ms/step - loss: 0.1254 - mean_absolute_error: 0.3007 - val_loss: 0.1246 - val_mean_absolute_error: 0.2976 Epoch 52/100 249/249 [==============================] - 0s 2ms/step - loss: 0.1237 - mean_absolute_error: 0.3001 - val_loss: 0.1268 - val_mean_absolute_error: 0.2993 Epoch 53/100 249/249 [==============================] - 0s 2ms/step - loss: 0.1224 - mean_absolute_error: 0.2992 - val_loss: 0.1205 - val_mean_absolute_error: 0.2961 Epoch 54/100 249/249 [==============================] - 1s 2ms/step - loss: 0.1204 - mean_absolute_error: 0.2985 - val_loss: 0.1195 - val_mean_absolute_error: 0.2967 Epoch 55/100 249/249 [==============================] - 1s 2ms/step - loss: 0.1161 - mean_absolute_error: 0.2968 - val_loss: 0.1323 - val_mean_absolute_error: 0.2987 Epoch 56/100 249/249 [==============================] - 0s 2ms/step - loss: 0.1159 - mean_absolute_error: 0.2967 - val_loss: 0.1243 - val_mean_absolute_error: 0.2918 Epoch 57/100 249/249 [==============================] - 0s 2ms/step - loss: 0.1170 - mean_absolute_error: 0.2957 - val_loss: 0.1168 - val_mean_absolute_error: 0.2954 Epoch 58/100 249/249 [==============================] - 0s 2ms/step - loss: 0.1118 - mean_absolute_error: 0.2945 - val_loss: 0.1122 - val_mean_absolute_error: 0.2922 Epoch 59/100 249/249 [==============================] - 0s 2ms/step - loss: 0.1104 - mean_absolute_error: 0.2941 - val_loss: 0.1136 - val_mean_absolute_error: 0.2926 Epoch 60/100 249/249 [==============================] - 0s 2ms/step - loss: 0.1114 - mean_absolute_error: 0.2934 - val_loss: 0.1122 - val_mean_absolute_error: 0.2906 Epoch 61/100 249/249 [==============================] - 1s 2ms/step - loss: 0.1084 - mean_absolute_error: 0.2917 - val_loss: 0.1300 - val_mean_absolute_error: 0.2949 Epoch 62/100 249/249 [==============================] - 0s 2ms/step - loss: 0.1058 - mean_absolute_error: 0.2919 - val_loss: 0.1068 - val_mean_absolute_error: 0.2885 Epoch 63/100 249/249 [==============================] - 0s 2ms/step - loss: 0.1048 - mean_absolute_error: 0.2905 - val_loss: 0.1115 - val_mean_absolute_error: 0.2907 Epoch 64/100 249/249 [==============================] - 1s 2ms/step - loss: 0.1066 - mean_absolute_error: 0.2890 - val_loss: 0.1111 - val_mean_absolute_error: 0.2885 Epoch 65/100 249/249 [==============================] - 0s 2ms/step - loss: 0.1070 - mean_absolute_error: 0.2881 - val_loss: 0.1144 - val_mean_absolute_error: 0.2862 Epoch 66/100 249/249 [==============================] - 0s 2ms/step - loss: 0.1058 - mean_absolute_error: 0.2870 - val_loss: 0.1062 - val_mean_absolute_error: 0.2849 Epoch 67/100 249/249 [==============================] - 0s 2ms/step - loss: 0.1033 - mean_absolute_error: 0.2869 - val_loss: 0.1075 - val_mean_absolute_error: 0.2859 Epoch 68/100 249/249 [==============================] - 0s 2ms/step - loss: 0.1009 - mean_absolute_error: 0.2852 - val_loss: 0.0992 - val_mean_absolute_error: 0.2831 Epoch 69/100 249/249 [==============================] - 0s 2ms/step - loss: 0.0994 - mean_absolute_error: 0.2851 - val_loss: 0.1087 - val_mean_absolute_error: 0.2850 Epoch 70/100 249/249 [==============================] - 0s 2ms/step - loss: 0.0962 - mean_absolute_error: 0.2848 - val_loss: 0.0973 - val_mean_absolute_error: 0.2813 Epoch 71/100 249/249 [==============================] - 1s 3ms/step - loss: 0.0970 - mean_absolute_error: 0.2841 - val_loss: 0.1170 - val_mean_absolute_error: 0.2797 Epoch 72/100 249/249 [==============================] - 1s 2ms/step - loss: 0.0955 - mean_absolute_error: 0.2836 - val_loss: 0.1019 - val_mean_absolute_error: 0.2807 Epoch 73/100 249/249 [==============================] - 0s 2ms/step - loss: 0.0954 - mean_absolute_error: 0.2828 - val_loss: 0.1007 - val_mean_absolute_error: 0.2790 Epoch 74/100 249/249 [==============================] - 0s 2ms/step - loss: 0.0931 - mean_absolute_error: 0.2827 - val_loss: 0.0917 - val_mean_absolute_error: 0.2806 Epoch 75/100 249/249 [==============================] - 0s 2ms/step - loss: 0.0924 - mean_absolute_error: 0.2824 - val_loss: 0.0965 - val_mean_absolute_error: 0.2787 Epoch 76/100 249/249 [==============================] - 1s 2ms/step - loss: 0.0928 - mean_absolute_error: 0.2820 - val_loss: 0.1156 - val_mean_absolute_error: 0.2824 Epoch 77/100 249/249 [==============================] - 0s 2ms/step - loss: 0.0905 - mean_absolute_error: 0.2808 - val_loss: 0.1023 - val_mean_absolute_error: 0.2814 Epoch 78/100 249/249 [==============================] - 1s 2ms/step - loss: 0.0889 - mean_absolute_error: 0.2812 - val_loss: 0.0889 - val_mean_absolute_error: 0.2774 Epoch 79/100 249/249 [==============================] - 0s 2ms/step - loss: 0.0880 - mean_absolute_error: 0.2802 - val_loss: 0.0892 - val_mean_absolute_error: 0.2780 Epoch 80/100 249/249 [==============================] - 0s 2ms/step - loss: 0.0898 - mean_absolute_error: 0.2798 - val_loss: 0.1034 - val_mean_absolute_error: 0.2787 Epoch 81/100 249/249 [==============================] - 0s 2ms/step - loss: 0.0915 - mean_absolute_error: 0.2800 - val_loss: 0.0916 - val_mean_absolute_error: 0.2780 Epoch 82/100 249/249 [==============================] - 1s 2ms/step - loss: 0.0889 - mean_absolute_error: 0.2797 - val_loss: 0.0969 - val_mean_absolute_error: 0.2779 Epoch 83/100 249/249 [==============================] - 0s 2ms/step - loss: 0.0841 - mean_absolute_error: 0.2784 - val_loss: 0.0923 - val_mean_absolute_error: 0.2779 Epoch 84/100 249/249 [==============================] - 0s 2ms/step - loss: 0.0842 - mean_absolute_error: 0.2782 - val_loss: 0.0923 - val_mean_absolute_error: 0.2762 Epoch 85/100 249/249 [==============================] - 1s 2ms/step - loss: 0.0851 - mean_absolute_error: 0.2783 - val_loss: 0.0852 - val_mean_absolute_error: 0.2763 Epoch 86/100 249/249 [==============================] - 0s 2ms/step - loss: 0.0833 - mean_absolute_error: 0.2779 - val_loss: 0.0912 - val_mean_absolute_error: 0.2762 Epoch 87/100 249/249 [==============================] - 0s 2ms/step - loss: 0.0815 - mean_absolute_error: 0.2769 - val_loss: 0.0947 - val_mean_absolute_error: 0.2749 Epoch 88/100 249/249 [==============================] - 0s 2ms/step - loss: 0.0846 - mean_absolute_error: 0.2762 - val_loss: 0.0921 - val_mean_absolute_error: 0.2733 Epoch 89/100 249/249 [==============================] - 0s 2ms/step - loss: 0.0798 - mean_absolute_error: 0.2758 - val_loss: 0.0836 - val_mean_absolute_error: 0.2742 Epoch 90/100 249/249 [==============================] - 0s 2ms/step - loss: 0.0797 - mean_absolute_error: 0.2752 - val_loss: 0.0820 - val_mean_absolute_error: 0.2739 Epoch 91/100 249/249 [==============================] - 0s 2ms/step - loss: 0.0796 - mean_absolute_error: 0.2753 - val_loss: 0.0805 - val_mean_absolute_error: 0.2736 Epoch 92/100 249/249 [==============================] - 0s 2ms/step - loss: 0.0802 - mean_absolute_error: 0.2752 - val_loss: 0.0806 - val_mean_absolute_error: 0.2716 Epoch 93/100 249/249 [==============================] - 0s 2ms/step - loss: 0.0802 - mean_absolute_error: 0.2744 - val_loss: 0.0779 - val_mean_absolute_error: 0.2721 Epoch 94/100 249/249 [==============================] - 0s 2ms/step - loss: 0.0818 - mean_absolute_error: 0.2745 - val_loss: 0.0827 - val_mean_absolute_error: 0.2721 Epoch 95/100 249/249 [==============================] - 1s 2ms/step - loss: 0.0843 - mean_absolute_error: 0.2748 - val_loss: 0.0791 - val_mean_absolute_error: 0.2720 Epoch 96/100 249/249 [==============================] - 0s 2ms/step - loss: 0.0771 - mean_absolute_error: 0.2740 - val_loss: 0.0778 - val_mean_absolute_error: 0.2706 Epoch 97/100 249/249 [==============================] - 1s 2ms/step - loss: 0.0810 - mean_absolute_error: 0.2735 - val_loss: 0.0923 - val_mean_absolute_error: 0.2728 Epoch 98/100 249/249 [==============================] - 0s 2ms/step - loss: 0.0744 - mean_absolute_error: 0.2732 - val_loss: 0.0820 - val_mean_absolute_error: 0.2712 Epoch 99/100 249/249 [==============================] - 0s 2ms/step - loss: 0.0750 - mean_absolute_error: 0.2728 - val_loss: 0.0775 - val_mean_absolute_error: 0.2699 Epoch 100/100 249/249 [==============================] - 0s 2ms/step - loss: 0.0776 - mean_absolute_error: 0.2729 - val_loss: 0.0769 - val_mean_absolute_error: 0.2707 32/32 [==============================] - 0s 1ms/step - loss: 0.0769 - mean_absolute_error: 0.2707 Accuracy: [0.07687386870384216, 0.2706981301307678]
from tensorflow.keras.utils import plot_model
model.summary()
plot_model(model, to_file='mlp-mnist.png', show_shapes=True)
Model: "sequential_147"
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
flatten_147 (Flatten) (None, 8) 0
dense_294 (Dense) (None, 32) 288
dense_295 (Dense) (None, 3) 99
=================================================================
Total params: 387 (1.51 KB)
Trainable params: 387 (1.51 KB)
Non-trainable params: 0 (0.00 Byte)
_________________________________________________________________
You must install pydot (`pip install pydot`) and install graphviz (see instructions at https://graphviz.gitlab.io/download/) for plot_model to work.
import joblib
joblib.dump(model, 'model/ann_model.pkl')
INFO:tensorflow:Assets written to: C:\Users\Asus\AppData\Local\Temp\tmpfg28j20q\assets
INFO:tensorflow:Assets written to: C:\Users\Asus\AppData\Local\Temp\tmpfg28j20q\assets
['model/ann_model.pkl']
predictions = model.predict(X_test)
y_pred = np.argmax(predictions, axis=1)
32/32 [==============================] - 0s 1ms/step
plt.figure(figsize=(12, 5))
plt.subplot(1, 2, 1)
plt.plot(history.history['loss'])
plt.plot(history.history['val_loss'])
plt.title('Loss over epochs')
plt.ylabel('Loss')
plt.xlabel('Epoch')
plt.legend(['Train', 'Validation'], loc='best')
plt.subplot(1, 2, 2)
plt.plot(history.history['mean_absolute_error'])
plt.plot(history.history['val_mean_absolute_error'])
plt.title('Accuracy over epochs')
plt.ylabel('Accuracy')
plt.xlabel('Epoch')
plt.legend(['Train', 'Validation'], loc='best')
plt.tight_layout()
plt.show()
cm = confusion_matrix(y_test, y_pred)
disp = ConfusionMatrixDisplay(confusion_matrix=cm)
disp.plot(cmap='Blues', values_format='d')
plt.title("Confusion Matrix: RNN")
plt.show()
print(classification_report(y_test, y_pred))
precision recall f1-score support
0 0.99 1.00 0.99 327
1 0.99 0.98 0.99 352
2 0.99 0.99 0.99 317
accuracy 0.99 996
macro avg 0.99 0.99 0.99 996
weighted avg 0.99 0.99 0.99 996