Advanced: Base Model Class
The BaseODEModel class provides the foundation for adding ODE-based models to mGrowthCtrl. It handles numerical integration, backend selection (NumPy or PyTorch), and provides hooks for control inputs. The CRModel is currently the main concrete subclass (see Consumer-Resource Model).
Overview
BaseODEModel is an abstract base class that:
Manages species (\(X\)) and metabolite (\(S\)) state vectors via
ModelNamesSupports both NumPy (scipy.integrate) and PyTorch (torchdiffeq) backends
Provides a unified
predict()interface for simulationAllows plugging in controllers via the
control_rhs()hook (see Introduction)
Example
Custom ODE models can be implemented by subclassing BaseODEModel and defining the system dynamics via compute_rhs().
from mgrowthctrl.models.base import BaseODEModel, ModelNames
class MyModel(BaseODEModel):
def __init__(self, names: ModelNames, mu_max: float = 0.5, K_s: float = 1.0):
super().__init__(backend="numpy", names=names)
self.mu_max = mu_max
self.K_s = K_s
def compute_rhs(self, t, X, S):
"""Compute dX/dt and dS/dt."""
mu = self.mu_max * S[0] / (self.K_s + S[0])
dX = X * mu
dS = -10.0 * X * mu
return dX, dS
def simulate(self, y0, t):
"""Simulate the model using the NumPy backend."""
return self.predict(
y0=y0,
t_eval=t,
)
We can now instantiate the model and simulate its dynamics.
import numpy as np
from mgrowthctrl.models.base import ModelNames
names = ModelNames(["biomass"], ["substrate"])
model = MyModel(names=names)
y0 = [0.1, 10.0]
t_eval = np.linspace(0, 50, 100)
sim = model.simulate(y0, t_eval)
Finally, we visualize the resulting trajectories.
import matplotlib.pyplot as plt
plt.figure(figsize=(6, 3.2))
plt.plot(sim.t, sim.X[0, :], label="biomass")
plt.plot(sim.t, sim.S[0, :], label="substrate")
plt.xlabel("time (a.u.)")
plt.ylabel("concentration (a.u.)")
plt.legend()
plt.tight_layout()
plt.savefig("base_class_simple_example.png")
plt.legend()
plt.show()
Here’s what the simulation result looks like:
Backend Selection
The default “numpy” backend places the data in numpy arrays and performs operations through those. Using the “torch” backend lets you offload operations on the GPU, for instance.
You can set the backend when constructing a model by providing the string “numpy” or “torch”, or by instantiating one of the two ArrayBackend subclasses from the mgrowthctrl.backends module, documented below.
These classes have a number of public methods that are used to unify handling between torch-based and numpy-based implementations, but these can be considered internal to the workings of the package. If you want to implement your own backend or to reuse the logic in some way, we recommend reading the source code directly.
- class mgrowthctrl.backends.array.ArrayBackend(*args, **kwargs)
Bases:
ProtocolA shared type that unifies numpy and torch backends. The
xp_nameattribute determines which of the two backends are instantiated, but ideally, users should use the shared methods.
- class mgrowthctrl.backends.array.NumpyBackend(*args, **kwargs)
Bases:
ArrayBackendThe default array backend implementation using numpy ndarrays. Calculations are always performed on the CPU.
- class mgrowthctrl.backends.array.TorchBackend(device: str | device | None = None, dtype: dtype | None = None)
Bases:
ArrayBackendArrays implemented as torch tensors. Allows storing data on the GPU by providing a
deviceattribute.
API Documentation
- class mgrowthctrl.models.base.ModelNames(X: ~typing.List[str], S: ~typing.List[str], X_err: ~typing.List[str] = <factory>, S_err: ~typing.List[str] = <factory>)
Bases:
objectA convenience container for DataFrame column names
- X: List[str]
Consumer (strain) column names
- S: List[str]
Resource (metabolite) column names
- X_err: List[str]
Optional consumer (strain) error column names
- S_err: List[str]
Optional resource (metabolite) error column names
- class mgrowthctrl.models.base.BaseODEModel(backend: Literal['numpy', 'torch'] | ArrayBackend, names: ModelNames)
Bases:
ABCBase class for microbiome ODE models supporting NumPy+SciPy and PyTorch+torchdiffeq.
Subclasses must implement the method compute_rhs:
compute_rhs(self, t, X, S) -> (dX_dt, dS_dt)
Controller support is intentionally NOT implemented here; keep those in a separate module. Subclasses may optionally override control_rhs to inject exogenous inputs, but the default is no control.
- Parameters:
backend – The type of arrays the model will work with, either numpy or torch
names – Consumer (X) and substrate (S) column names in the data
- abstractmethod compute_rhs(t, X, S)
Compute the right-hand side (RHS) of the ODE: dX/dt, dS/dt given X, S using current parameters. Expected to return a tuple (dX_dt, dS_dt) with the same dimensionality as the inputs.
- control_rhs(t, X, S)
Optional: override this method to inject control inputs u(t, y) with the same dimensionality as X and S. Expected to return a tuple (u_X, u_S).
The default implementation returns zeroes.
- solve(
- y0: ndarray | Tensor | list | tuple | float | int,
- t_span: Tuple[float, float],
- t_eval: ndarray | Tensor | list | tuple | float | int,
- *,
- log_space: bool = False,
- return_raw_tensors: bool = False,
- method: str | None = None,
- solver_options: Dict | None = None,
- control_fn: Callable | None = None,
Integrate the ODE system in either normal or log-space for X. The return values are different depending on the backend:
NumPy: returns SciPy OdeResult
Torch: returns either raw (T, state) tensor or a SimpleNamespace with .t/.y numpy arrays
- predict(
- y0: ndarray | Tensor | list | tuple | float | int,
- t_eval: ndarray | Tensor | list | tuple | float | int,
- *,
- log_space: bool = False,
- method: str | None = None,
- solver_options: Dict | None = None,
Simulate model predictions for the timepoints given in
t_eval. Returns aSimpleNamespaceobject with .t, .y, .X, .S properties as numpy arrays.
- evaluate(
- y_true: ndarray | Tensor | list | tuple | float | int,
- y0: ndarray | Tensor | list | tuple | float | int,
- t_eval: ndarray | Tensor | list | tuple | float | int,
- *,
- metric: Literal['sse', 'mse', 'mae'] | Callable[[ndarray, ndarray], float] = 'sse',
- observed: Literal['y', 'X', 'S'] = 'y',
- log_space: bool = False,
- sample_weight: ndarray | Tensor | list | tuple | float | int | None = None,
Evaluate model fit to time-series data.
Parameters:
- Parameters:
y_true – Ground-truth trajectory. Accepts (state, T) or (T, state).
y0 – Initial state used for prediction (state,).
t_eval – Times corresponding to y_true samples (T,).
metric – Either a specific metric from the list of options or a callable that takes
y_truevalues,y_predvalues and returns the error.observed – Which part of the state to compare.
log_space – Integrate in log-space for X (prediction only).
sample_weight – optional (T,) weights over time.