Skip to content

Commit 8977154

Browse files
committed
Add unit tests for genRuleFile
Signed-off-by: rohithsiddi <rohithsiddi7@gmail.com>
1 parent 7a3f3d5 commit 8977154

1 file changed

Lines changed: 340 additions & 0 deletions

File tree

flow/test/test_genRuleFile.py

Lines changed: 340 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,340 @@
1+
#!/usr/bin/env python3
2+
3+
import unittest
4+
from unittest.mock import patch
5+
from io import StringIO
6+
import sys
7+
import os
8+
import json
9+
import tempfile
10+
11+
sys.path.append(os.path.join(os.path.dirname(os.path.abspath(__file__)), "..", "util"))
12+
13+
import genRuleFile
14+
15+
16+
def make_base_metrics():
17+
"""Return a minimal but complete metrics dict for testing."""
18+
return {
19+
"constraints__clocks__count": 1,
20+
"constraints__clocks__details": ["core_clock: 2.0"],
21+
"synth__design__instance__area__stdcell": 500.0,
22+
"detailedplace__design__violations": 0,
23+
"placeopt__design__instance__count__stdcell": 400,
24+
"placeopt__design__instance__area": 600,
25+
"cts__design__instance__count__setup_buffer": 20,
26+
"cts__design__instance__count__hold_buffer": 15,
27+
"cts__timing__setup__ws": -0.05,
28+
"cts__timing__setup__tns": -0.3,
29+
"cts__timing__hold__ws": 0.0,
30+
"cts__timing__hold__tns": 0.0,
31+
"globalroute__antenna_diodes_count": 5,
32+
"globalroute__route__net": 500,
33+
"globalroute__timing__setup__ws": -0.06,
34+
"globalroute__timing__setup__tns": -0.5,
35+
"globalroute__timing__hold__ws": 0.0,
36+
"globalroute__timing__hold__tns": 0.0,
37+
"detailedroute__route__wirelength": 3000,
38+
"detailedroute__route__drc_errors": 0,
39+
"detailedroute__antenna__violating__nets": 2,
40+
"detailedroute__antenna_diodes_count": 3,
41+
"detailedroute__route__net": 500,
42+
"finish__timing__setup__ws": -0.1,
43+
"finish__timing__setup__tns": -2.0,
44+
"finish__timing__hold__ws": 0.05,
45+
"finish__timing__hold__tns": 0.0,
46+
"finish__design__instance__area": 800,
47+
}
48+
49+
50+
class TestCommaList(unittest.TestCase):
51+
def test_all_returns_empty(self):
52+
self.assertEqual(genRuleFile.comma_separated_list("all"), [])
53+
54+
def test_none_returns_empty(self):
55+
self.assertEqual(genRuleFile.comma_separated_list(None), [])
56+
57+
def test_csv(self):
58+
self.assertEqual(genRuleFile.comma_separated_list("a,b,c"), ["a", "b", "c"])
59+
60+
def test_whitespace_trimming(self):
61+
self.assertEqual(genRuleFile.comma_separated_list("a , b , c"), ["a", "b", "c"])
62+
63+
64+
class TestGenRuleFile(unittest.TestCase):
65+
def setUp(self):
66+
self.tmp_dir = tempfile.TemporaryDirectory()
67+
self.metrics_file = os.path.join(self.tmp_dir.name, "metrics.json")
68+
self.rules_file = os.path.join(self.tmp_dir.name, "rules.json")
69+
self.new_rules_file = os.path.join(self.tmp_dir.name, "new_rules.json")
70+
71+
def _run(self, metrics, old_rules=None, **kwargs):
72+
with open(self.metrics_file, "w") as f:
73+
json.dump(metrics, f)
74+
if old_rules is not None:
75+
with open(self.rules_file, "w") as f:
76+
json.dump(old_rules, f)
77+
defaults = dict(
78+
rules_file=self.rules_file,
79+
new_rules_file=self.new_rules_file,
80+
update=False,
81+
tighten=False,
82+
failing=False,
83+
variant="base",
84+
metrics_file=self.metrics_file,
85+
metrics_to_consider=[],
86+
)
87+
defaults.update(kwargs)
88+
genRuleFile.gen_rule_file(**defaults)
89+
with open(self.new_rules_file, "r") as f:
90+
return json.load(f)
91+
92+
def test_direct_mode(self):
93+
metrics = make_base_metrics()
94+
rules = self._run(metrics, update=True)
95+
# detailedplace__design__violations uses direct mode
96+
self.assertEqual(rules["detailedplace__design__violations"]["value"], 0)
97+
self.assertEqual(rules["detailedplace__design__violations"]["compare"], "==")
98+
99+
def test_direct_mode_drc_errors(self):
100+
metrics = make_base_metrics()
101+
metrics["detailedroute__route__drc_errors"] = 3
102+
rules = self._run(metrics, update=True)
103+
self.assertEqual(rules["detailedroute__route__drc_errors"]["value"], 3)
104+
105+
def test_padding_mode_area(self):
106+
metrics = make_base_metrics()
107+
rules = self._run(metrics, update=True)
108+
# synth__design__instance__area__stdcell: padding=15, round_value=False
109+
# 500.0 * 1.15 = 575.0 -> 3 sig figs -> 575.0
110+
expected = float(f"{500.0 * 1.15:.3g}")
111+
self.assertEqual(
112+
rules["synth__design__instance__area__stdcell"]["value"], expected
113+
)
114+
115+
def test_padding_mode_rounded(self):
116+
metrics = make_base_metrics()
117+
rules = self._run(metrics, update=True)
118+
# placeopt__design__instance__area: padding=15, round_value=True
119+
# 600 * 1.15 = 690.0 -> round -> 690
120+
self.assertEqual(rules["placeopt__design__instance__area"]["value"], 690)
121+
self.assertIsInstance(rules["placeopt__design__instance__area"]["value"], int)
122+
123+
def test_padding_mode_wirelength(self):
124+
metrics = make_base_metrics()
125+
rules = self._run(metrics, update=True)
126+
# detailedroute__route__wirelength: padding=15, round_value=True
127+
# 3000 * 1.15 = 3450
128+
self.assertEqual(rules["detailedroute__route__wirelength"]["value"], 3450)
129+
130+
def test_period_padding_negative_slack(self):
131+
metrics = make_base_metrics()
132+
# finish__timing__setup__ws = -0.1, period = 2.0, padding = 5
133+
# negative_slack = min(-0.1, 0) = -0.1
134+
# rule = -0.1 - max(-0.1 * 5/100, 2.0 * 5/100)
135+
# = -0.1 - max(0.005, 0.1) = -0.1 - 0.1 = -0.2
136+
rules = self._run(metrics, update=True)
137+
self.assertAlmostEqual(
138+
rules["finish__timing__setup__ws"]["value"], -0.2, places=3
139+
)
140+
self.assertEqual(rules["finish__timing__setup__ws"]["compare"], ">=")
141+
142+
def test_period_padding_zero_slack(self):
143+
metrics = make_base_metrics()
144+
# finish__timing__hold__ws = 0.05 (positive)
145+
# negative_slack = min(0.05, 0) = 0
146+
# rule = 0 - max(0 * 5/100, 2.0 * 5/100) = 0 - 0.1 = -0.1
147+
rules = self._run(metrics, update=True)
148+
self.assertAlmostEqual(
149+
rules["finish__timing__hold__ws"]["value"], -0.1, places=3
150+
)
151+
152+
def test_period_padding_tns(self):
153+
metrics = make_base_metrics()
154+
# finish__timing__setup__tns = -2.0, period = 2.0, padding = 20
155+
# negative_slack = min(-2.0, 0) = -2.0
156+
# rule = -2.0 - max(-2.0 * 20/100, 2.0 * 20/100)
157+
# = -2.0 - max(0.4, 0.4) = -2.0 - 0.4 = -2.4
158+
rules = self._run(metrics, update=True)
159+
self.assertAlmostEqual(
160+
rules["finish__timing__setup__tns"]["value"], -2.4, places=3
161+
)
162+
163+
def test_metric_mode_antenna_diodes(self):
164+
metrics = make_base_metrics()
165+
rules = self._run(metrics, update=True)
166+
# globalroute__antenna_diodes_count: mode=metric, padding=0.1,
167+
# metric=globalroute__route__net (500), min_max=max, min_max_direct=100
168+
# rule = 500 * 0.1 / 100 = 0.5 -> max(0.5, 100) = 100
169+
self.assertEqual(rules["globalroute__antenna_diodes_count"]["value"], 100)
170+
171+
def test_cts_buffer_min_threshold(self):
172+
metrics = make_base_metrics()
173+
rules = self._run(metrics, update=True)
174+
# cts__design__instance__count__setup_buffer: mode=metric, padding=10,
175+
# metric=placeopt__design__instance__count__stdcell (400)
176+
# rule_value = 400 * 10 / 100 = 40
177+
# special: max(40, 20 * 1.1) = max(40, 22) = 40
178+
self.assertEqual(
179+
rules["cts__design__instance__count__setup_buffer"]["value"], 40
180+
)
181+
182+
def test_cts_buffer_uses_metric_times_1_1(self):
183+
metrics = make_base_metrics()
184+
# Set placeopt stdcell count low so metric mode gives small value
185+
metrics["placeopt__design__instance__count__stdcell"] = 10
186+
metrics["cts__design__instance__count__setup_buffer"] = 50
187+
rules = self._run(metrics, update=True)
188+
# rule_value = 10 * 10 / 100 = 1
189+
# special: max(1, 50 * 1.1) = max(1, 55) = 55
190+
self.assertEqual(
191+
rules["cts__design__instance__count__setup_buffer"]["value"], 55
192+
)
193+
194+
def test_round_value_true_produces_int(self):
195+
metrics = make_base_metrics()
196+
rules = self._run(metrics, update=True)
197+
self.assertIsInstance(
198+
rules["placeopt__design__instance__count__stdcell"]["value"], int
199+
)
200+
201+
def test_round_value_false_produces_float(self):
202+
metrics = make_base_metrics()
203+
rules = self._run(metrics, update=True)
204+
self.assertIsInstance(
205+
rules["synth__design__instance__area__stdcell"]["value"], float
206+
)
207+
208+
def test_wildcard_warnings_match(self):
209+
metrics = make_base_metrics()
210+
# The wildcard pattern "*flow__warnings__count:*" matches keys with
211+
# colons. The code then replaces ":" with "__" for the output key but
212+
# looks up the replaced name in metrics, so we provide both forms.
213+
metrics["1_synth__flow__warnings__count:default"] = 5
214+
metrics["1_synth__flow__warnings__count__default"] = 5
215+
rules = self._run(metrics, update=True)
216+
key = "1_synth__flow__warnings__count__default"
217+
self.assertIn(key, rules)
218+
self.assertEqual(rules[key]["value"], 5)
219+
self.assertEqual(rules[key]["compare"], "<=")
220+
self.assertEqual(rules[key].get("level"), "warning")
221+
222+
@patch("sys.stdout", new_callable=StringIO)
223+
def test_no_old_rules_file_warns(self, mock_stdout):
224+
metrics = make_base_metrics()
225+
self._run(metrics, update=True)
226+
self.assertIn("[WARNING] No old rules file found", mock_stdout.getvalue())
227+
228+
def test_tighten_updates_when_tighter(self):
229+
metrics = make_base_metrics()
230+
old_rules = {
231+
"detailedroute__route__wirelength": {"value": 5000, "compare": "<="},
232+
}
233+
rules = self._run(metrics, old_rules=old_rules, tighten=True)
234+
# New value 3450 is tighter than old 5000 for <=
235+
self.assertEqual(rules["detailedroute__route__wirelength"]["value"], 3450)
236+
237+
def test_tighten_keeps_old_when_not_tighter(self):
238+
metrics = make_base_metrics()
239+
old_rules = {
240+
"detailedroute__route__wirelength": {"value": 3000, "compare": "<="},
241+
}
242+
rules = self._run(metrics, old_rules=old_rules, tighten=True)
243+
# New value 3450 is NOT tighter than old 3000 for <=, keep old
244+
self.assertEqual(rules["detailedroute__route__wirelength"]["value"], 3000)
245+
246+
def test_failing_updates_when_metric_fails(self):
247+
metrics = make_base_metrics()
248+
metrics["detailedroute__route__wirelength"] = 6000
249+
old_rules = {
250+
"detailedroute__route__wirelength": {"value": 5000, "compare": "<="},
251+
}
252+
rules = self._run(metrics, old_rules=old_rules, failing=True)
253+
# metric 6000 fails old rule (6000 <= 5000 is false), so update
254+
# new value = 6000 * 1.15 = 6900
255+
self.assertEqual(rules["detailedroute__route__wirelength"]["value"], 6900)
256+
257+
def test_failing_keeps_old_when_passing(self):
258+
metrics = make_base_metrics()
259+
metrics["detailedroute__route__wirelength"] = 4000
260+
old_rules = {
261+
"detailedroute__route__wirelength": {"value": 5000, "compare": "<="},
262+
}
263+
rules = self._run(metrics, old_rules=old_rules, failing=True)
264+
# metric 4000 passes old rule (4000 <= 5000), keep old
265+
self.assertEqual(rules["detailedroute__route__wirelength"]["value"], 5000)
266+
267+
def test_update_always_changes(self):
268+
metrics = make_base_metrics()
269+
old_rules = {
270+
"detailedroute__route__wirelength": {"value": 9999, "compare": "<="},
271+
}
272+
rules = self._run(metrics, old_rules=old_rules, update=True)
273+
self.assertEqual(rules["detailedroute__route__wirelength"]["value"], 3450)
274+
275+
@patch("sys.stdout", new_callable=StringIO)
276+
def test_string_metric_skipped(self, mock_stdout):
277+
metrics = make_base_metrics()
278+
metrics["detailedplace__design__violations"] = "N/A"
279+
rules = self._run(metrics, update=True)
280+
self.assertIn("[WARNING] Skipping string field", mock_stdout.getvalue())
281+
self.assertNotIn("detailedplace__design__violations", rules)
282+
283+
@patch("sys.stdout", new_callable=StringIO)
284+
def test_missing_clocks_details_warns(self, mock_stdout):
285+
metrics = make_base_metrics()
286+
# metrics.get() returns None when key is absent; the code checks
287+
# truthiness so both None and [] trigger the warning.
288+
metrics["constraints__clocks__details"] = []
289+
rules = self._run(metrics, update=True)
290+
self.assertIn(
291+
"'constraints__clocks__details' not found", mock_stdout.getvalue()
292+
)
293+
294+
@patch("sys.stdout", new_callable=StringIO)
295+
def test_multiple_clocks_warns(self, mock_stdout):
296+
metrics = make_base_metrics()
297+
metrics["constraints__clocks__details"] = [
298+
"clk1: 2.0",
299+
"clk2: 5.0",
300+
]
301+
self._run(metrics, update=True)
302+
self.assertIn("Multiple clocks not supported", mock_stdout.getvalue())
303+
304+
def test_metrics_to_consider_preserves_others(self):
305+
metrics = make_base_metrics()
306+
old_rules = {
307+
"detailedroute__route__wirelength": {"value": 9999, "compare": "<="},
308+
"finish__design__instance__area": {"value": 7777, "compare": "<="},
309+
}
310+
rules = self._run(
311+
metrics,
312+
old_rules=old_rules,
313+
update=True,
314+
metrics_to_consider=["detailedroute__route__wirelength"],
315+
)
316+
# wirelength is in the consider list, so it should be updated
317+
self.assertEqual(rules["detailedroute__route__wirelength"]["value"], 3450)
318+
# finish area is NOT in the consider list, so old value preserved
319+
self.assertEqual(rules["finish__design__instance__area"]["value"], 7777)
320+
321+
def test_clocks_count_direct(self):
322+
metrics = make_base_metrics()
323+
rules = self._run(metrics, update=True)
324+
self.assertEqual(rules["constraints__clocks__count"]["value"], 1)
325+
self.assertEqual(rules["constraints__clocks__count"]["compare"], "==")
326+
327+
def test_compare_operators(self):
328+
metrics = make_base_metrics()
329+
rules = self._run(metrics, update=True)
330+
self.assertEqual(
331+
rules["synth__design__instance__area__stdcell"]["compare"], "<="
332+
)
333+
self.assertEqual(rules["finish__timing__setup__ws"]["compare"], ">=")
334+
335+
def tearDown(self):
336+
self.tmp_dir.cleanup()
337+
338+
339+
if __name__ == "__main__":
340+
unittest.main()

0 commit comments

Comments
 (0)