DEPARTMENT OF COMPUTING

predict.py [download]


def do_predict(my_args):

    test_file = my_args.test_file
    if not os.path.exists(test_file):
        raise Exception("testing data file: {} does not exist.".format(test_file))
    
    model_file = get_model_filename(my_args.model_file, test_file)
    if not os.path.exists(model_file):
        raise Exception("Model file, '{}', does not exist.".format(model_file))

    X_test, y_test = load_data(my_args, test_file)
    pipeline = joblib.load(model_file)

    y_test_predicted = pipeline.predict(X_test)

    merged = X_test.index.to_frame()
    merged['SalePrice'] = y_test_predicted
    merged.to_csv("predictions.csv", index=False)

    return

Last Updated 02/04/2025