display_data.py [download]
#!/usr/bin/env python3
import sys
import argparse
import logging
import os.path
import pandas as pd
import math
import matplotlib.pyplot as plt
def get_data(filename):
"""
Assumes column 0 is the instance index stored in the
csv file. If no such column exists, remove the
index_col=0 parameter.
Assumes the column named "Cabin" should be a interpreted
as a string, but Pandas can't figure that out on its own.
Request missing values (blank cells) to be left as empty strings.
https://pandas.pydata.org/docs/reference/api/pandas.read_csv.html
"""
data = pd.read_csv(filename, index_col=0, dtype={ "Cabin": str }, keep_default_na=False)
return data
def get_basename(filename):
"""
Parses 'path/to/basename.ext' to find 'basename'.
"""
root, ext = os.path.splitext(filename)
dirname, basename = os.path.split(root)
logging.info("root: {} ext: {} dirname: {} basename: {}".format(root, ext, dirname, basename))
return basename
def get_feature_and_label_names(my_args, data):
"""
If label column name and feature column names are given, use them.
Otherwise guess which columns are features and which is the label.
Label guess is "". Not very smart.
Feature guess is all non-label features.
"""
label_column = my_args.label
feature_columns = my_args.features
logging.info("INPUT: label_column: {}".format(label_column))
logging.info("INPUT: feature_columns: {}".format(feature_columns))
if label_column in data.columns:
label = label_column
else:
label = ""
features = []
if feature_columns is not None:
for feature_column in feature_columns:
if feature_column in data.columns:
features.append(feature_column)
# no features specified, so add all non-labels
if len(features) == 0:
for feature_column in data.columns:
if feature_column != label:
features.append(feature_column)
logging.info("OUTPUT: labels: {}".format(label))
logging.info("OUTPUT: features: {}".format(features))
return features, label
def display_feature_histograms(my_args, data, figure_number):
"""
Display a histogram for every feature and the label, if identified.
"""
feature_columns, label_column = get_feature_and_label_names(my_args, data)
total_count = len(feature_columns)
if label_column:
total_count += 1
size = int(math.ceil(math.sqrt(total_count)))
w_inches_per_plot = 2.0
h_inches_per_plot = 2.5
fig_width = max(6.5, w_inches_per_plot*size)
fig_height = max(9.0, h_inches_per_plot*size)
fig = plt.figure(figure_number, figsize=(fig_width, fig_height))
fig.suptitle( "Feature Histograms" )
n_max = 1
all_ax = []
for i in range(1, len(feature_columns)+1):
feature_column = feature_columns[i-1]
if feature_column in data.columns:
ax = fig.add_subplot(size, size, i)
ax.set_yscale("log")
logging.info("'{}': {}".format(feature_column, data[feature_column]))
n, bins, patches = ax.hist(data[feature_column], bins=20)
if max(n) > n_max:
n_max = max(n)
ax.set_xlabel(feature_column)
# produces a warning for string type data
ax.locator_params(axis='x', tight=True, nbins=5)
all_ax.append(ax)
else:
logging.warn("feature_column: '{}' not in data.columns: {}".format(feature_column, data.columns))
if label_column:
ax = fig.add_subplot(size, size, total_count)
ax.set_yscale("log")
n, bins, patches = ax.hist(data[label_column], bins=20)
if max(n) > n_max:
n_max = max(n)
ax.set_xlabel(label_column)
ax.locator_params(axis='x', tight=True, nbins=5)
all_ax.append(ax)
for ax in all_ax:
ax.set_ylim(bottom=1.0, top=n_max)
fig.tight_layout()
basename = get_basename(my_args.data_file)
figure_name = "{}-histogram-{}.{}".format(basename, "-".join(feature_columns), "pdf")
fig.savefig(figure_name)
plt.close(fig)
return
def display_label_vs_features(my_args, data, figure_number):
"""
Display a plot of label vs feature for every feature and the label, if identified.
"""
feature_columns, label_column = get_feature_and_label_names(my_args, data)
total_count = len(feature_columns)
if label_column:
total_count += 1
size = int(math.ceil(math.sqrt(total_count)))
w_inches_per_plot = 2.0
h_inches_per_plot = 2.5
fig_width = max(6.5, w_inches_per_plot*size)
fig_height = max(9.0, h_inches_per_plot*size)
all_ax = []
fig = plt.figure(figure_number, figsize=(fig_width, fig_height))
fig.suptitle( "Label vs. Features" )
for i in range(1, len(feature_columns)+1):
feature_column = feature_columns[i-1]
if feature_column in data.columns:
ax = fig.add_subplot(size, size, i)
ax.scatter(feature_column, label_column, data=data, s=1)
ax.set_xlabel(feature_column)
ax.set_ylabel(label_column)
ax.locator_params(axis='both', tight=True, nbins=5)
all_ax.append(ax)
else:
logging.warn("feature_column: '{}' not in data.columns: {}".format(feature_column, data.columns))
if label_column:
ax = fig.add_subplot(size, size, total_count)
ax.scatter(label_column, label_column, data=data, s=1)
ax.set_xlabel(label_column)
ax.set_ylabel(label_column)
ax.locator_params(axis='both', tight=True, nbins=5)
all_ax.append(ax)
fig.tight_layout()
basename = get_basename(my_args.data_file)
figure_name = "{}-scatter-{}.{}".format(basename, "-".join(feature_columns), "pdf")
fig.savefig(figure_name)
plt.close(fig)
return
def parse_args(argv):
parser = argparse.ArgumentParser(prog=argv[0], description='Create Data Plots')
parser.add_argument('action', default='all',
choices=[ "label-vs-features", "feature-histograms",
"all" ],
nargs='?', help="desired action")
parser.add_argument('--data-file', '-d', default="", type=str, help="csv file of data to display")
parser.add_argument('--features', '-f', default=None, action="extend", nargs="+", type=str,
help="column names for features")
parser.add_argument('--label', '-l', default="label", type=str, help="column name for label")
my_args = parser.parse_args(argv[1:])
#
# Do any special fixes/checks here
#
return my_args
def main(argv):
my_args = parse_args(argv)
#logging.basicConfig(level=logging.INFO)
logging.basicConfig(level=logging.WARN)
filename = my_args.data_file
if os.path.exists(filename) and os.path.isfile(filename):
data = get_data(filename)
if my_args.action in ("all", "label-vs-features"):
display_label_vs_features(my_args, data, 1)
if my_args.action in ("all", "feature-histograms"):
display_feature_histograms(my_args, data, 2)
else:
print(filename + " doesn't exist, or is not a file.")
return
if __name__ == "__main__":
main(sys.argv)
Last Updated 02/01/2024