Skip to content

Commit ee49834

Browse files
authored
Merge pull request #94 from mdbenito/feature/log_scale
Implement log scaling for TraniningInfo.plot_all
2 parents 6ecc067 + d45a7d0 commit ee49834

File tree

1 file changed

+21
-4
lines changed

1 file changed

+21
-4
lines changed

src/sensai/torch/torch_opt.py

Lines changed: 21 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
import torch.nn as nn
1616
import torch.optim as optim
1717
from matplotlib import pyplot as plt
18+
from matplotlib.ticker import LogLocator, LogFormatter
1819
from torch import cuda as torchcuda
1920

2021
from .torch_data import TensorScaler, DataUtil, TorchDataSet, TorchDataSetProviderFromDataUtil, TorchDataSetProvider, \
@@ -946,23 +947,39 @@ def _create_series_with_one_based_index(self, sequence: Sequence, name: str):
946947
series.index += 1
947948
return series
948949

949-
def plot_all(self) -> matplotlib.figure.Figure:
950+
def plot_all(self, log_scale: bool = False) -> matplotlib.figure.Figure:
950951
"""
951-
Plots both the sequence of training loss values and the sequence of validation metric values
952+
Plots both the sequence of training loss values and the sequence of validation metric values.
953+
954+
:param log_scale: If True, uses a logarithmic y-axis on both axes (requires strictly positive values)
955+
:return: A matplotlib Figure object containing the plot
952956
"""
953957
ts = self.get_training_loss_series()
954958
vs = self.get_validation_metric_series()
955959

960+
if log_scale:
961+
if np.any(np.asarray(ts) <= 0):
962+
raise ValueError("log_scale=True requires all training loss values to be > 0.")
963+
if np.any(np.asarray(vs) <= 0):
964+
raise ValueError("log_scale=True requires all validation metric values to be > 0.")
965+
956966
fig, primary_ax = plt.subplots(1, 1)
957967
secondary_ax = primary_ax.twinx()
958968

959969
training_line = primary_ax.plot(ts, color='blue')
960970
validation_line = secondary_ax.plot(vs, color='orange')
961971
best_epoc_line = primary_ax.axvline(self.best_epoch, color='black', linestyle='dashed')
962972

973+
if log_scale:
974+
for ax in (primary_ax, secondary_ax):
975+
ax.set_yscale("log", base=10)
976+
ax.yaxis.set_major_locator(LogLocator(base=10))
977+
ax.yaxis.set_major_formatter(LogFormatter(base=10))
978+
ax.yaxis.set_minor_locator(LogLocator(base=10, subs=np.arange(2, 10) * 0.1))
979+
963980
primary_ax.set_xlabel("epoch")
964-
primary_ax.set_ylabel(ts.name)
965-
secondary_ax.set_ylabel(vs.name)
981+
primary_ax.set_ylabel(ts.name + (" (log)" if log_scale else ""))
982+
secondary_ax.set_ylabel(vs.name + (" (log)" if log_scale else ""))
966983

967984
primary_ax.legend(training_line + validation_line + [best_epoc_line], [ts.name, vs.name, "best epoch"])
968985
plt.tight_layout()

0 commit comments

Comments
 (0)