Source code for mqt.qecc.cc_decoder.decoder

"""LightsOut MaxSAT-based decoder for the hexagonal color code."""

from __future__ import annotations

import datetime
import json
import locale
import subprocess  # noqa: S404
from dataclasses import dataclass, field
from pathlib import Path
from typing import TYPE_CHECKING, Any

import numpy as np
from z3 import Bool, Not, Optimize, Xor, simplify

from . import code_from_string

if TYPE_CHECKING:  # pragma: no cover
    import numpy.typing as npt
    from z3 import ModelRef

    from . import ColorCode


[docs] @dataclass class LightsOut: """Lights out problem representation.""" lights_to_switches: dict[int, list[int]] switches_to_lights: dict[int, list[int]] switch_vars: list[Bool] | None = None helper_vars: dict[int, list[Bool]] = field(default_factory=dict) optimizer: Optimize = field(default_factory=Optimize)
[docs] def preconstruct_parity_constraint(self, light: int, indices: list[int]) -> None: """Preconstruct the parity constraints for a light. Adds all constraint to the optimizer that are independent of the value of the light. """ helper_vars = self.helper_vars[light] assert self.switch_vars is not None for i in range(1, len(indices) - 1): constraint = Xor(self.switch_vars[indices[i]], helper_vars[i]) == helper_vars[i - 1] self.optimizer.add(simplify(constraint)) constraint = self.switch_vars[indices[-1]] == helper_vars[-1] self.optimizer.add(simplify(constraint))
[docs] def complete_parity_constraint(self, light: int, indices: list[int], val: bool) -> None: """Completes the parity constraints for a light. Adds the constraint that is dependent on the value of the light. """ helper_vars = self.helper_vars[light] assert self.switch_vars is not None constraint = Xor(self.switch_vars[indices[0]], helper_vars[0]) == val self.optimizer.add(simplify(constraint))
[docs] def preconstruct_z3_instance(self) -> None: """Preconstruct the z3 instance for the lights-out problem. Creates all necessary variables, adds the known parts of the parity constraints. Soft constraints are added to the optimizer with default weights. """ if self.switch_vars is None: self.switch_vars = [Bool(f"switch_{i}") for i in range(len(self.switches_to_lights))] for light, switches in self.lights_to_switches.items(): if light not in self.helper_vars: self.helper_vars[light] = [Bool(f"helper_{light}_{i}") for i in range(len(switches) - 1)] self.preconstruct_parity_constraint(light, switches) for switch in self.switch_vars: self.optimizer.add_soft(Not(switch))
[docs] def validate_model(self, model: ModelRef, lights: list[bool]) -> bool: """Validate the model by checking if pressing the switches turns off all lights.""" assert self.switch_vars is not None for i, var in enumerate(self.switch_vars): if model[var]: # flip all lights that are controlled by this switch for light in self.switches_to_lights[i]: lights[light] = not lights[light] return all(not light for light in lights)
[docs] def count_switches(self, model: ModelRef) -> int: """Count the number of switches that are set to true.""" assert self.switch_vars is not None return sum(1 for var in self.switch_vars if model[var])
[docs] def solve( self, lights: list[bool], solver_path: str = "z3" ) -> tuple[list[int], datetime.timedelta, datetime.timedelta]: """Solve the lights-out problem for a given pattern. Assumes that the z3 instance has already been pre-constructed. """ # push a new context to the optimizer self.optimizer.push() # add the problem specific constraints start = datetime.datetime.now() for light, val in enumerate(lights): self.complete_parity_constraint(light, self.lights_to_switches[light], val) constr_time = datetime.datetime.now() - start switches: list[int] = [] if solver_path == "z3": # solve the problem start = datetime.datetime.now() result = self.optimizer.check() solve_time = datetime.datetime.now() - start assert str(result) == "sat", "No solution found" # validate the model model = self.optimizer.model() assert self.validate_model(model, lights), "Model is invalid" assert self.switch_vars is not None switches = [1 if model[var] else 0 for var in self.switch_vars] else: self.optimizer.set("pp.wcnf", True) wcnf = str(self.optimizer) # Note: This merely calls the solver. It does not interpret the output. # This is just to measure the time it takes to solve the problem. with Path("./solver-out_" + solver_path.split("/")[-1] + ".txt").open( "a+", encoding=locale.getpreferredencoding(False) ) as out: start = datetime.datetime.now() subprocess.run([solver_path, wcnf], stdout=out, check=False) # noqa: S603 solve_time = datetime.datetime.now() - start # pop the context from the optimizer self.optimizer.pop() return switches, constr_time, solve_time
def simulate_error_rate(code: ColorCode, error_rate: float, nr_sims: int, solver_path: str = "z3") -> dict[str, Any]: """Simulate the logical error rate for a given distance and error rate.""" problem = LightsOut(code.faces_to_qubits, code.qubits_to_faces) start = datetime.datetime.now() problem.preconstruct_z3_instance() preconstr_time = datetime.datetime.now() - start min_wt_logicals: npt.NDArray[np.int_] = np.full(len(code.L), -1).astype(int) logical_errors: npt.NDArray[np.int_] = np.zeros(len(code.L)).astype(int) avg_constr_time = 0.0 avg_solve_time = 0.0 rng = np.random.default_rng() for i in range(nr_sims): # sample error error = rng.choice([0, 1], size=code.n, p=[1 - error_rate, error_rate]) # get syndrome syndrome = code.get_syndrome(error) lights = [bool(b) for b in syndrome] # compute estimate estimate, constr_time, solve_time = problem.solve(lights, solver_path=solver_path) if len(estimate) > 0: # check if the estimate is correct residual = (error + np.array(estimate)) % 2 for logical in range(len(code.L)): if (code.L[logical] @ residual % 2).any(): logical_errors[logical] += 1 wt = np.sum(residual) # compute the min weight of a logical error if min_wt_logicals[logical] == -1 or wt < min_wt_logicals[logical]: min_wt_logicals[logical] = int(wt) break # compute rolling average of the times avg_constr_time = (avg_constr_time * i + constr_time.microseconds) / (i + 1) avg_solve_time = (avg_solve_time * i + solve_time.microseconds) / (i + 1) logical_error_rates: list[float] = [nr_errors / nr_sims for nr_errors in logical_errors] logical_error_rate_ebs: list[float] = [np.sqrt((1 - ler) * ler / nr_sims) for ler in logical_error_rates] avg_total_time = avg_constr_time + avg_solve_time return { "lattice": code.lattice_type, "distance": code.distance, "p": error_rate, "logical_error_rates": logical_error_rates, "logical_error_rate_ebs": logical_error_rate_ebs, "preconstr_time": preconstr_time.microseconds, "avg_constr_time": avg_constr_time, "avg_solve_time": avg_solve_time, "avg_total_time": avg_total_time, "min_wts_logical_err": min_wt_logicals.tolist(), }
[docs] def run( lattice_type: str, distance: int, error_rate: float, nr_sims: int = 10000, results_dir: str = "./results_maxsat", solver: str = "z3", ) -> None: """Run the decoding simulation for a given distance and error rate.""" code = code_from_string(lattice_type, distance) data = simulate_error_rate(code, error_rate, nr_sims, solver) strg = solver.split("/")[-1] filename = f"./code={code.lattice_type},distance={code.distance},p={round(error_rate, 4)},solver={strg}.json" path = Path(results_dir) path.mkdir(parents=True, exist_ok=True) with (path / filename).open("w") as out: out.write(json.dumps(data))