diff --git a/movement/roi/base.py b/movement/roi/base.py index e6d273b5c..d7cc24e1e 100644 --- a/movement/roi/base.py +++ b/movement/roi/base.py @@ -4,13 +4,14 @@ from abc import ABC, abstractmethod from collections.abc import Callable, Hashable, Sequence -from typing import TYPE_CHECKING, Any, Generic, TypeAlias, TypeVar, cast +from typing import TYPE_CHECKING, Any, Generic, Self, TypeAlias, TypeVar, cast import matplotlib.pyplot as plt import numpy as np import shapely from shapely.coords import CoordinateSequence +from movement.transforms import compute_homography_transform from movement.utils.broadcasting import broadcastable_method from movement.utils.vector import compute_signed_angle_2d @@ -560,3 +561,35 @@ def plot( if fig is None or ax is None: fig, ax = plt.subplots(1, 1) return self._plot(fig, ax, **matplotlib_kwargs) + + def get_transform(self, other: Self) -> np.ndarray: + """Compute the homography transformation matrix to align with `other`. + + Parameters + ---------- + other : BaseRegionOfInterest + Another region of interest to which this region will be aligned. + It should be of the same type (line or polygon) and have + the same number of defining points. + + Returns + ------- + np.ndarray + A (3, 3) transformation matrix that aligns this region to + the `other` region. + + Raises + ------ + ValueError + If the number of coordinate points does not match + the number of coordinate points in the other region, + or if there are insufficient points to + compute the transformation, + or if the points are not 2-dimensional, + or if the points are degenerate or collinear, + making it impossible to compute a valid homography. + + """ + return compute_homography_transform( + np.array(self.coords), np.array(other.coords) + ) diff --git a/movement/roi/line.py b/movement/roi/line.py index 58e78b0b6..c20828235 100644 --- a/movement/roi/line.py +++ b/movement/roi/line.py @@ -1,5 +1,7 @@ """1-dimensional lines of interest.""" +from typing import Self + import numpy as np import shapely import xarray as xr @@ -179,3 +181,7 @@ def compute_angle_to_normal( ), in_degrees=in_degrees, ) + + def get_transform(self, other: Self) -> np.ndarray: + """Throw error for transformation matrix for lines.""" + raise NotImplementedError("Homography is undefined for LineOfInterest") diff --git a/tests/test_unit/test_roi/test_transform.py b/tests/test_unit/test_roi/test_transform.py new file mode 100644 index 000000000..00873223f --- /dev/null +++ b/tests/test_unit/test_roi/test_transform.py @@ -0,0 +1,93 @@ +import numpy as np +import pytest + +from movement.roi.line import LineOfInterest +from movement.roi.polygon import PolygonOfInterest +from movement.transforms import compute_homography_transform + + +@pytest.mark.parametrize( + ["coords_a", "coords_b"], + [ + pytest.param( + np.array([[1, 1], [5, 1], [5, 3], [1, 3]], dtype=np.float32), + np.array([[1, 1], [5, 1], [5, 3], [1, 3]], dtype=np.float32), + id="Identical rectangles (identity transform)", + ), + pytest.param( + np.array([[0, 0], [1, 0], [1, 1], [0, 1]], dtype=np.float32), + np.array( + [ + [3, -1], + [4.73205081, 0], + [3.98205081, 1.29903811], + [2.25, 0.2990381], + ], + dtype=np.float32, + ), + id="Rotated and scaled square", + ), + ], +) +def test_get_transform_happy_path( + coords_a: np.ndarray, + coords_b: np.ndarray, +) -> None: + roi_a = PolygonOfInterest(coords_a) + roi_b = PolygonOfInterest(coords_b) + + expected_transform = compute_homography_transform(coords_a, coords_b) + + computed_transform = roi_a.get_transform(roi_b) + + assert computed_transform.shape == (3, 3) + assert np.allclose(computed_transform, expected_transform, atol=1e-6) + + +@pytest.mark.parametrize( + ["coords_a", "coords_b"], + [ + pytest.param( + np.array([[0, 0], [1, 0], [1, 1]], dtype=np.float32), + np.array( + [[0, 0], [1, 0], [2, 0.5], [1, 1], [0, 1]], dtype=np.float32 + ), + id="Triangle vs Pentagon", + ), + pytest.param( + np.array([[0, 0], [1, 0], [1, 1], [0, 1]], dtype=np.float32), + np.array([[0, 0], [1, 0], [1, 1]], dtype=np.float32), + id="Quad vs triangle", + ), + ], +) +def test_get_transform_mismatched_points_raises( + coords_a: np.ndarray, + coords_b: np.ndarray, +) -> None: + roi_a = PolygonOfInterest(coords_a) + roi_b = PolygonOfInterest(coords_b) + + with pytest.raises(ValueError): + roi_a.get_transform(roi_b) + + +@pytest.mark.parametrize( + ["coords_a", "coords_b"], + [ + pytest.param( + np.array([[0, 0], [1, 1]], dtype=np.float32), + np.array([[0, 0], [1, 2]], dtype=np.float32), + id="2 lines", + ) + ], +) +def test_line_of_interest_raises( + coords_a: np.ndarray, + coords_b: np.ndarray, +): + line_1 = LineOfInterest(points=coords_a) + line_2 = LineOfInterest(points=coords_b) + + with pytest.raises(NotImplementedError): + line_1.get_transform(line_2)