Source code for pumaguard.stats

"""
Module for statistics and plotting.
"""

import matplotlib.pyplot as plt


[docs] def plot_training_progress(filename, full_history): """ Plot the training progress and store in file. """ plt.figure(figsize=(18, 10)) plt.subplot(1, 2, 1) plt.plot(full_history.history["accuracy"], label="Training Accuracy") plt.plot(full_history.history["val_accuracy"], label="Validation Accuracy") plt.legend(loc="lower right") plt.ylabel("Accuracy") plt.ylim([min(plt.ylim()), 1]) plt.title("Training and Validation Accuracy") plt.subplot(1, 2, 2) plt.plot(full_history.history["loss"], label="Training Loss") plt.plot(full_history.history["val_loss"], label="Validation Loss") plt.legend(loc="upper right") plt.ylabel("Cross Entropy") plt.ylim([0, 1.0]) plt.title("Training and Validation Loss") print("Created plot of learning history") plt.savefig(filename)