DEPARTMENT OF COMPUTING

model_history.py [download]


#!/usr/bin/env python3

#
# Display the model history.
#

import joblib
import pandas as pd
import matplotlib.pyplot as plt

def plot_history(my_args):
    """
    Plot the history of the model training.
    
    Assumes model_file.history has the fit history.
    Assumes that the there are equal number of training and validation values.
    """

    history = joblib.load("{}.history".format(my_args.model_file))
    epochs = len(history["loss"])
    learning_curve_filename = "{}.learning_curve.png".format(my_args.model_file)

    #
    # Display the learning curves
    #
    line_count = len(history.keys())
    if line_count == 2:
        line_style = ["r--+", 
                      "b-+"]
    elif line_count == 4:
        line_style = ["r--*", "r--+", 
                      "b-*", "b-+"]
    elif line_count == 6:
        line_style = ["r--", "r--*", "r--+", 
                      "b-", "b-*", "b-+"]
    elif line_count == 8:
        line_style = ["r--", "r--*", "r--+", "r--x", 
                      "b-", "b-*", "b-+", "b-x"]
    elif line_count == 10:
        line_style = ["r--", "r--*", "r--+", "r--x", "r--1", 
                      "b-", "b-*", "b-+", "b-x", "b-1"]
    else:
        raise Exception("Invalid line count: {}".format(line_count))

    pd.DataFrame(history).plot(
        figsize=(8, 5), xlim=[0, epochs-1], grid=True, xlabel="Epoch",
        style=line_style)
    # plt.show()
    plt.savefig(learning_curve_filename)
    plt.clf()
    return

Last Updated 03/17/2025