Source code for ewoksxas.converters.tests.test_larch

"""Tests for the Larch group converters."""

import numpy as np
import pytest

from ewoksxas.converters.larch import create_groups, get_attribute_values


[docs] def test_1d_x_and_1d_y(): """Test that matching 1D x and y produce a single group.""" x = np.array([1.0, 2.0, 3.0]) y = np.array([0.1, 0.2, 0.3]) groups = create_groups(x, y, mode="k") assert len(groups) == 1 assert np.array_equal(groups[0].k, x) assert np.array_equal(groups[0].chi, y)
[docs] def test_1d_x_and_2d_y_shared_axis(): """Test that a 1D x is shared across all rows of a 2D y.""" x = np.array([1.0, 2.0, 3.0]) y = np.array([[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]]) groups = create_groups(x, y, mode="k") assert len(groups) == 2 for group, yi in zip(groups, y, strict=True): assert np.array_equal(group.k, x) assert np.array_equal(group.chi, yi)
[docs] def test_2d_x_and_2d_y_paired(): """Test that matching 2D x and y are paired row by row.""" x = np.array([[1.0, 2.0, 3.0], [1.5, 2.5, 3.5]]) y = np.array([[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]]) groups = create_groups(x, y, mode="k") assert len(groups) == 2 for group, xi, yi in zip(groups, x, y, strict=True): assert np.array_equal(group.k, xi) assert np.array_equal(group.chi, yi)
[docs] def test_mode_energy_attribute_names(): """Test that mode 'energy' writes energy and mu attributes.""" x = np.array([7000.0, 7001.0, 7002.0]) y = np.array([0.1, 0.2, 0.3]) groups = create_groups(x, y, mode="energy") assert hasattr(groups[0], "energy") assert hasattr(groups[0], "mu") assert np.array_equal(groups[0].energy, x) assert np.array_equal(groups[0].mu, y)
[docs] def test_mode_r_attribute_names(): """Test that mode 'r' writes r and chir attributes.""" x = np.array([0.0, 0.1, 0.2]) y = np.array([1.0, 2.0, 3.0]) groups = create_groups(x, y, mode="r") assert hasattr(groups[0], "r") assert hasattr(groups[0], "chir") assert np.array_equal(groups[0].r, x) assert np.array_equal(groups[0].chir, y)
[docs] def test_non_array_x_raises(): """Test that a non-NumPy x raises TypeError.""" x = [1.0, 2.0, 3.0] y = np.array([0.1, 0.2, 0.3]) with pytest.raises(TypeError, match="Expected NumPy array for x"): create_groups(x, y, mode="k") # type: ignore[arg-type]
[docs] def test_non_array_y_raises(): """Test that a non-NumPy y raises TypeError.""" x = np.array([1.0, 2.0, 3.0]) y = [0.1, 0.2, 0.3] with pytest.raises(TypeError, match="Expected NumPy array for y"): create_groups(x, y, mode="k") # type: ignore[arg-type]
[docs] def test_invalid_mode_raises(): """Test that an unknown mode raises ValueError.""" x = np.array([1.0, 2.0, 3.0]) y = np.array([0.1, 0.2, 0.3]) with pytest.raises(ValueError, match="The mode should be one of"): create_groups(x, y, mode="invalid")
[docs] def test_shape_mismatch_1d_1d_raises(): """Test that mismatched 1D shapes raise ValueError.""" x = np.array([1.0, 2.0, 3.0]) y = np.array([0.1, 0.2]) with pytest.raises(ValueError, match="incompatible"): create_groups(x, y, mode="k")
[docs] def test_shape_mismatch_1d_2d_raises(): """Test that a 1D x incompatible with the 2D y columns raises ValueError.""" x = np.array([1.0, 2.0, 3.0]) y = np.array([[0.1, 0.2], [0.3, 0.4]]) with pytest.raises(ValueError, match="incompatible"): create_groups(x, y, mode="k")
[docs] def test_shape_mismatch_2d_2d_raises(): """Test that mismatched 2D shapes raise ValueError.""" x = np.array([[1.0, 2.0, 3.0]]) y = np.array([[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]]) with pytest.raises(ValueError, match="incompatible"): create_groups(x, y, mode="k")
[docs] def test_kwargs_forwarded_to_groups(): """Test that extra kwargs are forwarded to every group.""" x = np.array([1.0, 2.0, 3.0]) y = np.array([[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]]) groups = create_groups(x, y, mode="k", label="sample") assert len(groups) == 2 for group in groups: assert group.label == "sample"
[docs] def test_energy_kev_converted_to_ev(): """Test that energy values in keV are detected and converted to eV.""" energy_kev = np.linspace(7.0, 7.1, 101) mu = np.linspace(0.0, 1.0, 101) groups = create_groups(energy_kev, mu, mode="energy") assert np.allclose(groups[0].energy, energy_kev * 1000)
[docs] def test_energy_ev_not_converted(): """Test that energy values already in eV are left unchanged.""" energy_ev = np.linspace(7000.0, 7100.0, 101) mu = np.linspace(0.0, 1.0, 101) groups = create_groups(energy_ev, mu, mode="energy") assert np.array_equal(groups[0].energy, energy_ev)
[docs] def test_get_attribute_values_single_attribute(): """Test extracting one attribute from multiple groups.""" x = np.array([1.0, 2.0, 3.0]) y = np.array([[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]]) groups = create_groups(x, y, mode="k") result = get_attribute_values(groups, ["chi"]) assert list(result.keys()) == ["chi"] assert np.array_equal(result["chi"], y)
[docs] def test_get_attribute_values_multiple_attributes(): """Test extracting several attributes from multiple groups at once.""" x = np.array([1.0, 2.0, 3.0]) y = np.array([[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]]) groups = create_groups(x, y, mode="k") result = get_attribute_values(groups, ["k", "chi"]) assert np.array_equal(result["k"], np.array([x, x])) assert np.array_equal(result["chi"], y)
[docs] def test_get_attribute_values_string_attribute(): """Test that a bare string is accepted as the attributes argument.""" x = np.array([1.0, 2.0, 3.0]) y = np.array([[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]]) groups = create_groups(x, y, mode="k") result = get_attribute_values(groups, "chi") # type: ignore[arg-type] assert list(result.keys()) == ["chi"] assert np.array_equal(result["chi"], y)
[docs] def test_get_attribute_values_single_group(): """Test extracting attributes from a list with just one group.""" x = np.array([1.0, 2.0, 3.0]) y = np.array([0.1, 0.2, 0.3]) groups = create_groups(x, y, mode="k") result = get_attribute_values(groups, ["k", "chi"]) assert result["k"].shape == (1, 3) assert result["chi"].shape == (1, 3) assert np.array_equal(result["k"][0], x) assert np.array_equal(result["chi"][0], y)
[docs] def test_energy_kev_descending_converted_to_ev(): """Test that a descending energy array in keV is correctly detected.""" energy_kev = np.linspace(7.1, 7.0, 101) mu = np.linspace(0.0, 1.0, 101) groups = create_groups(energy_kev, mu, mode="energy") assert np.allclose(groups[0].energy, energy_kev * 1000)
[docs] def test_energy_length_1_no_warning(): """Test that an energy array of length 1 doesn't raise a warning/error.""" energy_ev = np.array([7000.0]) mu = np.array([0.5]) groups = create_groups(energy_ev, mu, mode="energy") assert np.array_equal(groups[0].energy, energy_ev)