1"""An example of how to use the `Configurable` interface."""
2
3from __future__ import annotations
4
5import sys
6import typing as t
7
8import gymnasium as gym
9import matplotlib.pyplot as plt
10import numpy as np
11import scipy.optimize
12from matplotlib.axes import Axes
13from matplotlib.backends.backend_qt5agg import FigureCanvasQTAgg as FigureCanvas
14from matplotlib.backends.backend_qt5agg import NavigationToolbar2QT as NavigationToolbar
15from matplotlib.figure import Figure
16from numpy.typing import NDArray
17from PyQt5 import QtCore, QtGui, QtWidgets
18from typing_extensions import override
19
20from cernml import coi
21from cernml.coi import cancellation
22
23
24class ConfParabola(
25 coi.OptEnv[NDArray[np.double], NDArray[np.double], NDArray[np.double]],
26 coi.Configurable,
27):
28 """Example implementation of `OptEnv`.
29
30 The goal of this environment is to find the center of a parabola.
31 """
32
33 # pylint: disable = too-many-instance-attributes
34
35 # Domain declarations.
36 metadata = {
37 # All `mode` arguments to `self.render()` that we support.
38 "render_modes": ["ansi", "human", "matplotlib_figures"],
39 # The example is independent of all CERN accelerators.
40 "cern.machine": coi.Machine.NO_MACHINE,
41 # No need for communication with CERN accelerators.
42 "cern.japc": False,
43 # We implement cancellation for demonstration purposes.
44 "cern.cancellable": True,
45 }
46
47 # The radius at which an episode is ended. We employ "reward
48 # dangling", i.e. we start with a very wide radius and restrict it
49 # with each successful episode, up to a certain limit. This improves
50 # training speed, as the agent gathers more positive feedback early
51 # in the training.
52 objective = -0.05
53 max_objective = -0.003
54 action_space: gym.spaces.Box
55 observation_space: gym.spaces.Box
56 optimization_space: gym.spaces.Box
57
58 def __init__(
59 self,
60 cancellation_token: cancellation.Token,
61 *,
62 norm: int = 2,
63 dangling: bool = True,
64 box_width: float = 2.0,
65 dim: int = 5,
66 render_mode: str | None = None,
67 ):
68 self.render_mode = render_mode
69 self.token = cancellation_token
70 self.norm = norm
71 self.dangling = dangling
72 self.action_space = gym.spaces.Box(-1.0, 1.0, shape=(dim,))
73 self.observation_space = gym.spaces.Box(-box_width, box_width, shape=(dim,))
74 self.optimization_space = gym.spaces.Box(-box_width, box_width, shape=(dim,))
75 self.pos: NDArray[np.double] = np.zeros((dim,))
76 self.figure: Figure | None = None
77
78 @override
79 def get_config(self) -> coi.Config:
80 (dim,) = self.pos.shape
81 box_width = self.optimization_space.high[0]
82 config = coi.Config()
83 config.add("norm", self.norm, type=int, choices=(1, 2))
84 config.add("dimensions", dim, type=int, range=(1, 10))
85 config.add("enable_dangling", self.dangling, type=bool)
86 config.add("box_width", box_width, type=float, range=(0.0, float("inf")))
87 return config
88
89 @override
90 def apply_config(self, values: coi.ConfigValues) -> None:
91 self.norm = values.norm
92 self.dangling = values.enable_dangling
93 box_width = values.box_width
94 dim = values.dimensions
95 self.observation_space = gym.spaces.Box(-box_width, box_width, shape=(dim,))
96 self.action_space = gym.spaces.Box(-1.0, 1.0, shape=(dim,))
97 self.optimization_space = gym.spaces.Box(-box_width, box_width, shape=(dim,))
98 self.pos = np.zeros((dim,))
99
100 @override
101 def reset(
102 self, *, seed: int | None = None, options: coi.InfoDict | None = None
103 ) -> tuple[NDArray[np.double], coi.InfoDict]:
104 super().reset(seed=seed)
105 if seed is not None:
106 next_seed = self.np_random.bit_generator.random_raw
107 self.action_space.seed(next_seed())
108 self.observation_space.seed(next_seed())
109 self.optimization_space.seed(next_seed())
110 self.pos = self.optimization_space.sample()
111 # This is not good usage. In practice, you should only accept
112 # and use cancellation tokens if your environment contains a
113 # loop that waits for data. This is only for demonstration
114 # purposes.
115 self.token.raise_if_cancellation_requested()
116 return self.pos.copy(), {}
117
118 @override
119 def step(
120 self, action: NDArray[np.double]
121 ) -> tuple[NDArray[np.double], float, bool, bool, coi.InfoDict]:
122 old_pos = self.pos
123 next_pos = self.pos + action
124 self.pos = np.clip(
125 next_pos,
126 self.observation_space.low,
127 self.observation_space.high,
128 )
129 try:
130 # Because cancellation is cooperative, we know this is the
131 # only place where we can get cancelled.
132 reward = -self._fetch_distance_slow()
133 except cancellation.CancelledError:
134 self.pos = old_pos
135 self.token.complete_cancellation()
136 raise
137 terminated = reward > self.objective
138 truncated = next_pos not in self.observation_space
139 info = {"objective": self.objective}
140 if self.dangling and terminated and self.objective < self.max_objective:
141 self.objective *= 0.95
142 if self.render_mode == "human":
143 self.render()
144 return self.pos.copy(), reward, terminated, truncated, info
145
146 @override
147 def get_initial_params(
148 self, *, seed: int | None = None, options: coi.InfoDict | None = None
149 ) -> NDArray[np.double]:
150 pos, _ = self.reset(seed=seed, options=options)
151 return pos
152
153 @override
154 def compute_single_objective(self, params: NDArray[np.double]) -> float:
155 old_pos = self.pos
156 self.pos = np.clip(
157 params,
158 self.observation_space.low,
159 self.observation_space.high,
160 )
161 try:
162 # Because cancellation is cooperative, we know this is the
163 # only place where we can get cancelled.
164 return self._fetch_distance_slow()
165 except cancellation.CancelledError:
166 self.pos = old_pos
167 self.token.complete_cancellation()
168 raise
169
170 @override
171 def render(self) -> t.Any:
172 if self.render_mode == "human":
173 _, axes = plt.subplots()
174 self._update_axes(axes)
175 plt.show()
176 return None
177 if self.render_mode == "matplotlib_figures":
178 if self.figure is None:
179 self.figure = Figure()
180 axes = self.figure.subplots()
181 else:
182 [axes] = self.figure.axes
183 self._update_axes(axes)
184 return [self.figure]
185 if self.render_mode == "ansi":
186 return str(self.pos)
187 return super().render()
188
189 def _update_axes(self, axes: Axes) -> None:
190 """Plot this environment onto the given axes.
191
192 This method allows us to implement plotting once for both the
193 "human" and the "matplotlib_figures" render mode.
194 """
195 axes.cla()
196 axes.plot(self.pos, "o")
197 axes.plot(self.observation_space.low, "k--")
198 axes.plot(self.observation_space.high, "k--")
199 axes.plot(0.0 * self.observation_space.high, "k--")
200 axes.set_xlabel("Axes")
201 axes.set_ylabel("Position")
202
203 def _fetch_distance_slow(self, pos: np.ndarray | None = None) -> float:
204 """Get distance from the goal in a slow manner.
205
206 This simulates interaction with the machine. We sleep for a
207 while, then return the distance between the current position and
208 the coordinate-space origin.
209
210 Raises:
211 cernml.cancellation.CancelledError: if a cancellation
212 arrives while this method sleeps.
213 """
214 handle = self.token.wait_handle
215 with handle:
216 if handle.wait_for(lambda: self.token.cancellation_requested, timeout=0.3):
217 raise cancellation.CancelledError
218 return float(
219 np.linalg.norm(pos if pos is not None else self.pos, ord=self.norm)
220 )
221
222
223coi.register("ConfParabola-v0", entry_point=ConfParabola, max_episode_steps=10)
224
225
226class OptimizerThread(QtCore.QThread):
227 """Qt Thread that runs a COBYLA optimization.
228
229 Args:
230 env: An optimizable problem.
231 """
232
233 step = QtCore.pyqtSignal()
234
235 def __init__(self, env: coi.SingleOptimizable) -> None:
236 super().__init__()
237 self.env = env
238 opt_space = env.optimization_space
239 assert isinstance(opt_space, gym.spaces.Box), opt_space
240 self.optimization_space = opt_space
241
242 def run(self) -> None:
243 """Thread main function."""
244
245 def constraint(params: NDArray[np.double]) -> t.SupportsFloat:
246 space = self.optimization_space
247 width = space.high - space.low
248 return np.linalg.norm(2 * params / width, ord=np.inf)
249
250 def func(params: NDArray[np.double]) -> t.SupportsFloat:
251 loss = self.env.compute_single_objective(params)
252 self.step.emit()
253 return loss
254
255 try:
256 res: scipy.optimize.OptimizeResult = scipy.optimize.minimize(
257 func,
258 x0=self.env.get_initial_params(),
259 method="COBYLA",
260 constraints=[scipy.optimize.NonlinearConstraint(constraint, 0.0, 1.0)],
261 tol=0.01,
262 )
263 if res.success:
264 func(res.x)
265 except cancellation.CancelledError:
266 print("Operation cancelled by user")
267
268
269class ConfigureDialog(QtWidgets.QDialog):
270 """Qt dialog that allows configuring an environment.
271
272 Args:
273 target: The environment to be configured.
274 parent: The parent widget to attach to.
275 """
276
277 def __init__(
278 self, target: coi.Configurable, parent: QtWidgets.QWidget | None = None
279 ) -> None:
280 super().__init__(parent)
281 spec = getattr(target, "spec", None)
282 name = getattr(spec, "id", type(target).__name__)
283 self.setWindowTitle(f"Configure {name} ...")
284 self.target = target
285 self.config = self.target.get_config()
286 self.current_values = {
287 field.dest: field.value for field in self.config.fields()
288 }
289 main_layout = QtWidgets.QVBoxLayout()
290 self.setLayout(main_layout)
291 params = QtWidgets.QWidget()
292 main_layout.addWidget(params)
293 params_layout = QtWidgets.QFormLayout()
294 params.setLayout(params_layout)
295 for field in self.config.fields():
296 label = QtWidgets.QLabel(field.label)
297 widget = self._make_field_widget(field)
298 params_layout.addRow(label, widget)
299 controls = QtWidgets.QDialogButtonBox(
300 QtWidgets.QDialogButtonBox.Ok
301 | QtWidgets.QDialogButtonBox.Apply
302 | QtWidgets.QDialogButtonBox.Cancel
303 )
304 controls.button(controls.Ok).clicked.connect(self.on_ok_clicked)
305 controls.button(controls.Apply).clicked.connect(self.on_apply_clicked)
306 controls.button(controls.Cancel).clicked.connect(self.on_cancel_clicked)
307 main_layout.addWidget(controls)
308
309 def on_ok_clicked(self) -> None:
310 """Apply the configs and close the window."""
311 values = self.config.validate_all(self.current_values)
312 print(values)
313 self.target.apply_config(values)
314 self.accept()
315
316 def on_apply_clicked(self) -> None:
317 """Apply the configs."""
318 values = self.config.validate_all(self.current_values)
319 print(values)
320 self.target.apply_config(values)
321
322 def on_cancel_clicked(self) -> None:
323 """Discard any changes and close the window."""
324 self.reject()
325
326 def _make_field_widget(self, field: coi.Config.Field) -> QtWidgets.QWidget:
327 """Given a field, pick the best widget to configure it."""
328 # pylint: disable = too-many-return-statements
329 if field.choices is not None:
330 combo_box = QtWidgets.QComboBox()
331 combo_box.addItems(str(choice) for choice in field.choices)
332 combo_box.setCurrentText(str(field.value))
333 combo_box.currentTextChanged.connect(
334 lambda val: self.set_current_value(field.dest, val)
335 )
336 return combo_box
337 if field.range is not None:
338 low, high = field.range
339 spin_box: QtWidgets.QSpinBox | QtWidgets.QDoubleSpinBox
340 if isinstance(field.value, (int, np.integer)):
341 spin_box = QtWidgets.QSpinBox()
342 spin_box.setValue(int(field.value))
343 elif isinstance(field.value, (float, np.floating)):
344 spin_box = QtWidgets.QDoubleSpinBox()
345 spin_box.setValue(float(field.value))
346 else:
347 raise KeyError(type(field.value))
348 spin_box.setRange(low, high)
349 spin_box.valueChanged.connect(
350 lambda val: self.set_current_value(field.dest, str(val))
351 )
352 return spin_box
353 if isinstance(field.value, (bool, np.bool_)):
354 check_box = QtWidgets.QCheckBox()
355 check_box.setChecked(bool(field.value))
356 # Do not use `str(checked)`! `False` converts to `"False"`,
357 # which would convert back to `True` via `bool(string)`.
358 check_box.stateChanged.connect(
359 lambda checked: self.set_current_value(
360 field.dest, "checked" if checked else ""
361 )
362 )
363 return check_box
364 if isinstance(field.value, (int, np.integer)):
365 line_edit = QtWidgets.QLineEdit(str(field.value))
366 line_edit.setValidator(QtGui.QIntValidator())
367 line_edit.editingFinished.connect(
368 lambda: self.set_current_value(field.dest, line_edit.text())
369 )
370 return line_edit
371 if isinstance(field.value, (float, np.floating)):
372 line_edit = QtWidgets.QLineEdit(str(field.value))
373 line_edit.setValidator(QtGui.QDoubleValidator())
374 line_edit.editingFinished.connect(
375 lambda: self.set_current_value(field.dest, line_edit.text())
376 )
377 return line_edit
378 if isinstance(field.value, str):
379 line_edit = QtWidgets.QLineEdit(str(field.value))
380 line_edit.editingFinished.connect(
381 lambda: self.set_current_value(field.dest, line_edit.text())
382 )
383 return line_edit
384 return QtWidgets.QLabel(str(field.value))
385
386 def set_current_value(self, name: str, value: str) -> None:
387 """Update the saved values.
388
389 This is called by each config widget when it changes its value.
390 """
391 self.current_values[name] = value
392
393
394class MainWindow(QtWidgets.QMainWindow):
395 """Main window of the Qt application."""
396
397 def __init__(self) -> None:
398 super().__init__()
399
400 self.cancellation_token_source = cancellation.TokenSource()
401 env = coi.make(
402 "ConfParabola-v0",
403 cancellation_token=self.cancellation_token_source.token,
404 render_mode="matplotlib_figures",
405 )
406 self.env = t.cast(ConfParabola, env)
407 self.worker = OptimizerThread(self.env)
408 self.worker.step.connect(self.on_opt_step)
409 self.worker.finished.connect(self.on_opt_finished)
410 self.env.reset()
411
412 [figure] = self.env.render()
413 self.canvas = FigureCanvas(figure)
414 self.launch = QtWidgets.QPushButton("Launch")
415 self.launch.clicked.connect(self.on_launch)
416 self.cancel = QtWidgets.QPushButton("Cancel")
417 self.cancel.clicked.connect(self.on_cancel)
418 self.cancel.setEnabled(False)
419 self.configure_env = QtWidgets.QPushButton("Configure…")
420 self.configure_env.clicked.connect(self.on_configure)
421
422 window = QtWidgets.QWidget()
423 self.setCentralWidget(window)
424 main_layout = QtWidgets.QVBoxLayout(window)
425 buttons_layout = QtWidgets.QHBoxLayout()
426 main_layout.addWidget(self.canvas)
427 main_layout.addLayout(buttons_layout)
428 buttons_layout.addWidget(self.launch)
429 buttons_layout.addWidget(self.cancel)
430 buttons_layout.addWidget(self.configure_env)
431 self.addToolBar(NavigationToolbar(self.canvas, parent=self))
432
433 @override
434 def closeEvent(self, event: QtGui.QCloseEvent) -> None:
435 # pylint: disable = invalid-name, missing-function-docstring
436 self.launch.setEnabled(False)
437 self.cancel.setEnabled(False)
438 self.configure_env.setEnabled(False)
439 self.cancellation_token_source.cancel()
440 self.worker.wait()
441 event.accept()
442
443 def on_configure(self) -> None:
444 """Open the dialog to configure the environment."""
445 assert coi.is_configurable(self.env)
446 dialog = ConfigureDialog(self.env, parent=self)
447 dialog.open()
448
449 def on_launch(self) -> None:
450 """Disable the GUI and start optimization."""
451 self.launch.setEnabled(False)
452 self.cancel.setEnabled(True)
453 self.configure_env.setEnabled(False)
454 self.worker.start()
455
456 def on_cancel(self) -> None:
457 """Send a cancellation request."""
458 self.cancellation_token_source.cancel()
459 # Disable the button. `on_opt_finished()` will eventually
460 # re-enable the other buttons.
461 self.cancel.setEnabled(False)
462
463 def on_opt_step(self) -> None:
464 """Update the plots."""
465 self.env.render()
466 self.canvas.draw()
467
468 def on_opt_finished(self) -> None:
469 """Re-enable the GUI."""
470 # Reset the cancellation, if it is possible. Only re-enable the
471 # launch button if we could reset the cancellation (or no
472 # cancellation ever occurred.)
473 if self.cancellation_token_source.can_reset_cancellation:
474 self.cancellation_token_source.reset_cancellation()
475 if not self.cancellation_token_source.cancellation_requested:
476 self.launch.setEnabled(True)
477 self.cancel.setEnabled(False)
478 self.configure_env.setEnabled(True)
479
480
481def main(argv: list[str]) -> int:
482 """Main function. You should pass in `sys.argv`."""
483 app = QtWidgets.QApplication(argv)
484 window = MainWindow()
485 window.show()
486 return app.exec_()
487
488
489if __name__ == "__main__":
490 sys.exit(main(sys.argv))