Advanced: Base Model Class

The BaseODEModel class provides the foundation for adding ODE-based models to mGrowthCtrl. It handles numerical integration, backend selection (NumPy or PyTorch), and provides hooks for control inputs. The CRModel is currently the main concrete subclass (see Consumer-Resource Model).

Overview

BaseODEModel is an abstract base class that:

Example

Custom ODE models can be implemented by subclassing BaseODEModel and defining the system dynamics via compute_rhs().

from mgrowthctrl.models.base import BaseODEModel, ModelNames


class MyModel(BaseODEModel):
    def __init__(self, names: ModelNames, mu_max: float = 0.5, K_s: float = 1.0):
        super().__init__(backend="numpy", names=names)
        self.mu_max = mu_max
        self.K_s = K_s

    def compute_rhs(self, t, X, S):
        """Compute dX/dt and dS/dt."""
        mu = self.mu_max * S[0] / (self.K_s + S[0])

        dX = X * mu
        dS = -10.0 * X * mu

        return dX, dS

    def simulate(self, y0, t):
        """Simulate the model using the NumPy backend."""
        return self.predict(
            y0=y0,
            t_eval=t,
        )

We can now instantiate the model and simulate its dynamics.

import numpy as np
from mgrowthctrl.models.base import ModelNames

names = ModelNames(["biomass"], ["substrate"])

model = MyModel(names=names)

y0 = [0.1, 10.0]

t_eval = np.linspace(0, 50, 100)

sim = model.simulate(y0, t_eval)

Finally, we visualize the resulting trajectories.

import matplotlib.pyplot as plt

plt.figure(figsize=(6, 3.2))
plt.plot(sim.t, sim.X[0, :], label="biomass")
plt.plot(sim.t, sim.S[0, :], label="substrate")
plt.xlabel("time (a.u.)")
plt.ylabel("concentration (a.u.)")
plt.legend()
plt.tight_layout()
plt.savefig("base_class_simple_example.png")
plt.legend()
plt.show()

Here’s what the simulation result looks like:

Base class example

Backend Selection

The default “numpy” backend places the data in numpy arrays and performs operations through those. Using the “torch” backend lets you offload operations on the GPU, for instance.

You can set the backend when constructing a model by providing the string “numpy” or “torch”, or by instantiating one of the two ArrayBackend subclasses from the mgrowthctrl.backends module, documented below.

These classes have a number of public methods that are used to unify handling between torch-based and numpy-based implementations, but these can be considered internal to the workings of the package. If you want to implement your own backend or to reuse the logic in some way, we recommend reading the source code directly.

class mgrowthctrl.backends.array.ArrayBackend(*args, **kwargs)

Bases: Protocol

A shared type that unifies numpy and torch backends. The xp_name attribute determines which of the two backends are instantiated, but ideally, users should use the shared methods.

class mgrowthctrl.backends.array.NumpyBackend(*args, **kwargs)

Bases: ArrayBackend

The default array backend implementation using numpy ndarrays. Calculations are always performed on the CPU.

class mgrowthctrl.backends.array.TorchBackend(device: str | device | None = None, dtype: dtype | None = None)

Bases: ArrayBackend

Arrays implemented as torch tensors. Allows storing data on the GPU by providing a device attribute.

API Documentation

class mgrowthctrl.models.base.ModelNames(X: ~typing.List[str], S: ~typing.List[str], X_err: ~typing.List[str] = <factory>, S_err: ~typing.List[str] = <factory>)

Bases: object

A convenience container for DataFrame column names

X: List[str]

Consumer (strain) column names

S: List[str]

Resource (metabolite) column names

X_err: List[str]

Optional consumer (strain) error column names

S_err: List[str]

Optional resource (metabolite) error column names

class mgrowthctrl.models.base.BaseODEModel(backend: Literal['numpy', 'torch'] | ArrayBackend, names: ModelNames)

Bases: ABC

Base class for microbiome ODE models supporting NumPy+SciPy and PyTorch+torchdiffeq.

Subclasses must implement the method compute_rhs:

compute_rhs(self, t, X, S) -> (dX_dt, dS_dt)

Controller support is intentionally NOT implemented here; keep those in a separate module. Subclasses may optionally override control_rhs to inject exogenous inputs, but the default is no control.

Parameters:
  • backend – The type of arrays the model will work with, either numpy or torch

  • names – Consumer (X) and substrate (S) column names in the data

abstractmethod compute_rhs(t, X, S)

Compute the right-hand side (RHS) of the ODE: dX/dt, dS/dt given X, S using current parameters. Expected to return a tuple (dX_dt, dS_dt) with the same dimensionality as the inputs.

control_rhs(t, X, S)

Optional: override this method to inject control inputs u(t, y) with the same dimensionality as X and S. Expected to return a tuple (u_X, u_S).

The default implementation returns zeroes.

solve(
y0: ndarray | Tensor | list | tuple | float | int,
t_span: Tuple[float, float],
t_eval: ndarray | Tensor | list | tuple | float | int,
*,
log_space: bool = False,
return_raw_tensors: bool = False,
method: str | None = None,
solver_options: Dict | None = None,
control_fn: Callable | None = None,
)

Integrate the ODE system in either normal or log-space for X. The return values are different depending on the backend:

  • NumPy: returns SciPy OdeResult

  • Torch: returns either raw (T, state) tensor or a SimpleNamespace with .t/.y numpy arrays

predict(
y0: ndarray | Tensor | list | tuple | float | int,
t_eval: ndarray | Tensor | list | tuple | float | int,
*,
log_space: bool = False,
method: str | None = None,
solver_options: Dict | None = None,
) SimpleNamespace

Simulate model predictions for the timepoints given in t_eval. Returns a SimpleNamespace object with .t, .y, .X, .S properties as numpy arrays.

evaluate(
y_true: ndarray | Tensor | list | tuple | float | int,
y0: ndarray | Tensor | list | tuple | float | int,
t_eval: ndarray | Tensor | list | tuple | float | int,
*,
metric: Literal['sse', 'mse', 'mae'] | Callable[[ndarray, ndarray], float] = 'sse',
observed: Literal['y', 'X', 'S'] = 'y',
log_space: bool = False,
sample_weight: ndarray | Tensor | list | tuple | float | int | None = None,
) float

Evaluate model fit to time-series data.

Parameters:

Parameters:
  • y_true – Ground-truth trajectory. Accepts (state, T) or (T, state).

  • y0 – Initial state used for prediction (state,).

  • t_eval – Times corresponding to y_true samples (T,).

  • metric – Either a specific metric from the list of options or a callable that takes y_true values, y_pred values and returns the error.

  • observed – Which part of the state to compare.

  • log_space – Integrate in log-space for X (prediction only).

  • sample_weight – optional (T,) weights over time.