Skip to content

laser.cohorts.statearray

statearray

StateArray: a NumPy ndarray subclass with named compartment access.

StateArray

Bases: ndarray

A numpy array wrapper that provides attribute access to state compartments.

This class allows accessing state compartments by name (e.g., states.S, states.I, states.R) while maintaining full numpy array functionality and backward compatibility with numeric indexing (e.g., states[0], states[1]).

Example

states = StateArray(source_array=np.zeros((3, 100)), state_names=["S", "I", "R"], state_axis=0) states.S[0] = 1000 # Set susceptible population in patch 0 prevalence = states.I / states.sum(axis=0) # Calculate prevalence states[0] += births # Numeric indexing still works N = states.sum(axis=states.state_axis) # Sum over state axis to get total population per patch

state_axis property

state_axis

Return the axis index along which state compartments are stored.

Returns:

Name Type Description
int int

Zero-based axis index for the state dimension.

Raises:

Type Description
RuntimeError

If _state_axis is None, which occurs when the instance was created via view casting without metadata.

state_names property

state_names

Return the tuple of registered state compartment names.

Returns:

Type Description
tuple[str, ...] | None

tuple[str, ...] | None: Compartment names in axis order, or None if the instance was created via view casting without metadata.

get_state_index

get_state_index(name)

Return the numeric axis index for a named state compartment.

Parameters:

Name Type Description Default
name str

State compartment name to look up.

required

Returns:

Type Description
int | None

int | None: Zero-based index of name along the state axis, or None if name is not registered or state metadata is absent.

get_state_mask

get_state_mask(states)

Return a boolean mask selecting the specified state compartments.

The returned array has length equal to the number of registered states and is True at each position corresponding to a named state in states. Useful for vectorised operations that apply to a subset of compartments (e.g. mortality restricted to ["S", "I"]).

Parameters:

Name Type Description Default
states str | list[str]

A single state name or a list of state names to include in the mask.

required

Returns:

Type Description
ndarray

np.ndarray: Boolean array of length n_states (the size of the state axis) with True at each index corresponding to a state in states and False elsewhere.

Raises:

Type Description
ValueError

If states is neither a string nor a list.

ValueError

If any name in states is not a registered state.

Example

sa = StateArray(["S", "I", "R"], 0, shape=(3, 10)) sa.get_state_mask("S") array([ True, False, False]) sa.get_state_mask(["S", "R"]) array([ True, False, True])