This repository was archived by the owner on Dec 11, 2024. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathplot.py
More file actions
executable file
·154 lines (132 loc) · 4.43 KB
/
plot.py
File metadata and controls
executable file
·154 lines (132 loc) · 4.43 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
#!/usr/bin/env python
"""Generate plots for comparing schedulers."""
from argparse import ArgumentDefaultsHelpFormatter, ArgumentParser, Namespace
from pathlib import Path
from typing import Any, Final, Optional
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
from tbparse import SummaryReader
from tqdm import tqdm
from src.config import load_config
from src.utils import AVAIL_TASKS
sns.set()
_SMOOTH_ALPHA: Final = 0.4 # The smoothing factor
_TAGS_TO_PLOT: Final = {
"rls": [
("metrics/distance", "Distance"),
("metrics/potential", "Potential"),
],
"covar": [
("metrics/distance", "Distance"),
("gradients/x", "Gradients w.r.t. X"),
],
"cifar10": [
("metrics/fid", "FID"),
("metrics/inception_score", "Inception Score"),
],
} # The metrics to plot (as separate plots), with their names
_MODE_TO_COL: Final = {
"sched": "Scheduler",
"decay": "Decay",
} # Used for the legend title
_SCHED_TO_NAME: Final = {
"const": "Constant",
"step": "Step-decay",
"var": "Increasing-phase",
"poly": "Poly-linear",
"poly-sqrt": "Poly-sqrt",
} # Used for naming schedulers in the legend
def format_row(row):
"""Format one row of the final metrics in scientific notation.
Example: 5.23 ± 0.14 x 10^-3
"""
mean_mantissa, mean_exp = f"{row['mean']:.3e}".split("e")
if np.isinf(row["std"]) or np.isnan(row["std"]):
adjusted_std_mantissa = row["std"]
else:
std_mantissa, std_exp = f"{row['std']:.3e}".split("e")
adjusted_std_mantissa = float(std_mantissa) * 10 ** (
int(std_exp) - int(mean_exp)
)
return (
f"{mean_mantissa} ± {adjusted_std_mantissa:.3f} x "
f"10^{int(mean_exp)}"
)
def main(args: Namespace) -> None:
"""Run the main function."""
data: Optional[pd.DataFrame] = None
# Sort log dirs for determinism
for path in tqdm(sorted(args.log_dir)):
path_data = SummaryReader(path).scalars
config_path = path / "hparams.yaml"
if not config_path.exists():
continue
config = load_config(config_path)
if args.mode == "sched":
name: Any = _SCHED_TO_NAME[config.sched]
elif args.mode == "decay":
name = config.decay
path_data[_MODE_TO_COL[args.mode]] = name
data = pd.concat([data, path_data], ignore_index=True)
assert data is not None
for tag, tag_name in _TAGS_TO_PLOT[args.task.split("/")[0]]:
tag_data = data[data["tag"] == tag].copy()
# Smooth the values per-tag and per-"group" (eg. per-scheduler)
tag_data["smoothed"] = tag_data.groupby(_MODE_TO_COL[args.mode])[
"value"
].apply(lambda x: x.ewm(alpha=_SMOOTH_ALPHA).mean())
# Get mean and std for the final values
final_metrics = tag_data.groupby(_MODE_TO_COL[args.mode]).apply(
lambda df: df.loc[df["step"] == df["step"].max(), "smoothed"].agg(
["mean", "std"]
)
)
print(f"{tag}:")
print(final_metrics.apply(format_row, axis=1))
axes = sns.lineplot(
data=tag_data,
x="step",
y="smoothed",
hue=_MODE_TO_COL[args.mode],
ci="sd",
)
axes.set_xlabel("Steps")
axes.set_ylabel(tag_name)
axes.set_yscale("log")
tag_trimmed = tag.replace("/", "-")
fig_path = Path(f"{args.prefix}{tag_trimmed}.pdf")
if not fig_path.parent.exists():
fig_path.parent.mkdir(parents=True)
plt.savefig(fig_path, bbox_inches="tight", pad_inches=0)
axes.clear()
if __name__ == "__main__":
parser = ArgumentParser(
description="Generate plots for comparing schedulers",
formatter_class=ArgumentDefaultsHelpFormatter,
)
parser.add_argument(
"task",
choices=AVAIL_TASKS,
help="A string specifying the optimization task",
)
parser.add_argument(
"mode",
choices=["sched", "decay"],
help="The hyper-parameter to use for grouping",
)
parser.add_argument(
"log_dir",
nargs="+",
type=Path,
help="Path to the directories containing training logs",
)
parser.add_argument(
"-p",
"--prefix",
type=str,
default="plots/",
help="The prefix for naming the figure paths",
)
main(parser.parse_args())