Advanced: Base Controller

If you’d like to build your own custom controller by using the tools in the package, you can use the BaseController class as a starting point. Please refer to ClosedLoopNeuralController and OpenLoopNeuralController as example implementation of the base controller’s methods.

API Documentation

mgrowthctrl.controllers.base.LossFunction

Type alias for loss functions. “LossFunction” is a function that takes (Trajectory, Controller, Model) and returns (Loss, AuxData)

alias of Callable[[Tensor, BaseController, BaseODEModel], Tuple[Tensor, Any]]

class mgrowthctrl.controllers.base.BaseController(
criterion: Callable[[Tensor, BaseController, BaseODEModel], Tuple[Tensor, Any]] = None,
)

Bases: Module, ABC

Controller baseclass.

criterion: Callable[[Tensor, BaseController, BaseODEModel], Tuple[Tensor, Any]]

Used as a loss function. As an alternative, you can override the compute_loss method instead and leave this to its default of None.

x_idx: List[int]

Indices of species to control (add/inject). Default is None (control no species).

s_idx: List[int]

Indices of substrates to control (add/inject).

  • If None: controls ALL substrates.

  • If []: controls NO substrates.

  • If [i, j]: controls specific substrates i and j.

n_species: int

Number of species (X)

n_resources: int

Number of resources (S)

abstractmethod forward(
model: BaseODEModel,
t: Tensor,
X: Tensor,
S: Tensor,
) Tuple[Tensor, Tensor]

Define the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

abstractmethod get_input_history(trajectory: Tensor, model: BaseODEModel) Tensor

Reconstructs the full control schedule u(t). Subclasses should implement this.

predict_u(t, X, S, model) Tensor

Optional method to return raw control signal ‘u’ instead of dXc, dSc. Useful for logging.

compute_loss(trajectory: Tensor, model: BaseODEModel) Tuple[Tensor, Any]

Delegates the loss calculation to the user-provided function: self.criterion

simulate(
model: BaseODEModel,
y0: Any,
t_eval: Any,
t_span: Tuple[float, float] | None = None,
**kwargs,
)

Simulate the closed-loop system (Controller + Model).

Parameters:
  • y0 – Starting state

  • t_eval – Time values in an array-like input

  • t_span – Start and end time values to simulate, defaults to first and last values of t_eval

Returns:

SimpleNamespace with .t, .Y, .X, .S, properties and a .raw property that contains the raw solution result given by the model.

fit(
model: BaseODEModel,
y0: Tensor,
t_eval: Any,
t_span: Tuple[float, float] | None = None,
epochs: int = 100,
lr: float = 0.01,
log_fn: Callable[[BaseODEModel, int, Tensor, Tensor, Tensor], None] | None = None,
)

Train the controller against the model using a simple Adam optimizer loop.

Parameters:
  • y0 – Starting state

  • t_eval – Time values in an array-like input

  • t_span – Start and end time values to solve for, defaults to first and last values of t_eval

  • epochs – Number of epochs to train

  • lr – passed to the Adam optimizer from torch

  • log_fn – Logging function that can be used for debugging purposes. Called on every loop with (model, epoch, loss, U_trajectory, sol_tensor).