# Copyright 2024 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """ Plot the stats contained in the .csv file generated by the test_kraepelin_algorithm unit test (when STATS_FILE_NAME is defined). """ import argparse import csv import datetime import json import logging import os import struct import sys import matplotlib.pyplot as pyplot ################################################################################################## def plot_stats(args, rows, stat_name): # Get the vmc and score for each of the stepping epochs stepping_vmcs = [] stepping_scores = [] non_stepping_vmcs = [] non_stepping_scores = [] for row in rows: if int(row['epoch_type']) == 0: non_stepping_vmcs.append(int(row['vmc'])) non_stepping_scores.append(int(row[stat_name])) elif int(row['epoch_type']) == 2: stepping_vmcs.append(int(row['vmc'])) stepping_scores.append(int(row[stat_name])) pyplot.plot(stepping_vmcs, stepping_scores, 'go', non_stepping_vmcs, non_stepping_scores, 'ro') pyplot.show() ################################################################################################## if __name__ == '__main__': # Collect our command line arguments parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter) parser.add_argument('infile', help="The input csv file") parser.add_argument('--plot', choices=['score_0', 'score_lf', 'total'], default='score_0', help="Which metric to plot against vmc") parser.add_argument('--debug', action='store_true', help="Turn on debug logging") args = parser.parse_args() level = logging.INFO if args.debug: level = logging.DEBUG logging.basicConfig(level=level) # Read in the csv file col_names = None rows = [] with open(args.infile, 'rb') as csvfile: reader = csv.reader(csvfile) for row in reader: if reader.line_num == 1: col_names = [x.strip() for x in row] else: rows.append(dict(zip(col_names, row))) # Plot now plot_stats(args, rows, args.plot)