Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 0 additions & 2 deletions callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,11 @@
NoitaMonitorCallback — episode stats, Telegram alerts, W&B logging, Rich table
"""

from __future__ import annotations

import os
import time
import csv
from collections import deque
from typing import Optional

import glob
import numpy as np
Expand Down
3 changes: 1 addition & 2 deletions config.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
from __future__ import annotations

from typing import Optional
from pydantic_settings import BaseSettings, SettingsConfigDict
from pydantic import Field, field_validator
from pydantic import Field


class Config(BaseSettings):
Expand Down
1 change: 0 additions & 1 deletion eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
python eval.py model.zip --slow 0.5 # run at 50% speed (sleep 50ms between steps)
"""

from __future__ import annotations

import argparse
import sys
Expand Down
37 changes: 25 additions & 12 deletions notify.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
• Graceful no-op — if token / chat_id are empty, every method silently returns
"""

from __future__ import annotations

import io
import threading
Expand All @@ -23,6 +22,18 @@
from PIL import Image, ImageDraw, ImageFont



_cached_font = None
def _get_font():
global _cached_font
if _cached_font is None:
try:
from PIL import ImageFont
_cached_font = ImageFont.truetype("arial.ttf", 24)
except IOError:
_cached_font = ImageFont.load_default()
return _cached_font
Comment on lines +26 to +35

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The font loading logic is not thread-safe and contains a redundant import. Since TelegramNotifier uses a background polling thread, concurrent calls to _get_font could lead to race conditions or multiple font loading attempts. Additionally, ImageFont is already imported at the top of the file (line 22), making the local import unnecessary.

Suggested change
_cached_font = None
def _get_font():
global _cached_font
if _cached_font is None:
try:
from PIL import ImageFont
_cached_font = ImageFont.truetype("arial.ttf", 24)
except IOError:
_cached_font = ImageFont.load_default()
return _cached_font
_font_lock = threading.Lock()
_cached_font = None
def _get_font():
global _cached_font
if _cached_font is None:
with _font_lock:
if _cached_font is None:
try:
_cached_font = ImageFont.truetype("arial.ttf", 24)
except IOError:
_cached_font = ImageFont.load_default()
return _cached_font


class TelegramNotifier:
def __init__(self, token: str, chat_id: str):
self._token = token
Expand Down Expand Up @@ -188,10 +199,7 @@ def make_death_postcard(png_bytes: bytes, stats_text: str) -> bytes:
try:
img = Image.open(io.BytesIO(png_bytes)).convert("RGB")
draw = ImageDraw.Draw(img)
try:
font = ImageFont.truetype("arial.ttf", 24)
except IOError:
font = ImageFont.load_default()
font = _get_font()

# Add a semi-transparent black rectangle at the bottom
width, height = img.size
Expand Down Expand Up @@ -337,8 +345,7 @@ def register_stats_provider(self, fn: Callable[[], str]) -> None:
"""fn() should return a formatted string of current training stats."""
self._stats_fn = fn

@staticmethod
def _find_noita_hwnd() -> Optional[int]:
def _find_noita_hwnd(self) -> "int | None":
"""
Return HWND of the Noita game window, matched by process executable
(noita.exe / noita_dev.exe) so browser tabs with 'noita' in their title
Expand All @@ -350,6 +357,12 @@ def _find_noita_hwnd() -> Optional[int]:

k32 = ctypes.windll.kernel32
user32 = ctypes.windll.user32

# Check if cached hwnd is still valid
if hasattr(self, '_cached_hwnd') and self._cached_hwnd is not None:
if user32.IsWindow(self._cached_hwnd):
return self._cached_hwnd
self._cached_hwnd = None
Comment on lines +362 to +365

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Using hasattr repeatedly is less efficient and less idiomatic than using getattr with a default value, especially since _cached_hwnd is not initialized in the constructor. This approach is cleaner and avoids potential AttributeError issues if the attribute is accessed elsewhere.

Suggested change
if hasattr(self, '_cached_hwnd') and self._cached_hwnd is not None:
if user32.IsWindow(self._cached_hwnd):
return self._cached_hwnd
self._cached_hwnd = None
cached_hwnd = getattr(self, '_cached_hwnd', None)
if cached_hwnd is not None:
if user32.IsWindow(cached_hwnd):
return cached_hwnd
self._cached_hwnd = None

