Closed-Loop Controller

The ClosedLoopNeuralController provides methods to learn state-dependent control inputs \(u(X, S)\).

Example

We first define a one-species, one-resource consumer-resource model and assign its parameters.

from mgrowthctrl.models import CRModel, CRModelParams
from mgrowthctrl.models.base import ModelNames

params = CRModelParams.from_shapes(n=1, m=1)
names = ModelNames(["biomass"], ["substrate"])
model = CRModel(names=names, params=params, backend="torch")

model.r[0, 0] = 0.5
model.a[0, 0] = 1.0
model.k[0] = 0.01

The controller is trained to reach a target final biomass while penalizing control effort.

import numpy as np
import torch

torch.manual_seed(42)
np.random.seed(42)

def regulation_loss(trajectory, controller, model):
    target = torch.tensor(1e1, dtype=torch.float32)

    biomass = trajectory[:, 0]
    tracking_error = (
        torch.log(biomass + 1e-6) - torch.log(target + 1e-6)
    ).pow(2).mean()

    U_tensor = controller.net(torch.log(trajectory + 1e-6))
    effort = 1e-3 * U_tensor.pow(2).mean()

    return tracking_error + effort, U_tensor

We log-transform the input to controller.net() because bacterial counts and metabolite concentrations can span several orders of magnitude.

Next, we initialize the closed-loop neural controller. We assume that the controller has access to the full state (i.e., all bacterial counts and metabolite concentrations). Partial observability can be specified via the obs_x_idx and obs_s_idx arguments.

from torch import nn
from mgrowthctrl.controllers import ClosedLoopNeuralController

controller = ClosedLoopNeuralController(
    state_dim=model.n + model.m,
    n_species=model.n,
    n_resources=model.m,
    s_idx=[0],
    criterion=regulation_loss,
    hidden_dims=[5, 5, 5],
    activation=nn.ELU(),
)

We define a logging function, set the initial condition, and train the controller.

def log_fn(model, epoch, loss, U_trajectory, sol_tensor):
    if epoch % 1 == 0:
        final_biomass = sol_tensor[-1, 0].item()
        mean_input = float(U_trajectory.abs().max().item())
        print(
            f"Epoch {epoch:03d} | Loss: {loss.item():.4f} "
            f"| Final biomass: {final_biomass:.1e} "
            f"| Mean |u|: {mean_input:.1e}"
        )

y0 = torch.tensor([0.1, 10.0], dtype=torch.float32)
t_eval = np.linspace(0, 100, 100)
t_span = (0, 100)

controller.fit(model, y0, t_eval, t_span, epochs=100, lr=3e-2, log_fn=log_fn)

We visually compare the uncontrolled and controlled biomass trajectories.

import matplotlib.pyplot as plt

t_eval = np.linspace(0, 200, 200)
t_span = (0, 200)
uncontrolled_sol = model.predict(y0.numpy(), t_eval)
controlled_sol = controller.simulate(model, y0, t_eval, t_span)

fig, ax = plt.subplots(figsize=(6, 3.2))
ax.plot(uncontrolled_sol.t, uncontrolled_sol.y[0], label="Uncontrolled")
ax.plot(controlled_sol.t, controlled_sol.y[0], label="Controlled")
ax.hlines(1e1, 0, 200, ls="--", color="red", label="Target")
ax.set_xlabel("Time (h)")
ax.set_ylabel("Biomass")
ax.legend()
plt.tight_layout()
plt.savefig("basic_closed_loop_result.png")
plt.show()

We observe that the controlled dynamics reaches the desired concentration:

Closed-loop control predictions

Finally, we can inspect the learned control input over time.

raw_u = controller.get_input_history(torch.tensor(controlled_sol.y.T), model).detach()

fig, ax = plt.subplots(figsize=(6, 3.2))
ax.plot(t_eval, raw_u[:, 0], label="substrate")
ax.set_xlabel("Time (h)")
ax.set_ylabel("Injection rate (mM/h)")
ax.legend()
plt.tight_layout()
plt.savefig("basic_closed_loop_control_inputs.png")
plt.show()

