diff --git a/callbacks.py b/callbacks.py index 182c3c1..311f094 100644 --- a/callbacks.py +++ b/callbacks.py @@ -4,13 +4,11 @@ NoitaMonitorCallback — episode stats, Telegram alerts, W&B logging, Rich table """ -from __future__ import annotations import os import time import csv from collections import deque -from typing import Optional import glob import numpy as np diff --git a/config.py b/config.py index 90e11cf..622c5e1 100644 --- a/config.py +++ b/config.py @@ -1,8 +1,7 @@ -from __future__ import annotations from typing import Optional from pydantic_settings import BaseSettings, SettingsConfigDict -from pydantic import Field, field_validator +from pydantic import Field class Config(BaseSettings): diff --git a/eval.py b/eval.py index 4fca783..0be0eb9 100644 --- a/eval.py +++ b/eval.py @@ -8,7 +8,6 @@ python eval.py model.zip --slow 0.5 # run at 50% speed (sleep 50ms between steps) """ -from __future__ import annotations import argparse import sys diff --git a/notify.py b/notify.py index 0b1050f..47783ab 100644 --- a/notify.py +++ b/notify.py @@ -8,7 +8,6 @@ • Graceful no-op — if token / chat_id are empty, every method silently returns """ -from __future__ import annotations import io import threading @@ -23,6 +22,18 @@ from PIL import Image, ImageDraw, ImageFont + +_cached_font = None +def _get_font(): + global _cached_font + if _cached_font is None: + try: + from PIL import ImageFont + _cached_font = ImageFont.truetype("arial.ttf", 24) + except IOError: + _cached_font = ImageFont.load_default() + return _cached_font + class TelegramNotifier: def __init__(self, token: str, chat_id: str): self._token = token @@ -188,10 +199,7 @@ def make_death_postcard(png_bytes: bytes, stats_text: str) -> bytes: try: img = Image.open(io.BytesIO(png_bytes)).convert("RGB") draw = ImageDraw.Draw(img) - try: - font = ImageFont.truetype("arial.ttf", 24) - except IOError: - font = ImageFont.load_default() + font = _get_font() # Add a semi-transparent black rectangle at the bottom width, height = img.size @@ -337,8 +345,7 @@ def register_stats_provider(self, fn: Callable[[], str]) -> None: """fn() should return a formatted string of current training stats.""" self._stats_fn = fn - @staticmethod - def _find_noita_hwnd() -> Optional[int]: + def _find_noita_hwnd(self) -> "int | None": """ Return HWND of the Noita game window, matched by process executable (noita.exe / noita_dev.exe) so browser tabs with 'noita' in their title @@ -350,6 +357,12 @@ def _find_noita_hwnd() -> Optional[int]: k32 = ctypes.windll.kernel32 user32 = ctypes.windll.user32 + + # Check if cached hwnd is still valid + if hasattr(self, '_cached_hwnd') and self._cached_hwnd is not None: + if user32.IsWindow(self._cached_hwnd): + return self._cached_hwnd + self._cached_hwnd = None PROCESS_QUERY_LIMITED = 0x1000 found: list = [] @@ -381,7 +394,10 @@ def _cb(hwnd, _): return True user32.EnumWindows(_cb, 0) - return found[0] if found else None + if found: + self._cached_hwnd = found[0] + return found[0] + return None except Exception: return None @@ -466,10 +482,7 @@ def capture_noita_screen(self, overlay_text: str = "") -> bytes: if overlay_text: draw = ImageDraw.Draw(img) - try: - font = ImageFont.truetype("arial.ttf", 24) - except IOError: - font = ImageFont.load_default() + font = _get_font() text_bbox = draw.textbbox((10, 10), overlay_text, font=font) draw.rectangle( [text_bbox[0]-5, text_bbox[1]-5, text_bbox[2]+5, text_bbox[3]+5], diff --git a/tests/test_callbacks.py b/tests/test_callbacks.py new file mode 100644 index 0000000..f097592 --- /dev/null +++ b/tests/test_callbacks.py @@ -0,0 +1,20 @@ +import pytest +from unittest.mock import MagicMock, patch +from callbacks import NoitaMonitorCallback + + +def test_callbacks_init(): + mock_env = MagicMock() + mock_env.envs = [MagicMock()] + mock_env.envs[0].port = 1234 + + with patch("callbacks.TelegramNotifier", create=True) as mock_notifier: + with patch("callbacks.VideoRecorder", create=True) as mock_recorder: + mock_cfg = MagicMock() + cb = NoitaMonitorCallback( + cfg=mock_cfg, + notifier=mock_notifier, + verbose=0, + recorder=mock_recorder + ) + assert cb is not None diff --git a/tests/test_eval.py b/tests/test_eval.py new file mode 100644 index 0000000..d416c4e --- /dev/null +++ b/tests/test_eval.py @@ -0,0 +1,50 @@ +import pytest +import os +import io +import torch +import torch.nn as nn +from unittest.mock import patch, MagicMock + +# Assuming eval.py has calculate_thinking which uses torch +try: + from eval import calculate_thinking +except ImportError: + calculate_thinking = None + +@pytest.mark.skipif(calculate_thinking is None, reason="calculate_thinking not found") +def test_calculate_thinking(): + # Create a mock model with features_extractor that returns some tensor + class MockExtractor(nn.Module): + def forward(self, obs): + # simulate features extraction + return torch.ones((1, 64)) + + class MockDistribution: + def __init__(self, obs): + # simulate 4 discrete actions + self.probs = torch.tensor([[0.1, 0.2, 0.6, 0.1]]) + # Make sure logits require grad and are connected to obs + self.logits = obs.mean() * torch.tensor([[1.0, 2.0, 6.0, 1.0]], requires_grad=True) + + class MockPolicy(nn.Module): + def get_distribution(self, obs): + dist = MagicMock() + dist.distribution = MockDistribution(obs) + return dist + + class MockModel: + def __init__(self): + self.policy = MockPolicy() + self.device = torch.device('cpu') + + model = MockModel() + + # Create a dummy observation + obs = torch.zeros((1, 3, 64, 64)).numpy() + + # Run calculation + # calculate_thinking takes (model, obs) and returns something (probably a tensor or float) + saliency = calculate_thinking(model, obs) + + # Just check it returns something without crashing + assert saliency is not None diff --git a/tests/test_noita_env.py b/tests/test_noita_env.py new file mode 100644 index 0000000..d35946e --- /dev/null +++ b/tests/test_noita_env.py @@ -0,0 +1,48 @@ +import pytest +from unittest.mock import MagicMock, patch +import asyncio + +def make_env(): + import importlib, types, sys + + # Mock problematic modules + sys.modules['mss'] = MagicMock() + sys.modules['pygetwindow'] = MagicMock() + sys.modules['pyrect'] = MagicMock() + + import noita_env + importlib.reload(noita_env) + + env = noita_env.NoitaEnv.__new__(noita_env.NoitaEnv) + env._lock = MagicMock() + env._ws = MagicMock() + env._sct = MagicMock() + env._frame_buf = [] + env.port = 1234 + return env + +class TestNoitaEnvMethods: + def test_render_does_not_crash(self): + env = make_env() + # Should do nothing and not crash + env.render() + + def test_close_clears_ws(self): + env = make_env() + env._ws = MagicMock() + assert env._ws is not None + env.close() + assert env._ws is None + env._lock.__enter__.assert_called() + + def test_handle_exception_path(self): + env = make_env() + env._ws = None + # Create a mock socket + mock_ws = MagicMock() + mock_ws.recv.side_effect = Exception("Connection closed") + + # Should gracefully exit the loop when exception is caught + asyncio.run(env._handle(mock_ws)) + # Should set _ws to None on exit + assert env._ws is None diff --git a/tests/test_notify.py b/tests/test_notify.py index 2d7ff60..26711ac 100644 --- a/tests/test_notify.py +++ b/tests/test_notify.py @@ -1,3 +1,4 @@ +from unittest.mock import patch """ Tests for TelegramNotifier. @@ -185,3 +186,22 @@ def test_register_stats_provider(self, noop): noop.register_stats_provider(lambda: "stats text") assert noop._stats_fn is not None assert noop._stats_fn() == "stats text" + +# --------------------------------------------------------------------------- +# AI Status +# --------------------------------------------------------------------------- + +class TestAIStatus: + def test_empty_groq_key(self, noop): + result = noop.generate_ai_status(groq_key="", stats_context="Test stats") + assert result == "⚠️ Groq API key not set. Please set GROQ_API_KEY in .env." + + def test_empty_groq_key_none(self, noop): + result = noop.generate_ai_status(groq_key=None, stats_context="Test stats") + assert result == "⚠️ Groq API key not set. Please set GROQ_API_KEY in .env." + +class TestSetupBotMenu: + def test_setup_bot_menu_error(self, fake): + with patch('requests.post', side_effect=Exception('Mock API Error')): + # Should not crash + fake.setup_bot_menu() diff --git a/tests/test_offline_analysis.py b/tests/test_offline_analysis.py new file mode 100644 index 0000000..807403b --- /dev/null +++ b/tests/test_offline_analysis.py @@ -0,0 +1,24 @@ +import pytest +import os +import json +from unittest.mock import patch, mock_open + +from offline_analysis import analyze_actions + +def test_analyze_actions(): + mock_data = '{"episode": 1, "action": [1, 0, 0, 0], "reward": 10.0, "ep_reward": 10.0, "total_steps": 1}\n' + + with patch('builtins.open', mock_open(read_data=mock_data)): + with patch('offline_analysis.os.path.exists', return_value=True): + with patch('offline_analysis.plt') as mock_plt: + analyze_actions() + + assert mock_plt.figure.call_count > 0 + assert mock_plt.savefig.call_count > 0 + +def test_analyze_actions_no_file(): + with patch('offline_analysis.os.path.exists', return_value=False): + # Should print and return early + with patch('builtins.print') as mock_print: + analyze_actions() + mock_print.assert_called_once() diff --git a/tests/test_train.py b/tests/test_train.py new file mode 100644 index 0000000..e819b42 --- /dev/null +++ b/tests/test_train.py @@ -0,0 +1,53 @@ +import os +import pytest +from unittest.mock import patch, MagicMock + +from config import Config +from train import setup_logging, setup_wandb + +@pytest.fixture +def mock_config(): + cfg = MagicMock(spec=Config) + cfg.checkpoint_dir = "/tmp/checkpoints" + cfg.tensorboard_dir = "/tmp/tb" + cfg.log_level = "INFO" + cfg.log_dir = "/tmp/logs" + cfg.wandb_project = "noita-rl" + cfg.wandb_entity = "jules" + cfg.wandb_enabled = True + cfg.n_envs = 1 + cfg.total_timesteps = 1000 + cfg.learning_rate = 0.001 + cfg.n_steps = 2048 + cfg.batch_size = 64 + cfg.n_epochs = 10 + cfg.gamma = 0.99 + cfg.ent_coef = 0.01 + return cfg + +def test_setup_logging(tmp_path, mock_config): + mock_config.checkpoint_dir = str(tmp_path / "checkpoints") + mock_config.tensorboard_dir = str(tmp_path / "tb") + + with patch("train.logger") as mock_logger, patch("sys.stderr"): + setup_logging(mock_config, "test_run") + + # Verify logger.add was called + assert mock_logger.add.call_count > 0 + +def test_setup_wandb_enabled(mock_config): + with patch("train.wandb.init", create=True) as mock_init: + setup_wandb(mock_config, "test_run") + mock_init.assert_called_once() + +def test_setup_wandb_disabled(mock_config): + mock_config.wandb_enabled = False + with patch("train.wandb.init", create=True) as mock_init: + setup_wandb(mock_config, "test_run") + mock_init.assert_not_called() + +def test_setup_wandb_error(mock_config): + with patch("train.wandb.init", side_effect=Exception("WandB Error"), create=True) as mock_init, \ + patch("train.logger.warning") as mock_warn: + setup_wandb(mock_config, "test_run") + mock_warn.assert_called_once() diff --git a/tests/test_train_multi.py b/tests/test_train_multi.py new file mode 100644 index 0000000..d0ec9bb --- /dev/null +++ b/tests/test_train_multi.py @@ -0,0 +1,27 @@ +import pytest +from unittest.mock import patch, MagicMock +from train_multi import find_latest_checkpoint + +def test_find_latest_checkpoint_empty_dir(tmp_path): + assert find_latest_checkpoint(str(tmp_path)) is None + +def test_find_latest_checkpoint_no_match(tmp_path): + (tmp_path / "other_run_100_steps.txt").touch() + assert find_latest_checkpoint(str(tmp_path)) is None + +def test_find_latest_checkpoint_single_match(tmp_path): + f = tmp_path / "test_run_100_steps.zip" + f.touch() + assert find_latest_checkpoint(str(tmp_path)) == str(f) + +def test_find_latest_checkpoint_multiple_matches(tmp_path): + f1 = tmp_path / "test_run_100_steps.zip" + f2 = tmp_path / "test_run_200_steps.zip" + f3 = tmp_path / "test_run_50_steps.zip" + import time + f1.touch() + time.sleep(0.01) + f3.touch() + time.sleep(0.01) + f2.touch() + assert find_latest_checkpoint(str(tmp_path)) == str(f2) diff --git a/tests/test_video_recorder.py b/tests/test_video_recorder.py index 0825c62..48d7aaf 100644 --- a/tests/test_video_recorder.py +++ b/tests/test_video_recorder.py @@ -109,6 +109,12 @@ def test_fallback_returns_nonempty_string(self, recorder, event_name): assert isinstance(result, str) assert len(result) > 5 + + @pytest.mark.parametrize("event_name", ALL_EVENTS) + def test_trigger_does_not_crash_on_empty_ctx(self, recorder, event_name): + recorder.trigger_event(event_name, {}) + assert not recorder._event_q.empty() + @pytest.mark.parametrize("event_name", ALL_EVENTS) def test_fallback_does_not_crash_on_empty_ctx(self, recorder, event_name): result = recorder._groq_describe(event_name, {}) @@ -237,3 +243,35 @@ def test_active_event_none_initially(self, recorder): def test_pre_buffer_starts_empty(self, recorder): assert len(recorder._pre_buf) == 0 + +# --------------------------------------------------------------------------- +# State Tests +# --------------------------------------------------------------------------- + +class TestStateProperties: + def test_is_idle(self, recorder): + assert recorder.is_idle is True + recorder._state = "recording" + assert recorder.is_idle is False + recorder._state = "cooldown" + assert recorder.is_idle is False + + def test_status(self, recorder): + assert recorder.status == "idle" + recorder._state = "recording" + assert recorder.status == "recording" + recorder._state = "cooldown" + assert recorder.status == "cooldown" + + def test_force_trigger(self, recorder): + recorder._state = "cooldown" # Set to cooldown + assert recorder.is_idle is False + + # force_trigger should bypass cooldown + recorder.force_trigger("force_event", {"info": "data"}) + + # It should place the event on the queue regardless of cooldown state + assert not recorder._event_q.empty() + item = recorder._event_q.get_nowait() + assert item["name"] == "force_event" + assert item["ctx"] == {"info": "data"} diff --git a/train.py b/train.py index 96fa4cd..baaaf35 100644 --- a/train.py +++ b/train.py @@ -8,11 +8,15 @@ python train.py --name "experiment-01" """ -from __future__ import annotations + import argparse import os import sys +import wandb + + + # Workaround for OpenMP duplicate library error os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE" diff --git a/video_recorder.py b/video_recorder.py index da6d3ce..e3c493e 100644 --- a/video_recorder.py +++ b/video_recorder.py @@ -1,3 +1,4 @@ +from collections import deque """ VideoRecorder — event-triggered highlight recorder for NoitaRL. @@ -29,7 +30,7 @@ wand_kill — wand shot killed an enemy """ -from __future__ import annotations + import collections import io @@ -564,8 +565,8 @@ def _read_recent_logs(n_lines: int = 30) -> str: log_path = os.path.join(os.path.dirname(__file__), "logger.txt") try: with open(log_path, "r", encoding="utf-8", errors="replace") as f: - lines = f.readlines() - return "".join(lines[-n_lines:]).strip() + lines = deque(f, maxlen=n_lines) + return "".join(lines).strip() except Exception: return "" @@ -575,8 +576,8 @@ def _read_recent_actions(n: int = 20) -> str: trace_path = os.path.join(os.path.dirname(__file__), "actions_trace.jsonl") try: with open(trace_path, "r", encoding="utf-8", errors="replace") as f: - lines = f.readlines() - return "".join(lines[-n:]).strip() + lines = deque(f, maxlen=n) + return "".join(lines).strip() except Exception: return ""