1"""An example implementation of the `OptEnv` interface."""
2
3from __future__ import annotations
4
5import argparse
6import sys
7import typing as t
8
9import gymnasium as gym
10import matplotlib.pyplot as plt
11import numpy as np
12import scipy.optimize
13from matplotlib.axes import Axes
14from matplotlib.figure import Figure
15from numpy.typing import NDArray
16from stable_baselines3.common.base_class import BaseAlgorithm
17from stable_baselines3.td3 import TD3
18from typing_extensions import override
19
20from cernml import coi
21
22
23class Parabola(coi.OptEnv):
24 """Example implementation of `OptEnv`.
25
26 The goal of this environment is to find the center of a 2D parabola.
27 """
28
29 # Domain declarations.
30 observation_space = gym.spaces.Box(-2.0, 2.0, shape=(2,))
31 action_space = gym.spaces.Box(-1.0, 1.0, shape=(2,))
32 optimization_space = gym.spaces.Box(-2.0, 2.0, shape=(2,))
33 metadata = {
34 # All `mode` arguments to `self.render()` that we support.
35 "render_modes": ["ansi", "human", "matplotlib_figures"],
36 # The example is independent of all CERN accelerators.
37 "cern.machine": coi.Machine.NO_MACHINE,
38 # No need for communication with CERN accelerators.
39 "cern.japc": False,
40 # Cancellation is important if you communicate with an
41 # accelerator. There might be a bug in the machine and your
42 # environment is waiting for data that will never arrive. In
43 # such situations, it is good when the user gets a chance to
44 # cleanly shut down your environment. Cancellation tokens solve
45 # this problem.
46 # That being said, we don't communicate with an accelerator, so
47 # we don't need this feature here.
48 "cern.cancellable": False,
49 }
50
51 # The radius at which an episode is ended. We employ "reward
52 # dangling", i.e. we start with a very wide radius and restrict it
53 # with each successful episode, up to a certain limit. This improves
54 # training speed, as the agent gathers more positive feedback early
55 # in the training.
56 objective = -0.05
57 max_objective = -0.003
58
59 def __init__(self, *, render_mode: str | None = None) -> None:
60 self.render_mode = render_mode
61 self.pos: NDArray[np.double] = np.zeros(2)
62 self._train = True
63 self.figure: Figure | None = None
64
65 def train(self, train: bool = True) -> None:
66 """Turn the environment's training mode on or off.
67
68 If the training mode is on, reward dangling is active and each
69 successful end of episode makes the objective stricter. If
70 training mode is off, the objective remains constant.
71 """
72 self._train = train
73
74 @override
75 def reset(
76 self, seed: int | None = None, options: coi.InfoDict | None = None
77 ) -> tuple[NDArray[np.double], coi.InfoDict]:
78 super().reset(seed=seed)
79 if seed is not None:
80 next_seed = self.np_random.bit_generator.random_raw
81 self.action_space.seed(next_seed())
82 self.observation_space.seed(next_seed())
83 self.optimization_space.seed(next_seed())
84 # Don't use the full observation space for initial states.
85 self.pos = self.action_space.sample()
86 return self.pos.copy(), {}
87
88 @override
89 def step(
90 self, action: NDArray[np.double]
91 ) -> tuple[NDArray[np.double], float, bool, bool, coi.InfoDict]:
92 next_pos = self.pos + action
93 self.pos = np.clip(
94 next_pos,
95 self.observation_space.low,
96 self.observation_space.high,
97 )
98 reward = -sum(self.pos**2)
99 terminated = reward > self.objective
100 truncated = next_pos not in self.observation_space
101 info = {"objective": self.objective}
102 if self._train and terminated and self.objective < self.max_objective:
103 self.objective *= 0.95
104 return self.pos.copy(), reward, terminated, truncated, info
105
106 @override
107 def get_initial_params(
108 self, *, seed: int | None = None, options: coi.InfoDict | None = None
109 ) -> NDArray[np.double]:
110 pos, _ = self.reset(seed=seed, options=options)
111 return pos
112
113 @override
114 def compute_single_objective(self, params: NDArray[np.double]) -> float:
115 self.pos = np.clip(
116 params,
117 self.observation_space.low,
118 self.observation_space.high,
119 )
120 return sum(self.pos**2)
121
122 @override
123 def render(self) -> t.Any:
124 if self.render_mode == "human":
125 plt.figure()
126 plt.scatter(*self.pos)
127 plt.show()
128 return None
129 if self.render_mode == "matplotlib_figures":
130 if self.figure is None:
131 self.figure = Figure()
132 axes = t.cast(Axes, self.figure.subplots())
133 else:
134 [axes] = self.figure.axes
135 axes.scatter(*self.pos)
136 return [self.figure]
137 if self.render_mode == "ansi":
138 return str(self.pos)
139 return super().render()
140
141
142coi.register("Parabola-v0", entry_point=Parabola, max_episode_steps=10)
143
144
145def run_episode(agent: BaseAlgorithm, env: coi.OptEnv) -> bool:
146 """Run one episode of ``env`` and return the success flag."""
147 obs, _ = env.reset()
148 done = False
149 while not done:
150 action, _ = agent.predict(obs)
151 obs, _, terminated, truncated, info = env.step(action)
152 done = terminated or truncated
153 return info.get("success", False)
154
155
156def get_parser() -> argparse.ArgumentParser:
157 """Return an `ArgumentParser` instance."""
158 description, _, epilog = __doc__.partition("\n\n")
159 parser = argparse.ArgumentParser(
160 description=description,
161 epilog=epilog,
162 )
163 parser.add_argument(
164 "mode",
165 choices=("rl", "opt"),
166 help="whether to run numerical optimization or reinforcement learning",
167 )
168 return parser
169
170
171def main_rl(env: Parabola, num_runs: int) -> list[bool]:
172 """Handler for `rl` mode."""
173 agent = TD3("MlpPolicy", env, learning_rate=2e-3)
174 agent.learn(total_timesteps=300)
175 env.train(False)
176 return [run_episode(agent, env) for _ in range(num_runs)]
177
178
179def main_opt(env: Parabola, num_runs: int) -> list[bool]:
180 """Handler for `opt` mode."""
181 bounds = bounds = scipy.optimize.Bounds(
182 env.optimization_space.low,
183 env.optimization_space.high,
184 )
185 return [
186 scipy.optimize.minimize(
187 fun=env.compute_single_objective,
188 x0=env.get_initial_params(),
189 bounds=bounds,
190 ).success
191 for _ in range(num_runs)
192 ]
193
194
195def main(argv: list[str]) -> None:
196 """Main function. Should be passed `sys.argv[1:]`."""
197 args = get_parser().parse_args(argv)
198 env = t.cast(Parabola, coi.make("Parabola-v0"))
199 coi.check(env)
200 successes = {"rl": main_rl, "opt": main_opt}[args.mode](env, 100)
201 print(f"Success rate: {np.mean(successes):.1%}")
202
203
204if __name__ == "__main__":
205 main(sys.argv[1:])