The example build a convolutional neural net for classifying mnist fashion images.
import gzip
import numpy as np
import matplotlib.pyplot as plt
from tensorflow import keras
from tensorflow.keras import layers
# ----------------------------------------------------
# 1. FUNCTIONS TO LOAD IMAGES AND LABELS (YOUR VERSION)
# ----------------------------------------------------
def load_images(path):
with gzip.open(path, 'rb') as f:
data = np.frombuffer(f.read(), np.uint8, offset=16)
return data.reshape(-1, 28, 28)
def load_labels(path):
with gzip.open(path, 'rb') as f:
data = np.frombuffer(f.read(), np.uint8, offset=8)
return data
# ----------------------------------------------------
# 2. LOAD TRAINING AND TEST DATA
# ----------------------------------------------------
x_train = load_images('data/train-images-idx3-ubyte.gz')
y_train = load_labels('data/train-labels-idx1-ubyte.gz')
x_test = load_images('data/t10k-images-idx3-ubyte.gz')
y_test = load_labels('data/t10k-labels-idx1-ubyte.gz')
print("Train:", x_train.shape, y_train.shape)
print("Test: ", x_test.shape, y_test.shape)
# ----------------------------------------------------
# 3. PREPARE DATA FOR CNN
# CNN needs shape: (batch, height, width, channels)
# ----------------------------------------------------
x_train_cnn = x_train.reshape(-1, 28, 28, 1).astype("float32") / 255.0
x_test_cnn = x_test.reshape(-1, 28, 28, 1).astype("float32") / 255.0
# ----------------------------------------------------
# 4. BUILD THE CNN MODEL
# ----------------------------------------------------
model = keras.Sequential([
# 1st convolution block
layers.Conv2D(32, (3, 3), activation='relu', input_shape=(28, 28, 1)),
layers.MaxPooling2D((2, 2)),
# 2nd convolution block
layers.Conv2D(64, (3, 3), activation='relu'),
layers.MaxPooling2D((2, 2)),
# 3rd convolution layer
layers.Conv2D(64, (3, 3), activation='relu'),
# Flatten to feed Dense layers
layers.Flatten(),
# Dense classifier
layers.Dense(128, activation='relu'),
layers.Dropout(0.2),
# Output layer: 10 classes
layers.Dense(10, activation='softmax')
])
# ----------------------------------------------------
# 5. COMPILE THE MODEL
# ----------------------------------------------------
model.compile(
optimizer='adam',
loss='sparse_categorical_crossentropy',
metrics=['accuracy']
)
# ----------------------------------------------------
# 6. PRINT WEIGHTS BEFORE TRAINING
# ----------------------------------------------------
print("\n=== Weights BEFORE training ===")
for layer in model.layers:
weights = layer.get_weights()
print(f"Layer: {layer.name}")
for w in weights:
print(w)
print("-" * 40)
# ----------------------------------------------------
# 7. TRAIN THE MODEL
# ----------------------------------------------------
history = model.fit(
x_train_cnn, y_train,
epochs=10,
batch_size=32,
validation_split=0.1,
verbose=1
)
# ----------------------------------------------------
# 8. EVALUATE ON TEST DATA
# ----------------------------------------------------
test_loss, test_acc = model.evaluate(x_test_cnn, y_test, verbose=2)
print(f"\n✅ Test accuracy: {test_acc:.3f}")
# ----------------------------------------------------
# 9. PRINT WEIGHTS AFTER TRAINING
# ----------------------------------------------------
print("\n=== Weights AFTER training ===")
for layer in model.layers:
weights = layer.get_weights()
print(f"Layer: {layer.name}")
for w in weights:
print(w)
print("-" * 40)
# ----------------------------------------------------
# 10. SAVE MODEL
# ----------------------------------------------------
model.save("fashion_mnist_cnn_model.keras")
print("\nModel saved as fashion_mnist_cnn_model.keras")
# ----------------------------------------------------
# 11. PRINT ACCURACY PROGRESS
# ----------------------------------------------------
print("\nTraining accuracy:\n", history.history['accuracy'])
print("\nValidation accuracy:\n", history.history['val_accuracy'])
# ----------------------------------------------------
# 12. PLOT TRAINING AND VALIDATION ACCURACY
# ----------------------------------------------------
plt.figure(figsize=(10, 4))
plt.plot(history.history['accuracy'], label='Training Accuracy')
plt.plot(history.history['val_accuracy'], label='Validation Accuracy')
plt.title("CNN Accuracy over Epochs")
plt.xlabel("Epoch")
plt.ylabel("Accuracy")
plt.legend()
plt.show()