Minimal GUI for Configuring an Optimization Problem

  1"""An example of how to use the `Configurable` interface."""
  2
  3from __future__ import annotations
  4
  5import sys
  6import typing as t
  7
  8import gymnasium as gym
  9import matplotlib.pyplot as plt
 10import numpy as np
 11import scipy.optimize
 12from matplotlib.axes import Axes
 13from matplotlib.backends.backend_qt5agg import FigureCanvasQTAgg as FigureCanvas
 14from matplotlib.backends.backend_qt5agg import NavigationToolbar2QT as NavigationToolbar
 15from matplotlib.figure import Figure
 16from numpy.typing import NDArray
 17from PyQt5 import QtCore, QtGui, QtWidgets
 18from typing_extensions import override
 19
 20from cernml import coi
 21from cernml.coi import cancellation
 22
 23
 24class ConfParabola(
 25    coi.OptEnv[NDArray[np.double], NDArray[np.double], NDArray[np.double]],
 26    coi.Configurable,
 27):
 28    """Example implementation of `OptEnv`.
 29
 30    The goal of this environment is to find the center of a parabola.
 31    """
 32
 33    # pylint: disable = too-many-instance-attributes
 34
 35    # Domain declarations.
 36    metadata = {
 37        # All `mode` arguments to `self.render()` that we support.
 38        "render_modes": ["ansi", "human", "matplotlib_figures"],
 39        # The example is independent of all CERN accelerators.
 40        "cern.machine": coi.Machine.NO_MACHINE,
 41        # No need for communication with CERN accelerators.
 42        "cern.japc": False,
 43        # We implement cancellation for demonstration purposes.
 44        "cern.cancellable": True,
 45    }
 46
 47    # The radius at which an episode is ended. We employ "reward
 48    # dangling", i.e. we start with a very wide radius and restrict it
 49    # with each successful episode, up to a certain limit. This improves
 50    # training speed, as the agent gathers more positive feedback early
 51    # in the training.
 52    objective = -0.05
 53    max_objective = -0.003
 54    action_space: gym.spaces.Box
 55    observation_space: gym.spaces.Box
 56    optimization_space: gym.spaces.Box
 57
 58    def __init__(
 59        self,
 60        cancellation_token: cancellation.Token,
 61        *,
 62        norm: int = 2,
 63        dangling: bool = True,
 64        box_width: float = 2.0,
 65        dim: int = 5,
 66        render_mode: str | None = None,
 67    ):
 68        self.render_mode = render_mode
 69        self.token = cancellation_token
 70        self.norm = norm
 71        self.dangling = dangling
 72        self.action_space = gym.spaces.Box(-1.0, 1.0, shape=(dim,))
 73        self.observation_space = gym.spaces.Box(-box_width, box_width, shape=(dim,))
 74        self.optimization_space = gym.spaces.Box(-box_width, box_width, shape=(dim,))
 75        self.pos: NDArray[np.double] = np.zeros((dim,))
 76        self.figure: Figure | None = None
 77
 78    @override
 79    def get_config(self) -> coi.Config:
 80        (dim,) = self.pos.shape
 81        box_width = self.optimization_space.high[0]
 82        config = coi.Config()
 83        config.add("norm", self.norm, type=int, choices=(1, 2))
 84        config.add("dimensions", dim, type=int, range=(1, 10))
 85        config.add("enable_dangling", self.dangling, type=bool)
 86        config.add("box_width", box_width, type=float, range=(0.0, float("inf")))
 87        return config
 88
 89    @override
 90    def apply_config(self, values: coi.ConfigValues) -> None:
 91        self.norm = values.norm
 92        self.dangling = values.enable_dangling
 93        box_width = values.box_width
 94        dim = values.dimensions
 95        self.observation_space = gym.spaces.Box(-box_width, box_width, shape=(dim,))
 96        self.action_space = gym.spaces.Box(-1.0, 1.0, shape=(dim,))
 97        self.optimization_space = gym.spaces.Box(-box_width, box_width, shape=(dim,))
 98        self.pos = np.zeros((dim,))
 99
