Classifying a Fashion Dataset

January 21, 2019

In this notebook the Fashion MNIST dataset is classified by means of a convolutional network with Keras.


Imports

In [2]:
import numpy as np
import pandas as pd

import tensorflow as tf
from tensorflow.python.keras.preprocessing.image import ImageDataGenerator
from tensorflow.python.keras.models import Sequential, load_model, Model
from tensorflow.python.keras.layers import Conv2D, MaxPooling2D, Input, Flatten, Dense, Dropout, Concatenate, BatchNormalization, Activation
from tensorflow.python.keras.applications.vgg16 import VGG16
from tensorflow.python.keras import optimizers

import plotly.offline as py
import plotly.graph_objs as go
py.init_notebook_mode(connected=True)

import matplotlib.pyplot as plt
import matplotlib.cm as cm

import warnings
warnings.filterwarnings('ignore')

Constants

In [3]:
SIZE = 28
LABELS = 10
CHANNELS = 1

Helper function

In [4]:
def digits(data, labels=None, random=False, xy=None):
    if xy is None: x, y = 20, min(10, len(data) // 20 +1)
    else: x, y = xy
    fig, ax = plt.subplots(y, x, figsize = (x, y))
    if x==1: indeces = np.arange(y)
    elif y==1: indeces = np.arange(x)
    else: indeces = [(i,j) for i in np.arange(y) for j in np.arange(x)]
    for i, index in enumerate(indeces[:len(data)]):
        if random: i = np.random.randint(0, len(data))
        ax[index].matshow(data.reshape(-1, SIZE, SIZE)[i], cmap=cm.gray_r)
        if labels: ax[index].set_title("Label: {}".format(labels[i]))
    for index in indeces: ax[index].axis('off')
    plt.subplots_adjust(wspace=0, hspace=0)
    plt.show()

Data Import

In [5]:
train = pd.read_csv('/Users/desiredewaele/drive/AI/Databases/FMnist/fashion-mnist_train.csv')#, nrows=10000)
trainY = train.pop('label')
In [6]:
print(train.shape, trainY.shape)
(60000, 784) (60000,)
In [9]:
labelDict = {
    0: "T-shirt/top",
    1: "Trouser",
    2: "Pullover",
    3: "Dress",
    4: "Coat",
    5: "Sandal",
    6: "Shirt",
    7: "Sneaker",
    8: "Bag",
    9: "Ankle boot"
}
for label in np.unique(trainY): 
    print("Label", labelDict[label])
    digits(train.values[trainY==label], xy=(15,1))
Label T-shirt/top
Label Trouser
Label Pullover
Label Dress
Label Coat
Label Sandal
Label Shirt
Label Sneaker
Label Bag
Label Ankle boot