Here’s what the result looks like:

Closed-loop control inputs

For a detailed tutorial on closed-loop control of a fitted consumer-resource model, please see the tutorial Advanced: Closed-Loop Control of Microbiome Dynamics.

API Documentation

class mgrowthctrl.controllers.closed_loop.ClosedLoopNeuralController(
state_dim: int,
n_species: int,
n_resources: int,
*,
hidden_dims: Sequence[int] | None = None,
activation: Module = ReLU(),
criterion: Callable[[Tensor, BaseController, BaseODEModel], Tuple[Tensor, Any]] | None = None,
injection_only: bool = True,
x_idx: Sequence[int] | None = None,
s_idx: Sequence[int] | None = None,
target_substrate_idx: int | None = None,
obs_x_idx: Sequence[int] | None = None,
obs_s_idx: Sequence[int] | None = None,
)

Bases: BaseController

Closed-Loop Neural Controller (State-Dependent Policy).

This controller learns a feedback policy u(X, S) from the current biological state. Unlike the open-loop controller, the output depends on the current observed state rather than time.

The controller outputs positive injection rates (via Softplus when enabled) which are added directly to the ODE derivatives:

  • dX/dt += u_x(X, S) (Probiotic injection)

  • dS/dt += u_s(X, S) (Nutrient/Prebiotic injection)

Parameters:
  • state_dim (int) – Total dimension of the full biological state (n_species + n_resources).

  • n_species (int) – Number of species in the system (dimension of X).

  • n_resources (int) – Number of substrates in the system (dimension of S).

  • hidden_dims (Optional[Sequence[int]]) – Architecture of the neural network (neurons per hidden layer). Default: [16, 16].

  • activation (nn.Module) – Activation function for hidden layers.

  • criterion (Optional[LossFunction]) – The objective function used for training. If not provided, you should override the compute_loss method. Signature: (trajectory, controller, model) -> (loss, aux_data).

  • injection_only (bool) – If True, applies Softplus to the network output so controls are nonnegative.

  • x_idx (Optional[Sequence[int]]) – Indices of species to control (add/inject). Default is None (control no species).

  • s_idx (Optional[Sequence[int]]) –

    Indices of substrates to control (add/inject).

    • If None: controls ALL substrates.

    • If []: controls NO substrates.

    • If [i, j]: controls specific substrates i and j.

  • target_substrate_idx (Optional[int]) – Convenience argument. If s_idx is None and this is set, the controller controls only this substrate index.

  • obs_x_idx (Optional[Sequence[int]]) – Indices of species observed by the controller. Default is None (observe all species).

  • obs_s_idx (Optional[Sequence[int]]) – Indices of substrates observed by the controller. Default is None (observe all substrates).

obs_dim

Total observed input dimension.

net

The neural network mapping observed state -> control signal u.

forward(
model: BaseODEModel,
t: Tensor,
X: Tensor,
S: Tensor,
) Tuple[Tensor, Tensor]

Calculates the control signal from the current observed state.

Parameters:
  • model (BaseODEModel) – The metadata wrapper for the system (unused, required by API).

  • t (torch.Tensor) – Current simulation time (unused by this controller).

  • X (torch.Tensor) – Current biological state.

  • S (torch.Tensor) – Current biological state.

Returns:

  • dXc (torch.Tensor) – Additive change to biomass derivatives (same shape as X).

  • dSc (torch.Tensor) – Additive change to substrate derivatives (same shape as S).

get_input_history(trajectory: Tensor, model: BaseODEModel) Tensor

Reconstructs the full control trajectory u(X_obs, S_obs) for visualization.

The control signal is evaluated along the provided trajectory using only the observed components of the state.

Parameters:

trajectory (torch.Tensor) – Shape (Steps, State_Dim), where State_Dim = n_species + n_resources.

Returns:

U_tensor – Shape (Steps, Output_Dim). The control signal at each time step.

Return type:

torch.Tensor