Example of Using Both Optimization and RL

  1"""An example implementation of the `OptEnv` interface."""
  2
  3from __future__ import annotations
  4
  5import argparse
  6import sys
  7import typing as t
  8
  9import gymnasium as gym
 10import matplotlib.pyplot as plt
 11import numpy as np
 12import scipy.optimize
 13from matplotlib.axes import Axes
 14from matplotlib.figure import Figure
 15from numpy.typing import NDArray
 16from stable_baselines3.common.base_class import BaseAlgorithm
 17from stable_baselines3.td3 import TD3
 18from typing_extensions import override
 19
 20from cernml import coi
 21
 22
 23class Parabola(coi.OptEnv):
 24    """Example implementation of `OptEnv`.
 25
 26    The goal of this environment is to find the center of a 2D parabola.
 27    """
 28
 29    # Domain declarations.
 30    observation_space = gym.spaces.Box(-2.0, 2.0, shape=(2,))
 31    action_space = gym.spaces.Box(-1.0, 1.0, shape=(2,))
 32    optimization_space = gym.spaces.Box(-2.0, 2.0, shape=(2,))
 33    metadata = {
 34        # All `mode` arguments to `self.render()` that we support.
 35        "render_modes": ["ansi", "human", "matplotlib_figures"],
 36        # The example is independent of all CERN accelerators.
 37        "cern.machine": coi.Machine.NO_MACHINE,
 38        # No need for communication with CERN accelerators.
 39        "cern.japc": False,
 40        # Cancellation is important if you communicate with an
 41        # accelerator. There might be a bug in the machine and your
 42        # environment is waiting for data that will never arrive. In
 43        # such situations, it is good when the user gets a chance to
 44        # cleanly shut down your environment. Cancellation tokens solve
 45        # this problem.
 46        # That being said, we don't communicate with an accelerator, so
 47        # we don't need this feature here.
 48        "cern.cancellable": False,
 49    }
 50
 51    # The radius at which an episode is ended. We employ "reward
 52    # dangling", i.e. we start with a very wide radius and restrict it
 53    # with each successful episode, up to a certain limit. This improves
 54    # training speed, as the agent gathers more positive feedback early
 55    # in the training.
 56    objective = -0.05
 57    max_objective = -0.003
 58
 59    def __init__(self, *, render_mode: str | None = None) -> None:
 60        self.render_mode = render_mode
 61        self.pos: NDArray[np.double] = np.zeros(2)
 62        self._train = True
 63        self.figure: Figure | None = None
 64
 65    def train(self, train: bool = True) -> None:
 66        """Turn the environment's training mode on or off.
 67
 68        If the training mode is on, reward dangling is active and each
 69        successful end of episode makes the objective stricter. If
 70        training mode is off, the objective remains constant.
 71        """
 72        self._train = train
 73
 74    @override
 75    def reset(
 76        self, seed: int | None = None, options: coi.InfoDict | None = None
 77    ) -> tuple[NDArray[np.double], coi.InfoDict]:
 78        super().reset(seed=seed)
 79        if seed is not None:
 80            next_seed = self.np_random.bit_generator.random_raw
 81            self.action_space.seed(next_seed())
 82            self.observation_space.seed(next_seed())
 83            self.optimization_space.seed(next_seed())
 84        # Don't use the full observation space for initial states.
 85        self.pos = self.action_space.sample()
 86        return self.pos.copy(), {}
 87
 88    @override
 89    def step(
 90        self, action: NDArray[np.double]
 91    ) -> tuple[NDArray[np.double], float, bool, bool, coi.InfoDict]:
 92        next_pos = self.pos + action
 93        self.pos = np.clip(
 94            next_pos,
 95            self.observation_space.low,
 96            self.observation_space.high,
 97        )
 98        reward = -sum(self.pos**2)
 99        terminated = reward > self.objective
100        truncated = next_pos not in self.observation_space
101        info = {"objective": self.objective}
102        if self._train and terminated and self.objective < self.max_objective:
103            self.objective *= 0.95
104        return self.pos.copy(), reward, terminated, truncated, info
105
106    @override
107    def get_initial_params(
108        self, *, seed: int | None = None, options: coi.InfoDict | None = None
109    ) -> NDArray[np.double]:
110        pos, _ = self.reset(seed=seed, options=options)
111        return pos
112
113    @override
114    def compute_single_objective(self, params: NDArray[np.double]) -> float:
115        self.pos = np.clip(
116            params,
117            self.observation_space.low,
118            self.observation_space.high,
119        )
120        return sum(self.pos**2)
121
122    @override
123    def render(self) -> t.Any:
124        if self.render_mode == "human":
125            plt.figure()
126            plt.scatter(*self.pos)
127            plt.show()
128            return None
129        if self.render_mode == "matplotlib_figures":
130            if self.figure is None:
131                self.figure = Figure()
132                axes = t.cast(Axes, self.figure.subplots())
133            else:
134                [axes] = self.figure.axes
135            axes.scatter(*self.pos)
136            return [self.figure]
137        if self.render_mode == "ansi":
138            return str(self.pos)
139        return super().render()
140
141
142coi.register("Parabola-v0", entry_point=Parabola, max_episode_steps=10)
143
144
145def run_episode(agent: BaseAlgorithm, env: coi.OptEnv) -> bool:
146    """Run one episode of ``env`` and return the success flag."""
147    obs, _ = env.reset()
148    done = False
149    while not done:
150        action, _ = agent.predict(obs)
151        obs, _, terminated, truncated, info = env.step(action)
152        done = terminated or truncated
153    return info.get("success", False)
154
155
156def get_parser() -> argparse.ArgumentParser:
157    """Return an `ArgumentParser` instance."""
158    description, _, epilog = __doc__.partition("\n\n")
159    parser = argparse.ArgumentParser(
160        description=description,
161        epilog=epilog,
162    )
163    parser.add_argument(
164        "mode",
165        choices=("rl", "opt"),
166        help="whether to run numerical optimization or reinforcement learning",
167    )
168    return parser
169
170
171def main_rl(env: Parabola, num_runs: int) -> list[bool]:
172    """Handler for `rl` mode."""
173    agent = TD3("MlpPolicy", env, learning_rate=2e-3)
174    agent.learn(total_timesteps=300)
175    env.train(False)
176    return [run_episode(agent, env) for _ in range(num_runs)]
177
178
179def main_opt(env: Parabola, num_runs: int) -> list[bool]:
180    """Handler for `opt` mode."""
181    bounds = bounds = scipy.optimize.Bounds(
182        env.optimization_space.low,
183        env.optimization_space.high,
184    )
185    return [
186        scipy.optimize.minimize(
187            fun=env.compute_single_objective,
188            x0=env.get_initial_params(),
189            bounds=bounds,
190        ).success
191        for _ in range(num_runs)
192    ]
193
194
195def main(argv: list[str]) -> None:
196    """Main function. Should be passed `sys.argv[1:]`."""
197    args = get_parser().parse_args(argv)
198    env = t.cast(Parabola, coi.make("Parabola-v0"))
199    coi.check(env)
200    successes = {"rl": main_rl, "opt": main_opt}[args.mode](env, 100)
201    print(f"Success rate: {np.mean(successes):.1%}")
202
203
204if __name__ == "__main__":
205    main(sys.argv[1:])