In [None]:
import math

import tensorflow as tf
import tensorflow.keras.layers as layers
import tensorflow_datasets as tfds
from tensorflow.keras.callbacks import LearningRateScheduler
from tensorflow.keras.initializers import Constant
from tensorflow.keras.models import Model

In [None]:
BATCH_SIZE = 128
SAVED_MODEL_DIR = './saved_model'

In [None]:
(ds_train_data, ds_val_data), info = tfds.load(
 name='mnist',
 split=['train', 'test'],
 with_info=True,
 as_supervised=True,
)

num_classes = info.features['label'].num_classes

In [None]:
def preprocess(image, label):
 image = tf.cast(image, tf.float32)
 image = image / 255.0
 return image, label

AUTOTUNE = tf.data.experimental.AUTOTUNE

ds_train = (
 ds_train_data
 .map(preprocess, num_parallel_calls=AUTOTUNE)
 .cache()
 .shuffle(info.splits['train'].num_examples)
 .batch(BATCH_SIZE)
 .prefetch(AUTOTUNE)
)

ds_val = (
 ds_val_data
 .map(preprocess, AUTOTUNE)
 .batch(BATCH_SIZE)
 .cache()
 .prefetch(AUTOTUNE)
)

In [None]:
inputs = layers.Input(shape=(28, 28, 1), name='input')

x = layers.Conv2D(24, kernel_size=(6, 6), strides=1)(inputs)
x = layers.BatchNormalization(scale=False, beta_initializer=Constant(0.01))(x)
x = layers.Activation('relu')(x)
x = layers.Dropout(rate=0.25)(x)

x = layers.Conv2D(48, kernel_size=(5, 5), strides=2)(x)
x = layers.BatchNormalization(scale=False, beta_initializer=Constant(0.01))(x)
x = layers.Activation('relu')(x)
x = layers.Dropout(rate=0.25)(x)

x = layers.Conv2D(64, kernel_size=(4, 4), strides=2)(x)
x = layers.BatchNormalization(scale=False, beta_initializer=Constant(0.01))(x)
x = layers.Activation('relu')(x)
x = layers.Dropout(rate=0.25)(x)

x = layers.Flatten()(x)
x = layers.Dense(200)(x)
x = layers.BatchNormalization(scale=False, beta_initializer=Constant(0.01))(x)
x = layers.Activation('relu')(x)
x = layers.Dropout(rate=0.25)(x)

predications = layers.Dense(num_classes, activation='softmax', name='output')(x)

model = Model(inputs=inputs, outputs=predications)
model.compile(loss='sparse_categorical_crossentropy', optimizer='adam', metrics=['accuracy'])

In [None]:
model.summary()

In [None]:
lr_decay = lambda epoch: 0.0001 + 0.02 * math.pow(1.0 / math.e, epoch / 3.0)
decay_callback = LearningRateScheduler(lr_decay, verbose=1)

model.fit(
 ds_train,
 epochs=20,
 validation_data=ds_val,
 callbacks=[decay_callback],
 verbose=1
)

In [None]:
tf.saved_model.save(model, SAVED_MODEL_DIR)

In [None]:
converter = tf.lite.TFLiteConverter.from_saved_model(SAVED_MODEL_DIR)
# converter.optimizations = [tf.lite.Optimize.DEFAULT]
# converter.target_spec.supported_types = [tf.float16]
tflite_model = converter.convert()

with open('mnist.tflite', 'wb') as f:
 f.write(tflite_model)

In [None]:
try:
 from google.colab import files
 files.download('mnist.tflite')
except:
 print("Skip downloading")