100    @override
101    def reset(
102        self, *, seed: int | None = None, options: coi.InfoDict | None = None
103    ) -> tuple[NDArray[np.double], coi.InfoDict]:
104        super().reset(seed=seed)
105        if seed is not None:
106            next_seed = self.np_random.bit_generator.random_raw
107            self.action_space.seed(next_seed())
108            self.observation_space.seed(next_seed())
109            self.optimization_space.seed(next_seed())
110        self.pos = self.optimization_space.sample()
111        # This is not good usage. In practice, you should only accept
112        # and use cancellation tokens if your environment contains a
113        # loop that waits for data. This is only for demonstration
114        # purposes.
115        self.token.raise_if_cancellation_requested()
116        return self.pos.copy(), {}
117
118    @override
119    def step(
120        self, action: NDArray[np.double]
121    ) -> tuple[NDArray[np.double], float, bool, bool, coi.InfoDict]:
122        old_pos = self.pos
123        next_pos = self.pos + action
124        self.pos = np.clip(
125            next_pos,
126            self.observation_space.low,
127            self.observation_space.high,
128        )
129        try:
130            # Because cancellation is cooperative, we know this is the
131            # only place where we can get cancelled.
132            reward = -self._fetch_distance_slow()
133        except cancellation.CancelledError:
134            self.pos = old_pos
135            self.token.complete_cancellation()
136            raise
137        terminated = reward > self.objective
138        truncated = next_pos not in self.observation_space
139        info = {"objective": self.objective}
140        if self.dangling and terminated and self.objective < self.max_objective:
141            self.objective *= 0.95
142        if self.render_mode == "human":
143            self.render()
144        return self.pos.copy(), reward, terminated, truncated, info
145
146    @override
147    def get_initial_params(
148        self, *, seed: int | None = None, options: coi.InfoDict | None = None
149    ) -> NDArray[np.double]:
150        pos, _ = self.reset(seed=seed, options=options)
151        return pos
152
153    @override
154    def compute_single_objective(self, params: NDArray[np.double]) -> float:
155        old_pos = self.pos
156        self.pos = np.clip(
157            params,
158            self.observation_space.low,
159            self.observation_space.high,
160        )
161        try:
162            # Because cancellation is cooperative, we know this is the
163            # only place where we can get cancelled.
164            return self._fetch_distance_slow()
165        except cancellation.CancelledError:
166            self.pos = old_pos
167            self.token.complete_cancellation()
168            raise
169
170    @override
171    def render(self) -> t.Any:
172        if self.render_mode == "human":
173            _, axes = plt.subplots()
174            self._update_axes(axes)
175            plt.show()
176            return None
177        if self.render_mode == "matplotlib_figures":
178            if self.figure is None:
179                self.figure = Figure()
180                axes = self.figure.subplots()
181            else:
182                [axes] = self.figure.axes
183            self._update_axes(axes)
184            return [self.figure]
185        if self.render_mode == "ansi":
186            return str(self.pos)
187        return super().render()
188
189    def _update_axes(self, axes: Axes) -> None:
190        """Plot this environment onto the given axes.
191
192        This method allows us to implement plotting once for both the
193        "human" and the "matplotlib_figures" render mode.
194        """
195        axes.cla()
196        axes.plot(self.pos, "o")
197        axes.plot(self.observation_space.low, "k--")
198        axes.plot(self.observation_space.high, "k--")
199        axes.plot(0.0 * self.observation_space.high, "k--")
200        axes.set_xlabel("Axes")
201        axes.set_ylabel("Position")
202
203    def _fetch_distance_slow(self, pos: np.ndarray | None = None) -> float:
204        """Get distance from the goal in a slow manner.
205
206        This simulates interaction with the machine. We sleep for a
207        while, then return the distance between the current position and
208        the coordinate-space origin.
209
210        Raises:
211            cernml.cancellation.CancelledError: if a cancellation
212                arrives while this method sleeps.
213        """
214        handle = self.token.wait_handle
215        with handle:
216            if handle.wait_for(lambda: self.token.cancellation_requested, timeout=0.3):
217                raise cancellation.CancelledError
218        return float(
219            np.linalg.norm(pos if pos is not None else self.pos, ord=self.norm)
220        )
221
222
223coi.register("ConfParabola-v0", entry_point=ConfParabola, max_episode_steps=10)
224
225
226class OptimizerThread(QtCore.QThread):
227    """Qt Thread that runs a COBYLA optimization.
228
229    Args:
230        env: An optimizable problem.
231    """
232
233    step = QtCore.pyqtSignal()
234
235    def __init__(self, env: coi.SingleOptimizable) -> None:
236        super().__init__()
237        self.env = env
238        opt_space = env.optimization_space
239        assert isinstance(opt_space, gym.spaces.Box), opt_space
240        self.optimization_space = opt_space
241
242    def run(self) -> None:
243        """Thread main function."""
244
245        def constraint(params: NDArray[np.double]) -> t.SupportsFloat:
246            space = self.optimization_space
247            width = space.high - space.low
248            return np.linalg.norm(2 * params / width, ord=np.inf)
249
250        def func(params: NDArray[np.double]) -> t.SupportsFloat:
251            loss = self.env.compute_single_objective(params)
252            self.step.emit()
253            return loss
254
255        try:
256            res: scipy.optimize.OptimizeResult = scipy.optimize.minimize(
257                func,
258                x0=self.env.get_initial_params(),
259                method="COBYLA",
260                constraints=[scipy.optimize.NonlinearConstraint(constraint, 0.0, 1.0)],
261                tol=0.01,
262            )
263            if res.success:
264                func(res.x)
265        except cancellation.CancelledError:
266            print("Operation cancelled by user")
267
268
269class ConfigureDialog(QtWidgets.QDialog):
270    """Qt dialog that allows configuring an environment.
271
272    Args:
273        target: The environment to be configured.
274        parent: The parent widget to attach to.
275    """
276
277    def __init__(
278        self, target: coi.Configurable, parent: QtWidgets.QWidget | None = None
279    ) -> None:
280        super().__init__(parent)
281        spec = getattr(target, "spec", None)
282        name = getattr(spec, "id", type(target).__name__)
283        self.setWindowTitle(f"Configure {name} ...")
284        self.target = target
285        self.config = self.target.get_config()
286        self.current_values = {
287            field.dest: field.value for field in self.config.fields()
288        }
289        main_layout = QtWidgets.QVBoxLayout()
290        self.setLayout(main_layout)
291        params = QtWidgets.QWidget()
292        main_layout.addWidget(params)
293        params_layout = QtWidgets.QFormLayout()
294        params.setLayout(params_layout)
295        for field in self.config.fields():
296            label = QtWidgets.QLabel(field.label)
297            widget = self._make_field_widget(field)
298            params_layout.addRow(label, widget)
299        controls = QtWidgets.QDialogButtonBox(
300            QtWidgets.QDialogButtonBox.Ok
301            | QtWidgets.QDialogButtonBox.Apply
302            | QtWidgets.QDialogButtonBox.Cancel
303        )
304        controls.button(controls.Ok).clicked.connect(self.on_ok_clicked)
305        controls.button(controls.Apply).clicked.connect(self.on_apply_clicked)
306        controls.button(controls.Cancel).clicked.connect(self.on_cancel_clicked)
307        main_layout.addWidget(controls)
308
309    def on_ok_clicked(self) -> None:
310        """Apply the configs and close the window."""
311        values = self.config.validate_all(self.current_values)
312        print(values)
313        self.target.apply_config(values)
314        self.accept()
315
316    def on_apply_clicked(self) -> None:
317        """Apply the configs."""
318        values = self.config.validate_all(self.current_values)
319        print(values)
320        self.target.apply_config(values)
321
322    def on_cancel_clicked(self) -> None:
323        """Discard any changes and close the window."""
324        self.reject()
325
326    def _make_field_widget(self, field: coi.Config.Field) -> QtWidgets.QWidget:
327        """Given a field, pick the best widget to configure it."""
328        # pylint: disable = too-many-return-statements
329        if field.choices is not None:
330            combo_box = QtWidgets.QComboBox()
331            combo_box.addItems(str(choice) for choice in field.choices)
332            combo_box.setCurrentText(str(field.value))
333            combo_box.currentTextChanged.connect(
334                lambda val: self.set_current_value(field.dest, val)
335            )
336            return combo_box
337        if field.range is not None:
338            low, high = field.range
339            spin_box: QtWidgets.QSpinBox | QtWidgets.QDoubleSpinBox
340            if isinstance(field.value, (int, np.integer)):
341                spin_box = QtWidgets.QSpinBox()
342                spin_box.setValue(int(field.value))
343            elif isinstance(field.value, (float, np.floating)):
344                spin_box = QtWidgets.QDoubleSpinBox()
345                spin_box.setValue(float(field.value))
346            else:
347                raise KeyError(type(field.value))
348            spin_box.setRange(low, high)
349            spin_box.valueChanged.connect(
350                lambda val: self.set_current_value(field.dest, str(val))
351            )
352            return spin_box
353        if isinstance(field.value, (bool, np.bool_)):
354            check_box = QtWidgets.QCheckBox()
355            check_box.setChecked(bool(field.value))
356            # Do not use `str(checked)`! `False` converts to `"False"`,
357            # which would convert back to `True` via `bool(string)`.
358            check_box.stateChanged.connect(
359                lambda checked: self.set_current_value(
360                    field.dest, "checked" if checked else ""
361                )
362            )
363            return check_box
364        if isinstance(field.value, (int, np.integer)):
365            line_edit = QtWidgets.QLineEdit(str(field.value))
366            line_edit.setValidator(QtGui.QIntValidator())
367            line_edit.editingFinished.connect(
368                lambda: self.set_current_value(field.dest, line_edit.text())
369            )
370            return line_edit
371        if isinstance(field.value, (float, np.floating)):
372            line_edit = QtWidgets.QLineEdit(str(field.value))
373            line_edit.setValidator(QtGui.QDoubleValidator())
374            line_edit.editingFinished.connect(
375                lambda: self.set_current_value(field.dest, line_edit.text())
376            )
377            return line_edit
378        if isinstance(field.value, str):
379            line_edit = QtWidgets.QLineEdit(str(field.value))
380            line_edit.editingFinished.connect(
381                lambda: self.set_current_value(field.dest, line_edit.text())
382            )
383            return line_edit
384        return QtWidgets.QLabel(str(field.value))
385
386    def set_current_value(self, name: str, value: str) -> None:
387        """Update the saved values.
388
389        This is called by each config widget when it changes its value.
390        """
391        self.current_values[name] = value
392
393
394class MainWindow(QtWidgets.QMainWindow):
395    """Main window of the Qt application."""
396
397    def __init__(self) -> None:
398        super().__init__()
399
400        self.cancellation_token_source = cancellation.TokenSource()
401        env = coi.make(
402            "ConfParabola-v0",
403            cancellation_token=self.cancellation_token_source.token,
404            render_mode="matplotlib_figures",
405        )
406        self.env = t.cast(ConfParabola, env)
407        self.worker = OptimizerThread(self.env)
408        self.worker.step.connect(self.on_opt_step)
409        self.worker.finished.connect(self.on_opt_finished)
410        self.env.reset()
411
412        [figure] = self.env.render()
413        self.canvas = FigureCanvas(figure)
414        self.launch = QtWidgets.QPushButton("Launch")
415        self.launch.clicked.connect(self.on_launch)
416        self.cancel = QtWidgets.QPushButton("Cancel")
417        self.cancel.clicked.connect(self.on_cancel)
418        self.cancel.setEnabled(False)
419        self.configure_env = QtWidgets.QPushButton("Configure…")
420        self.configure_env.clicked.connect(self.on_configure)
421
422        window = QtWidgets.QWidget()
423        self.setCentralWidget(window)
424        main_layout = QtWidgets.QVBoxLayout(window)
425        buttons_layout = QtWidgets.QHBoxLayout()
426        main_layout.addWidget(self.canvas)
427        main_layout.addLayout(buttons_layout)
428        buttons_layout.addWidget(self.launch)
429        buttons_layout.addWidget(self.cancel)
430        buttons_layout.addWidget(self.configure_env)
431        self.addToolBar(NavigationToolbar(self.canvas, parent=self))
432
433    @override
434    def closeEvent(self, event: QtGui.QCloseEvent) -> None:
435        # pylint: disable = invalid-name, missing-function-docstring
436        self.launch.setEnabled(False)
437        self.cancel.setEnabled(False)
438        self.configure_env.setEnabled(False)
439        self.cancellation_token_source.cancel()
440        self.worker.wait()
441        event.accept()
442
443    def on_configure(self) -> None:
444        """Open the dialog to configure the environment."""
445        assert coi.is_configurable(self.env)
446        dialog = ConfigureDialog(self.env, parent=self)
447        dialog.open()
448
449    def on_launch(self) -> None:
450        """Disable the GUI and start optimization."""
451        self.launch.setEnabled(False)
452        self.cancel.setEnabled(True)
453        self.configure_env.setEnabled(False)
454        self.worker.start()
455
456    def on_cancel(self) -> None:
457        """Send a cancellation request."""
458        self.cancellation_token_source.cancel()
459        # Disable the button. `on_opt_finished()` will eventually
460        # re-enable the other buttons.
461        self.cancel.setEnabled(False)
462
463    def on_opt_step(self) -> None:
464        """Update the plots."""
465        self.env.render()
466        self.canvas.draw()
467
468    def on_opt_finished(self) -> None:
469        """Re-enable the GUI."""
470        # Reset the cancellation, if it is possible. Only re-enable the
471        # launch button if we could reset the cancellation (or no
472        # cancellation ever occurred.)
473        if self.cancellation_token_source.can_reset_cancellation:
474            self.cancellation_token_source.reset_cancellation()
475        if not self.cancellation_token_source.cancellation_requested:
476            self.launch.setEnabled(True)
477        self.cancel.setEnabled(False)
478        self.configure_env.setEnabled(True)
479
480
481def main(argv: list[str]) -> int:
482    """Main function. You should pass in `sys.argv`."""
483    app = QtWidgets.QApplication(argv)
484    window = MainWindow()
485    window.show()
486    return app.exec_()
487
488
489if __name__ == "__main__":
490    sys.exit(main(sys.argv))