cnn_classification.py [download]
#!/usr/bin/env python3
#
import sys
import argparse
import logging
import os.path
import joblib
import tensorflow as tf
import keras
import open_data
import model_creation
################################################################
#
# CNN functions
#
def do_cnn_fit(my_args):
"""
Create a new model, and fit it to the training data.
"""
X, y = open_data.load_batch(my_args.batch_number)
model = model_creation.create_model(my_args, X.shape[1:])
early_stopping = keras.callbacks.EarlyStopping(monitor='val_loss', patience=5, restore_best_weights=True)
history = model.fit(X, y, epochs=10, verbose=1, callbacks=[early_stopping], validation_split=0.2, shuffle=True, batch_size=1)
model_file = my_args.model_file
joblib.dump(model, model_file)
joblib.dump(history.history, "{}.history".format(model_file))
return
def do_cnn_refit(my_args):
X, y = open_data.load_batch(my_args.batch_number)
model_file = my_args.model_file
model = joblib.load(model_file)
early_stopping = keras.callbacks.EarlyStopping(monitor='val_loss', patience=5, restore_best_weights=True)
history = model.fit(X, y, epochs=10, verbose=1, callbacks=[early_stopping], validation_split=0.2, shuffle=True, batch_size=1)
joblib.dump(model, model_file)
joblib.dump(history.history, "{}.history".format(model_file))
return
#
# CNN functions
#
################################################################
################################################################
#
# Evaluate existing models functions
#
import model_evaluation
import model_history
#
# Evaluate existing models functions
#
################################################################
def parse_args(argv):
parser = argparse.ArgumentParser(prog=argv[0], description='Image Classification with CNN')
parser.add_argument('action', default='cnn-fit',
choices=[ "cnn-fit", "score", "learning-curve", "cnn-refit" ],
nargs='?', help="desired action")
parser.add_argument('--batch-number', '-b', default=1, type=int, help="which training batch to use (default=1)")
parser.add_argument('--model-file', '-m', default="model.joblib", type=str, help="name of file for the model (default is constructed from train file name when fitting)")
parser.add_argument('--model-name', '-M', default="v", type=str, help="name of model create function")
my_args = parser.parse_args(argv[1:])
#
# Do any special fixes/checks here
#
return my_args
def main(argv):
my_args = parse_args(argv)
# logging.basicConfig(level=logging.INFO)
logging.basicConfig(level=logging.WARN)
if my_args.action == 'cnn-fit':
do_cnn_fit(my_args)
elif my_args.action == 'cnn-refit':
do_cnn_refit(my_args)
elif my_args.action == 'score':
model_evaluation.show_score(my_args)
elif my_args.action == 'learning-curve':
model_history.plot_history(my_args)
else:
raise Exception("Action: {} is not known.".format(my_args.action))
return
if __name__ == "__main__":
main(sys.argv)
Last Updated 03/17/2025