"""Tests for the Larch group converters."""
import numpy as np
import pytest
from ewoksxas.converters.larch import create_groups, get_attribute_values
[docs]
def test_1d_x_and_1d_y():
"""Test that matching 1D x and y produce a single group."""
x = np.array([1.0, 2.0, 3.0])
y = np.array([0.1, 0.2, 0.3])
groups = create_groups(x, y, mode="k")
assert len(groups) == 1
assert np.array_equal(groups[0].k, x)
assert np.array_equal(groups[0].chi, y)
[docs]
def test_1d_x_and_2d_y_shared_axis():
"""Test that a 1D x is shared across all rows of a 2D y."""
x = np.array([1.0, 2.0, 3.0])
y = np.array([[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]])
groups = create_groups(x, y, mode="k")
assert len(groups) == 2
for group, yi in zip(groups, y, strict=True):
assert np.array_equal(group.k, x)
assert np.array_equal(group.chi, yi)
[docs]
def test_2d_x_and_2d_y_paired():
"""Test that matching 2D x and y are paired row by row."""
x = np.array([[1.0, 2.0, 3.0], [1.5, 2.5, 3.5]])
y = np.array([[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]])
groups = create_groups(x, y, mode="k")
assert len(groups) == 2
for group, xi, yi in zip(groups, x, y, strict=True):
assert np.array_equal(group.k, xi)
assert np.array_equal(group.chi, yi)
[docs]
def test_mode_energy_attribute_names():
"""Test that mode 'energy' writes energy and mu attributes."""
x = np.array([7000.0, 7001.0, 7002.0])
y = np.array([0.1, 0.2, 0.3])
groups = create_groups(x, y, mode="energy")
assert hasattr(groups[0], "energy")
assert hasattr(groups[0], "mu")
assert np.array_equal(groups[0].energy, x)
assert np.array_equal(groups[0].mu, y)
[docs]
def test_mode_r_attribute_names():
"""Test that mode 'r' writes r and chir attributes."""
x = np.array([0.0, 0.1, 0.2])
y = np.array([1.0, 2.0, 3.0])
groups = create_groups(x, y, mode="r")
assert hasattr(groups[0], "r")
assert hasattr(groups[0], "chir")
assert np.array_equal(groups[0].r, x)
assert np.array_equal(groups[0].chir, y)
[docs]
def test_non_array_x_raises():
"""Test that a non-NumPy x raises TypeError."""
x = [1.0, 2.0, 3.0]
y = np.array([0.1, 0.2, 0.3])
with pytest.raises(TypeError, match="Expected NumPy array for x"):
create_groups(x, y, mode="k") # type: ignore[arg-type]
[docs]
def test_non_array_y_raises():
"""Test that a non-NumPy y raises TypeError."""
x = np.array([1.0, 2.0, 3.0])
y = [0.1, 0.2, 0.3]
with pytest.raises(TypeError, match="Expected NumPy array for y"):
create_groups(x, y, mode="k") # type: ignore[arg-type]
[docs]
def test_invalid_mode_raises():
"""Test that an unknown mode raises ValueError."""
x = np.array([1.0, 2.0, 3.0])
y = np.array([0.1, 0.2, 0.3])
with pytest.raises(ValueError, match="The mode should be one of"):
create_groups(x, y, mode="invalid")
[docs]
def test_shape_mismatch_1d_1d_raises():
"""Test that mismatched 1D shapes raise ValueError."""
x = np.array([1.0, 2.0, 3.0])
y = np.array([0.1, 0.2])
with pytest.raises(ValueError, match="incompatible"):
create_groups(x, y, mode="k")
[docs]
def test_shape_mismatch_1d_2d_raises():
"""Test that a 1D x incompatible with the 2D y columns raises ValueError."""
x = np.array([1.0, 2.0, 3.0])
y = np.array([[0.1, 0.2], [0.3, 0.4]])
with pytest.raises(ValueError, match="incompatible"):
create_groups(x, y, mode="k")
[docs]
def test_shape_mismatch_2d_2d_raises():
"""Test that mismatched 2D shapes raise ValueError."""
x = np.array([[1.0, 2.0, 3.0]])
y = np.array([[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]])
with pytest.raises(ValueError, match="incompatible"):
create_groups(x, y, mode="k")
[docs]
def test_kwargs_forwarded_to_groups():
"""Test that extra kwargs are forwarded to every group."""
x = np.array([1.0, 2.0, 3.0])
y = np.array([[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]])
groups = create_groups(x, y, mode="k", label="sample")
assert len(groups) == 2
for group in groups:
assert group.label == "sample"
[docs]
def test_energy_kev_converted_to_ev():
"""Test that energy values in keV are detected and converted to eV."""
energy_kev = np.linspace(7.0, 7.1, 101)
mu = np.linspace(0.0, 1.0, 101)
groups = create_groups(energy_kev, mu, mode="energy")
assert np.allclose(groups[0].energy, energy_kev * 1000)
[docs]
def test_energy_ev_not_converted():
"""Test that energy values already in eV are left unchanged."""
energy_ev = np.linspace(7000.0, 7100.0, 101)
mu = np.linspace(0.0, 1.0, 101)
groups = create_groups(energy_ev, mu, mode="energy")
assert np.array_equal(groups[0].energy, energy_ev)
[docs]
def test_get_attribute_values_single_attribute():
"""Test extracting one attribute from multiple groups."""
x = np.array([1.0, 2.0, 3.0])
y = np.array([[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]])
groups = create_groups(x, y, mode="k")
result = get_attribute_values(groups, ["chi"])
assert list(result.keys()) == ["chi"]
assert np.array_equal(result["chi"], y)
[docs]
def test_get_attribute_values_multiple_attributes():
"""Test extracting several attributes from multiple groups at once."""
x = np.array([1.0, 2.0, 3.0])
y = np.array([[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]])
groups = create_groups(x, y, mode="k")
result = get_attribute_values(groups, ["k", "chi"])
assert np.array_equal(result["k"], np.array([x, x]))
assert np.array_equal(result["chi"], y)
[docs]
def test_get_attribute_values_string_attribute():
"""Test that a bare string is accepted as the attributes argument."""
x = np.array([1.0, 2.0, 3.0])
y = np.array([[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]])
groups = create_groups(x, y, mode="k")
result = get_attribute_values(groups, "chi") # type: ignore[arg-type]
assert list(result.keys()) == ["chi"]
assert np.array_equal(result["chi"], y)
[docs]
def test_get_attribute_values_single_group():
"""Test extracting attributes from a list with just one group."""
x = np.array([1.0, 2.0, 3.0])
y = np.array([0.1, 0.2, 0.3])
groups = create_groups(x, y, mode="k")
result = get_attribute_values(groups, ["k", "chi"])
assert result["k"].shape == (1, 3)
assert result["chi"].shape == (1, 3)
assert np.array_equal(result["k"][0], x)
assert np.array_equal(result["chi"][0], y)
[docs]
def test_energy_kev_descending_converted_to_ev():
"""Test that a descending energy array in keV is correctly detected."""
energy_kev = np.linspace(7.1, 7.0, 101)
mu = np.linspace(0.0, 1.0, 101)
groups = create_groups(energy_kev, mu, mode="energy")
assert np.allclose(groups[0].energy, energy_kev * 1000)
[docs]
def test_energy_length_1_no_warning():
"""Test that an energy array of length 1 doesn't raise a warning/error."""
energy_ev = np.array([7000.0])
mu = np.array([0.5])
groups = create_groups(energy_ev, mu, mode="energy")
assert np.array_equal(groups[0].energy, energy_ev)