laser.cohorts.statearray
statearray
StateArray: a NumPy ndarray subclass with named compartment access.
StateArray
Bases: ndarray
A numpy array wrapper that provides attribute access to state compartments.
This class allows accessing state compartments by name (e.g., states.S, states.I, states.R) while maintaining full numpy array functionality and backward compatibility with numeric indexing (e.g., states[0], states[1]).
Example
states = StateArray(source_array=np.zeros((3, 100)), state_names=["S", "I", "R"], state_axis=0) states.S[0] = 1000 # Set susceptible population in patch 0 prevalence = states.I / states.sum(axis=0) # Calculate prevalence states[0] += births # Numeric indexing still works N = states.sum(axis=states.state_axis) # Sum over state axis to get total population per patch
state_axis
property
state_axis
Return the axis index along which state compartments are stored.
Returns:
| Name | Type | Description |
|---|---|---|
int |
int
|
Zero-based axis index for the state dimension. |
Raises:
| Type | Description |
|---|---|
RuntimeError
|
If |
state_names
property
state_names
Return the tuple of registered state compartment names.
Returns:
| Type | Description |
|---|---|
tuple[str, ...] | None
|
tuple[str, ...] | None: Compartment names in axis order, or |
get_state_index
get_state_index(name)
Return the numeric axis index for a named state compartment.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
name
|
str
|
State compartment name to look up. |
required |
Returns:
| Type | Description |
|---|---|
int | None
|
int | None: Zero-based index of |
get_state_mask
get_state_mask(states)
Return a boolean mask selecting the specified state compartments.
The returned array has length equal to the number of registered states
and is True at each position corresponding to a named state in
states. Useful for vectorised operations that apply to a subset
of compartments (e.g. mortality restricted to ["S", "I"]).
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
states
|
str | list[str]
|
A single state name or a list of state names to include in the mask. |
required |
Returns:
| Type | Description |
|---|---|
ndarray
|
np.ndarray: Boolean array of length |
Raises:
| Type | Description |
|---|---|
ValueError
|
If |
ValueError
|
If any name in |
Example
sa = StateArray(["S", "I", "R"], 0, shape=(3, 10)) sa.get_state_mask("S") array([ True, False, False]) sa.get_state_mask(["S", "R"]) array([ True, False, True])