Advanced: Base Controller
If you’d like to build your own custom controller by using the tools in the package, you can use the BaseController class as a starting point. Please refer to ClosedLoopNeuralController and OpenLoopNeuralController as example implementation of the base controller’s methods.
API Documentation
- mgrowthctrl.controllers.base.LossFunction
Type alias for loss functions. “LossFunction” is a function that takes (Trajectory, Controller, Model) and returns (Loss, AuxData)
alias of
Callable[[Tensor,BaseController,BaseODEModel],Tuple[Tensor,Any]]
- class mgrowthctrl.controllers.base.BaseController(
- criterion: Callable[[Tensor, BaseController, BaseODEModel], Tuple[Tensor, Any]] = None,
Bases:
Module,ABCController baseclass.
- criterion: Callable[[Tensor, BaseController, BaseODEModel], Tuple[Tensor, Any]]
Used as a loss function. As an alternative, you can override the
compute_lossmethod instead and leave this to its default of None.
- x_idx: List[int]
Indices of species to control (add/inject). Default is None (control no species).
- s_idx: List[int]
Indices of substrates to control (add/inject).
If None: controls ALL substrates.
If []: controls NO substrates.
If [i, j]: controls specific substrates i and j.
- n_species: int
Number of species (X)
- n_resources: int
Number of resources (S)
- abstractmethod forward(
- model: BaseODEModel,
- t: Tensor,
- X: Tensor,
- S: Tensor,
Define the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Moduleinstance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- abstractmethod get_input_history(trajectory: Tensor, model: BaseODEModel) Tensor
Reconstructs the full control schedule u(t). Subclasses should implement this.
- predict_u(t, X, S, model) Tensor
Optional method to return raw control signal ‘u’ instead of dXc, dSc. Useful for logging.
- compute_loss(trajectory: Tensor, model: BaseODEModel) Tuple[Tensor, Any]
Delegates the loss calculation to the user-provided function:
self.criterion
- simulate(
- model: BaseODEModel,
- y0: Any,
- t_eval: Any,
- t_span: Tuple[float, float] | None = None,
- **kwargs,
Simulate the closed-loop system (Controller + Model).
- Parameters:
y0 – Starting state
t_eval – Time values in an array-like input
t_span – Start and end time values to simulate, defaults to first and last values of
t_eval
- Returns:
SimpleNamespacewith .t, .Y, .X, .S, properties and a .raw property that contains the raw solution result given by the model.
- fit(
- model: BaseODEModel,
- y0: Tensor,
- t_eval: Any,
- t_span: Tuple[float, float] | None = None,
- epochs: int = 100,
- lr: float = 0.01,
- log_fn: Callable[[BaseODEModel, int, Tensor, Tensor, Tensor], None] | None = None,
Train the controller against the model using a simple Adam optimizer loop.
- Parameters:
y0 – Starting state
t_eval – Time values in an array-like input
t_span – Start and end time values to solve for, defaults to first and last values of
t_evalepochs – Number of epochs to train
lr – passed to the Adam optimizer from torch
log_fn – Logging function that can be used for debugging purposes. Called on every loop with (model, epoch, loss, U_trajectory, sol_tensor).