PROCESS_QUERY_LIMITED = 0x1000

found: list = []
Expand Down Expand Up @@ -381,7 +394,10 @@ def _cb(hwnd, _):
return True

user32.EnumWindows(_cb, 0)
return found[0] if found else None
if found:
self._cached_hwnd = found[0]
return found[0]
return None
except Exception:
return None

Expand Down Expand Up @@ -466,10 +482,7 @@ def capture_noita_screen(self, overlay_text: str = "") -> bytes:

if overlay_text:
draw = ImageDraw.Draw(img)
try:
font = ImageFont.truetype("arial.ttf", 24)
except IOError:
font = ImageFont.load_default()
font = _get_font()
text_bbox = draw.textbbox((10, 10), overlay_text, font=font)
draw.rectangle(
[text_bbox[0]-5, text_bbox[1]-5, text_bbox[2]+5, text_bbox[3]+5],
Expand Down
20 changes: 20 additions & 0 deletions tests/test_callbacks.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
import pytest
from unittest.mock import MagicMock, patch
from callbacks import NoitaMonitorCallback


def test_callbacks_init():
mock_env = MagicMock()
mock_env.envs = [MagicMock()]
mock_env.envs[0].port = 1234

with patch("callbacks.TelegramNotifier", create=True) as mock_notifier:
with patch("callbacks.VideoRecorder", create=True) as mock_recorder:
mock_cfg = MagicMock()
cb = NoitaMonitorCallback(
cfg=mock_cfg,
notifier=mock_notifier,
verbose=0,
recorder=mock_recorder
)
assert cb is not None
50 changes: 50 additions & 0 deletions tests/test_eval.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
import pytest
import os
import io
import torch
import torch.nn as nn
from unittest.mock import patch, MagicMock

# Assuming eval.py has calculate_thinking which uses torch
try:
from eval import calculate_thinking
except ImportError:
calculate_thinking = None

@pytest.mark.skipif(calculate_thinking is None, reason="calculate_thinking not found")
def test_calculate_thinking():
Comment on lines +9 to +15

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major | ⚡ Quick win

Don’t skip this test when calculate_thinking is missing.

Current skip logic turns a broken/missing target function into a green test run, which weakens coverage for this path.

Suggested fix
-try:
-    from eval import calculate_thinking
-except ImportError:
-    calculate_thinking = None
-
-@pytest.mark.skipif(calculate_thinking is None, reason="calculate_thinking not found")
+from eval import calculate_thinking
 def test_calculate_thinking():
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
try:
from eval import calculate_thinking
except ImportError:
calculate_thinking = None
@pytest.mark.skipif(calculate_thinking is None, reason="calculate_thinking not found")
def test_calculate_thinking():
from eval import calculate_thinking
def test_calculate_thinking():
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@tests/test_eval.py` around lines 9 - 15, The test currently hides a missing
implementation by skipping when calculate_thinking cannot be imported; instead
make the test fail loudly: remove the try/except import wrapper (or keep it but
add an explicit assertion) so that test_calculate_thinking either raises
ImportError or asserts calculate_thinking is not None with a clear failure
message. Update references to the calculate_thinking symbol in
test_calculate_thinking so the test will fail if the function is absent rather
than being skipped.

# Create a mock model with features_extractor that returns some tensor
class MockExtractor(nn.Module):
def forward(self, obs):
# simulate features extraction
return torch.ones((1, 64))

class MockDistribution:
def __init__(self, obs):
# simulate 4 discrete actions
self.probs = torch.tensor([[0.1, 0.2, 0.6, 0.1]])
# Make sure logits require grad and are connected to obs
self.logits = obs.mean() * torch.tensor([[1.0, 2.0, 6.0, 1.0]], requires_grad=True)

class MockPolicy(nn.Module):
def get_distribution(self, obs):
dist = MagicMock()
dist.distribution = MockDistribution(obs)
return dist

class MockModel:
def __init__(self):
self.policy = MockPolicy()
self.device = torch.device('cpu')

model = MockModel()

# Create a dummy observation
obs = torch.zeros((1, 3, 64, 64)).numpy()

# Run calculation
# calculate_thinking takes (model, obs) and returns something (probably a tensor or float)
saliency = calculate_thinking(model, obs)

# Just check it returns something without crashing
assert saliency is not None
48 changes: 48 additions & 0 deletions tests/test_noita_env.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
import pytest
from unittest.mock import MagicMock, patch
import asyncio

def make_env():
import importlib, types, sys

# Mock problematic modules
sys.modules['mss'] = MagicMock()
sys.modules['pygetwindow'] = MagicMock()
sys.modules['pyrect'] = MagicMock()

import noita_env
importlib.reload(noita_env)

Comment on lines +5 to +15

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major | ⚡ Quick win

Avoid leaking mocked modules into global interpreter state.

make_env() writes directly into sys.modules and never restores it. That can contaminate later tests and create order-dependent failures. Wrap this in patch.dict(...): so mocks are scoped.

Suggested fix
 def make_env():
     import importlib, types, sys
 
-    # Mock problematic modules
-    sys.modules['mss'] = MagicMock()
-    sys.modules['pygetwindow'] = MagicMock()
-    sys.modules['pyrect'] = MagicMock()
-
-    import noita_env
-    importlib.reload(noita_env)
+    mocked_modules = {
+        'mss': MagicMock(),
+        'pygetwindow': MagicMock(),
+        'pyrect': MagicMock(),
+    }
+    with patch.dict(sys.modules, mocked_modules):
+        import noita_env
+        importlib.reload(noita_env)
 
     env = noita_env.NoitaEnv.__new__(noita_env.NoitaEnv)
     env._lock = MagicMock()
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@tests/test_noita_env.py` around lines 5 - 15, The test's make_env() currently
writes mocks directly into sys.modules and never restores them; change it to use
unittest.mock.patch.dict to temporarily inject the mocked modules (keys 'mss',
'pygetwindow', 'pyrect') into sys.modules so they are automatically restored,
perform importlib.reload(noita_env) inside that patch.dict context, and ensure
imports for patch (e.g., from unittest.mock import patch) are present so the
mocking is scoped and doesn't leak global interpreter state.

env = noita_env.NoitaEnv.__new__(noita_env.NoitaEnv)
env._lock = MagicMock()
env._ws = MagicMock()
env._sct = MagicMock()
env._frame_buf = []
env.port = 1234
return env

class TestNoitaEnvMethods:
def test_render_does_not_crash(self):
env = make_env()
# Should do nothing and not crash
env.render()

def test_close_clears_ws(self):
env = make_env()
env._ws = MagicMock()
assert env._ws is not None
env.close()
assert env._ws is None
env._lock.__enter__.assert_called()

def test_handle_exception_path(self):
env = make_env()
env._ws = None
# Create a mock socket
mock_ws = MagicMock()
mock_ws.recv.side_effect = Exception("Connection closed")
Comment on lines +42 to +43

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The test uses MagicMock for mock_ws, but env._handle performs an async for iteration over it. MagicMock does not support async iteration by default, which will cause the test to fail with a TypeError. You should use AsyncMock instead to correctly simulate the websocket's behavior.

Suggested change
mock_ws = MagicMock()
mock_ws.recv.side_effect = Exception("Connection closed")
from unittest.mock import AsyncMock
mock_ws = AsyncMock()
mock_ws.__aiter__.return_value = [Exception("Connection closed")] # Simulate failure during iteration


# Should gracefully exit the loop when exception is caught
asyncio.run(env._handle(mock_ws))
# Should set _ws to None on exit
assert env._ws is None
20 changes: 20 additions & 0 deletions tests/test_notify.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from unittest.mock import patch
"""
Tests for TelegramNotifier.

Expand Down Expand Up @@ -185,3 +186,22 @@ def test_register_stats_provider(self, noop):
noop.register_stats_provider(lambda: "stats text")
assert noop._stats_fn is not None
assert noop._stats_fn() == "stats text"

# ---------------------------------------------------------------------------
# AI Status
# ---------------------------------------------------------------------------

class TestAIStatus:
def test_empty_groq_key(self, noop):
result = noop.generate_ai_status(groq_key="", stats_context="Test stats")
assert result == "⚠️ Groq API key not set. Please set GROQ_API_KEY in .env."

def test_empty_groq_key_none(self, noop):
result = noop.generate_ai_status(groq_key=None, stats_context="Test stats")
assert result == "⚠️ Groq API key not set. Please set GROQ_API_KEY in .env."

class TestSetupBotMenu:
def test_setup_bot_menu_error(self, fake):
with patch('requests.post', side_effect=Exception('Mock API Error')):
# Should not crash
fake.setup_bot_menu()
24 changes: 24 additions & 0 deletions tests/test_offline_analysis.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
import pytest
import os
import json
from unittest.mock import patch, mock_open

from offline_analysis import analyze_actions

def test_analyze_actions():
mock_data = '{"episode": 1, "action": [1, 0, 0, 0], "reward": 10.0, "ep_reward": 10.0, "total_steps": 1}\n'

with patch('builtins.open', mock_open(read_data=mock_data)):
with patch('offline_analysis.os.path.exists', return_value=True):
with patch('offline_analysis.plt') as mock_plt:
analyze_actions()

assert mock_plt.figure.call_count > 0
assert mock_plt.savefig.call_count > 0

def test_analyze_actions_no_file():
with patch('offline_analysis.os.path.exists', return_value=False):
# Should print and return early
with patch('builtins.print') as mock_print:
analyze_actions()
mock_print.assert_called_once()
53 changes: 53 additions & 0 deletions tests/test_train.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
import os
import pytest
from unittest.mock import patch, MagicMock

from config import Config
from train import setup_logging, setup_wandb

@pytest.fixture
def mock_config():
cfg = MagicMock(spec=Config)
cfg.checkpoint_dir = "/tmp/checkpoints"
cfg.tensorboard_dir = "/tmp/tb"
cfg.log_level = "INFO"
cfg.log_dir = "/tmp/logs"
cfg.wandb_project = "noita-rl"
cfg.wandb_entity = "jules"
cfg.wandb_enabled = True
cfg.n_envs = 1
cfg.total_timesteps = 1000
cfg.learning_rate = 0.001
cfg.n_steps = 2048
cfg.batch_size = 64
cfg.n_epochs = 10
cfg.gamma = 0.99
cfg.ent_coef = 0.01
return cfg

def test_setup_logging(tmp_path, mock_config):
mock_config.checkpoint_dir = str(tmp_path / "checkpoints")
mock_config.tensorboard_dir = str(tmp_path / "tb")

with patch("train.logger") as mock_logger, patch("sys.stderr"):
setup_logging(mock_config, "test_run")

# Verify logger.add was called
assert mock_logger.add.call_count > 0

def test_setup_wandb_enabled(mock_config):
with patch("train.wandb.init", create=True) as mock_init:
setup_wandb(mock_config, "test_run")
mock_init.assert_called_once()

def test_setup_wandb_disabled(mock_config):
mock_config.wandb_enabled = False
with patch("train.wandb.init", create=True) as mock_init:
setup_wandb(mock_config, "test_run")
mock_init.assert_not_called()

def test_setup_wandb_error(mock_config):
with patch("train.wandb.init", side_effect=Exception("WandB Error"), create=True) as mock_init, \
patch("train.logger.warning") as mock_warn:
setup_wandb(mock_config, "test_run")
mock_warn.assert_called_once()
27 changes: 27 additions & 0 deletions tests/test_train_multi.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
import pytest
from unittest.mock import patch, MagicMock
from train_multi import find_latest_checkpoint

def test_find_latest_checkpoint_empty_dir(tmp_path):
assert find_latest_checkpoint(str(tmp_path)) is None

def test_find_latest_checkpoint_no_match(tmp_path):
(tmp_path / "other_run_100_steps.txt").touch()
assert find_latest_checkpoint(str(tmp_path)) is None

def test_find_latest_checkpoint_single_match(tmp_path):
f = tmp_path / "test_run_100_steps.zip"
f.touch()
assert find_latest_checkpoint(str(tmp_path)) == str(f)

def test_find_latest_checkpoint_multiple_matches(tmp_path):
f1 = tmp_path / "test_run_100_steps.zip"
f2 = tmp_path / "test_run_200_steps.zip"
f3 = tmp_path / "test_run_50_steps.zip"
import time
f1.touch()
time.sleep(0.01)
f3.touch()
time.sleep(0.01)
f2.touch()
assert find_latest_checkpoint(str(tmp_path)) == str(f2)
Loading
Loading