DEPARTMENT OF COMPUTING

open_data.py [download]


#!/usr/bin/env python3

#
# https://keras.io/api/datasets/fashion_mnist/
#
# 60,000 28x28 grayscale training images
# 10,000 28x28 grayscale testing images
#
# I choose to split them into 6 training batches,
# only using 1-5 to train and 6 to validate.
#
#


import numpy as np
import keras

def load_batch_from_keras(number):
    """
    number in [1, 2, 3, 4, 5] -> load training batch
    number == 6 -> load validation batch
    number < 1 -> test batch
    number > 6 -> load training batches 1,2,3,4,5
    """
    (x_train, y_train), (x_test, y_test) = keras.datasets.fashion_mnist.load_data()

    # shuffle training data (always the same seed)
    np.random.seed(42)
    p = np.random.permutation(x_train.shape[0])
    x_train = x_train[p]
    y_train = y_train[p]

    # select subset
    if number < 1:
        images, labels = x_test, y_test
    elif number > 6:
        start = 0 * 10000
        end = 5 * 10000
        images, labels = x_train[start:end], y_train[start:end]
    else:
        start = (number - 1) * 10000
        end = (number) * 10000
        images, labels = x_train[start:end], y_train[start:end]

    # one-hot-encode labels
    labels = keras.utils.to_categorical(labels, num_classes=10)
    # reshape the (28,28) images to (28,28,1). Conv2D expects (w,h,depth)
    images = images.reshape(-1, 28, 28, 1)
    # scale images to [0,1] range
    images = images.astype(np.float32) / 255.0

    return images, labels

def load_batch(number):
    return load_batch_from_keras(number)

Last Updated 03/17/2025