|
15 | 15 | import torch.nn as nn |
16 | 16 | import torch.optim as optim |
17 | 17 | from matplotlib import pyplot as plt |
| 18 | +from matplotlib.ticker import LogLocator, LogFormatter |
18 | 19 | from torch import cuda as torchcuda |
19 | 20 |
|
20 | 21 | from .torch_data import TensorScaler, DataUtil, TorchDataSet, TorchDataSetProviderFromDataUtil, TorchDataSetProvider, \ |
@@ -946,23 +947,39 @@ def _create_series_with_one_based_index(self, sequence: Sequence, name: str): |
946 | 947 | series.index += 1 |
947 | 948 | return series |
948 | 949 |
|
949 | | - def plot_all(self) -> matplotlib.figure.Figure: |
| 950 | + def plot_all(self, log_scale: bool = False) -> matplotlib.figure.Figure: |
950 | 951 | """ |
951 | | - Plots both the sequence of training loss values and the sequence of validation metric values |
| 952 | + Plots both the sequence of training loss values and the sequence of validation metric values. |
| 953 | +
|
| 954 | + :param log_scale: If True, uses a logarithmic y-axis on both axes (requires strictly positive values) |
| 955 | + :return: A matplotlib Figure object containing the plot |
952 | 956 | """ |
953 | 957 | ts = self.get_training_loss_series() |
954 | 958 | vs = self.get_validation_metric_series() |
955 | 959 |
|
| 960 | + if log_scale: |
| 961 | + if np.any(np.asarray(ts) <= 0): |
| 962 | + raise ValueError("log_scale=True requires all training loss values to be > 0.") |
| 963 | + if np.any(np.asarray(vs) <= 0): |
| 964 | + raise ValueError("log_scale=True requires all validation metric values to be > 0.") |
| 965 | + |
956 | 966 | fig, primary_ax = plt.subplots(1, 1) |
957 | 967 | secondary_ax = primary_ax.twinx() |
958 | 968 |
|
959 | 969 | training_line = primary_ax.plot(ts, color='blue') |
960 | 970 | validation_line = secondary_ax.plot(vs, color='orange') |
961 | 971 | best_epoc_line = primary_ax.axvline(self.best_epoch, color='black', linestyle='dashed') |
962 | 972 |
|
| 973 | + if log_scale: |
| 974 | + for ax in (primary_ax, secondary_ax): |
| 975 | + ax.set_yscale("log", base=10) |
| 976 | + ax.yaxis.set_major_locator(LogLocator(base=10)) |
| 977 | + ax.yaxis.set_major_formatter(LogFormatter(base=10)) |
| 978 | + ax.yaxis.set_minor_locator(LogLocator(base=10, subs=np.arange(2, 10) * 0.1)) |
| 979 | + |
963 | 980 | primary_ax.set_xlabel("epoch") |
964 | | - primary_ax.set_ylabel(ts.name) |
965 | | - secondary_ax.set_ylabel(vs.name) |
| 981 | + primary_ax.set_ylabel(ts.name + (" (log)" if log_scale else "")) |
| 982 | + secondary_ax.set_ylabel(vs.name + (" (log)" if log_scale else "")) |
966 | 983 |
|
967 | 984 | primary_ax.legend(training_line + validation_line + [best_epoc_line], [ts.name, vs.name, "best epoch"]) |
968 | 985 | plt.tight_layout() |
|
0 commit comments