DEPARTMENT OF COMPUTING

display_data.py [download]


#!/usr/bin/env python3

import sys
import argparse
import logging
import os.path

import pandas as pd

import math
import matplotlib.pyplot as plt

def get_data(filename):
    """
    Assumes column 0 is the instance index stored in the
    csv file.  If no such column exists, remove the
    index_col=0 parameter.

    Assumes the column named "Cabin" should be a interpreted 
    as a string, but Pandas can't figure that out on its own.

    Request missing values (blank cells) to be left as empty strings.

    https://pandas.pydata.org/docs/reference/api/pandas.read_csv.html
    """

    data = pd.read_csv(filename, index_col=0, dtype={ "Cabin": str }, keep_default_na=False)
    return data

def get_basename(filename):
    """
    Parses 'path/to/basename.ext' to find 'basename'.
    """
    root, ext = os.path.splitext(filename)
    dirname, basename = os.path.split(root)
    logging.info("root: {}  ext: {}  dirname: {}  basename: {}".format(root, ext, dirname, basename))
    return basename

def get_feature_and_label_names(my_args, data):
    """
    If label column name and feature column names are given, use them.
    Otherwise guess which columns are features and which is the label.
    
    Label guess is "".  Not very smart.
    Feature guess is all non-label features.
    """
    label_column = my_args.label
    feature_columns = my_args.features
    logging.info("INPUT: label_column: {}".format(label_column))
    logging.info("INPUT: feature_columns: {}".format(feature_columns))

    if label_column in data.columns:
        label = label_column
    else:
        label = ""

    features = []
    if feature_columns is not None:
        for feature_column in feature_columns:
            if feature_column in data.columns:
                features.append(feature_column)

    # no features specified, so add all non-labels
    if len(features) == 0:
        for feature_column in data.columns:
            if feature_column != label:
                features.append(feature_column)

    logging.info("OUTPUT: labels: {}".format(label))
    logging.info("OUTPUT: features: {}".format(features))
    return features, label

def display_feature_histograms(my_args, data, figure_number):
    """
    Display a histogram for every feature and the label, if identified.
    """
    feature_columns, label_column = get_feature_and_label_names(my_args, data)

    total_count = len(feature_columns)
    if label_column:
        total_count += 1
    size = int(math.ceil(math.sqrt(total_count)))
    w_inches_per_plot = 2.0
    h_inches_per_plot = 2.5
    fig_width = max(6.5, w_inches_per_plot*size)
    fig_height = max(9.0, h_inches_per_plot*size)
    
    fig = plt.figure(figure_number, figsize=(fig_width, fig_height))
    fig.suptitle( "Feature Histograms" )
    n_max = 1
    all_ax = []
    for i in range(1, len(feature_columns)+1):
        feature_column = feature_columns[i-1]
        if feature_column in data.columns:
            ax = fig.add_subplot(size, size, i)
            ax.set_yscale("log")
            logging.info("'{}': {}".format(feature_column, data[feature_column]))
            n, bins, patches = ax.hist(data[feature_column], bins=20)
            if max(n) > n_max:
                n_max = max(n)
            ax.set_xlabel(feature_column)
            # produces a warning for string type data
            ax.locator_params(axis='x', tight=True, nbins=5)
            all_ax.append(ax)
        else:
            logging.warn("feature_column: '{}' not in data.columns: {}".format(feature_column, data.columns))


    if label_column:
        ax = fig.add_subplot(size, size, total_count)
        ax.set_yscale("log")
        n, bins, patches = ax.hist(data[label_column], bins=20)
        if max(n) > n_max:
            n_max = max(n)
        ax.set_xlabel(label_column)
        ax.locator_params(axis='x', tight=True, nbins=5)
        all_ax.append(ax)

    for ax in all_ax:
        ax.set_ylim(bottom=1.0, top=n_max)

    fig.tight_layout()
    basename = get_basename(my_args.data_file)
    figure_name = "{}-histogram-{}.{}".format(basename, "-".join(feature_columns), "pdf")
    fig.savefig(figure_name)
    plt.close(fig)
    return

def display_label_vs_features(my_args, data, figure_number):
    """
    Display a plot of label vs feature for every feature and the label, if identified.
    """
    feature_columns, label_column = get_feature_and_label_names(my_args, data)

    total_count = len(feature_columns)
    if label_column:
        total_count += 1
    size = int(math.ceil(math.sqrt(total_count)))
    w_inches_per_plot = 2.0
    h_inches_per_plot = 2.5
    fig_width = max(6.5, w_inches_per_plot*size)
    fig_height = max(9.0, h_inches_per_plot*size)
    
    all_ax = []
    fig = plt.figure(figure_number, figsize=(fig_width, fig_height))
    fig.suptitle( "Label vs. Features" )
    for i in range(1, len(feature_columns)+1):
        feature_column = feature_columns[i-1]
        if feature_column in data.columns:
            ax = fig.add_subplot(size, size, i)
            ax.scatter(feature_column, label_column, data=data, s=1)
            ax.set_xlabel(feature_column)
            ax.set_ylabel(label_column)
            ax.locator_params(axis='both', tight=True, nbins=5)
            all_ax.append(ax)
        else:
            logging.warn("feature_column: '{}' not in data.columns: {}".format(feature_column, data.columns))
            
    if label_column:
        ax = fig.add_subplot(size, size, total_count)
        ax.scatter(label_column, label_column, data=data, s=1)
        ax.set_xlabel(label_column)
        ax.set_ylabel(label_column)
        ax.locator_params(axis='both', tight=True, nbins=5)
        all_ax.append(ax)

    fig.tight_layout()
    basename = get_basename(my_args.data_file)
    figure_name = "{}-scatter-{}.{}".format(basename, "-".join(feature_columns), "pdf")
    fig.savefig(figure_name)
    plt.close(fig)

    return

def parse_args(argv):
    parser = argparse.ArgumentParser(prog=argv[0], description='Create Data Plots')
    parser.add_argument('action', default='all',
                        choices=[ "label-vs-features", "feature-histograms",
                                  "all" ], 
                        nargs='?', help="desired action")
    parser.add_argument('--data-file',               '-d', default="",    type=str,   help="csv file of data to display")
    parser.add_argument('--features',  '-f', default=None, action="extend", nargs="+", type=str,
                        help="column names for features")
    parser.add_argument('--label',               '-l', default="label",    type=str,   help="column name for label")

    my_args = parser.parse_args(argv[1:])

    #
    # Do any special fixes/checks here
    #
    
    return my_args

def main(argv):
    my_args = parse_args(argv)
    #logging.basicConfig(level=logging.INFO)
    logging.basicConfig(level=logging.WARN)

    filename = my_args.data_file
    if os.path.exists(filename) and os.path.isfile(filename):
        data = get_data(filename)

        if my_args.action in ("all", "label-vs-features"):
            display_label_vs_features(my_args, data, 1)
        if my_args.action in ("all", "feature-histograms"):
            display_feature_histograms(my_args, data, 2)

    else:
        print(filename + " doesn't exist, or is not a file.")
    
    return

if __name__ == "__main__":
    main(sys.argv)
    

Last Updated 02/01/2024