Source code for pennylane.transforms.dynamic_one_shot

# Copyright 2018-2024 Xanadu Quantum Technologies Inc.

# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at

#     http://www.apache.org/licenses/LICENSE-2.0

# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Contains the batch dimension transform.
"""

import itertools
from collections import Counter
from collections.abc import Sequence
from functools import partial, singledispatch

import numpy as np

import pennylane as qml
from pennylane.exceptions import QuantumFunctionError
from pennylane.measurements import (
    CountsMP,
    ExpectationMP,
    MeasurementProcess,
    MeasurementValue,
    MidMeasureMP,
    ProbabilityMP,
    SampleMP,
    Shots,
    VarianceMP,
)
from pennylane.tape import QuantumScript, QuantumScriptBatch
from pennylane.typing import PostprocessingFn, Result, ResultBatch, TensorLike

from .core import transform

fill_in_value = np.iinfo(np.int32).min


def is_mcm(operation):
    """Returns True if the operation is a mid-circuit measurement and False otherwise."""
    mcm = isinstance(operation, MidMeasureMP)
    return mcm or "MidCircuitMeasure" in str(type(operation))


def null_postprocessing(results):
    """A postprocessing function returned by a transform that only converts the batch of results
    into a result for a single ``QuantumTape``.
    """
    return results[0]


# pylint: disable=unused-argument
def _expand_fn(
    tape: QuantumScript, postselect_mode=None, **_
) -> tuple[QuantumScriptBatch, PostprocessingFn]:
    if not any(is_mcm(o) for o in tape.operations):
        return (tape,), null_postprocessing
    samples_present = any(isinstance(mp, SampleMP) for mp in tape.measurements)
    postselect_present = any(op.postselect is not None for op in tape.operations if is_mcm(op))
    if postselect_present and samples_present and tape.batch_size is not None:
        raise ValueError(
            "Returning qml.sample is not supported when postselecting mid-circuit "
            "measurements with broadcasting"
        )

    return qml.transforms.broadcast_expand(tape)


def _add_shot_vector_support(fn: PostprocessingFn, shots: Shots) -> PostprocessingFn:
    def new_fn(results: ResultBatch) -> Result:
        results = results[0]
        return tuple(fn((results[slice(*sl)],)) for sl in shots.bins())

    return new_fn


def _squeeze_stack(array):
    return qml.math.squeeze(qml.math.vstack(array))


[docs] @partial(transform, expand_transform=_expand_fn) def dynamic_one_shot( tape: QuantumScript, postselect_mode=None, **_ ) -> tuple[QuantumScriptBatch, PostprocessingFn]: """Transform a QNode to into several one-shot tapes to support dynamic circuit execution. This transform enables the ``"one-shot"`` mid-circuit measurement method. The ``"one-shot"`` method prompts the device to perform a series of one-shot executions, where in each execution, the ``qml.measure`` operation applies a probabilistic mid-circuit measurement to the circuit. This is in contrast with ``qml.defer_measurement``, which instead introduces an extra wire for each mid-circuit measurement. The ``"one-shot"`` method is favourable in the few-shots and several-mid-circuit-measurements limit, whereas ``qml.defer_measurements`` is favourable in the opposite limit. Args: tape (QNode or QuantumScript or Callable): a quantum circuit. Returns: qnode (QNode) or quantum function (Callable) or tuple[List[QuantumScript], function]: The transformed circuit as described in :func:`qml.transform <pennylane.transform>`. This circuit will provide the results of a dynamic execution. **Example** Most devices that support mid-circuit measurements will include this transform in its preprocessing automatically when applicable. When this is the case, any user-applied ``dynamic_one_shot`` transforms will be ignored. The recommended way to use dynamic one shot is to specify ``mcm_method="one-shot"`` in the ``qml.qnode`` decorator. .. code-block:: python dev = qml.device("default.qubit") params = np.pi / 4 * np.ones(2) @partial(qml.set_shots, shots=100) @qml.qnode(dev, mcm_method="one-shot", postselect_mode="fill-shots") def func(x, y): qml.RX(x, wires=0) m0 = qml.measure(0) qml.cond(m0, qml.RY)(y, wires=1) return qml.expval(op=m0) """ if not any(is_mcm(o) for o in tape.operations): return (tape,), null_postprocessing for m in tape.measurements: if not isinstance(m, (CountsMP, ExpectationMP, ProbabilityMP, SampleMP, VarianceMP)): raise TypeError( f"Native mid-circuit measurement mode does not support {type(m).__name__} " "measurements." ) if not tape.shots: raise QuantumFunctionError("dynamic_one_shot is only supported with finite shots.") aux_tapes = [init_auxiliary_tape(tape)] def processing_fn(results): results = results[0] if len(aux_tapes[0].measurements) == 1: results = [_squeeze_stack(tuple(results))] else: results = [ _squeeze_stack(tuple(res[i] for res in results)) for i, _ in enumerate(aux_tapes[0].measurements) ] return parse_native_mid_circuit_measurements( tape, results=results, postselect_mode=postselect_mode ) if tape.shots.has_partitioned_shots: processing_fn = _add_shot_vector_support(processing_fn, tape.shots) return aux_tapes, processing_fn
def get_legacy_capabilities(dev): """Gets the capabilities dictionary of a device.""" assert isinstance(dev, qml.devices.LegacyDeviceFacade) return dev.target_device.capabilities() def _supports_one_shot(dev: "qml.devices.Device"): """Checks whether a device supports one-shot.""" if isinstance(dev, qml.devices.LegacyDevice): return get_legacy_capabilities(dev).get("supports_mid_measure", False) return dev.name in ("default.qubit", "lightning.qubit") or ( dev.capabilities is not None and "one-shot" in dev.capabilities.supported_mcm_methods ) @dynamic_one_shot.custom_qnode_transform def _dynamic_one_shot_qnode(self, qnode, targs, tkwargs): """Custom qnode transform for ``dynamic_one_shot``.""" if tkwargs.get("device", None): raise ValueError( "Cannot provide a 'device' value directly to the dynamic_one_shot decorator " "when transforming a QNode." ) if qnode.device is not None: if not _supports_one_shot(qnode.device): raise TypeError( f"Device {qnode.device.name} does not support mid-circuit measurements and/or " "one-shot execution mode natively, and hence it does not support the " "dynamic_one_shot transform. 'default.qubit' and 'lightning.qubit' currently " "support mid-circuit measurements and the dynamic_one_shot transform." ) tkwargs.setdefault("device", qnode.device) return self.default_qnode_transform(qnode, targs, tkwargs) def init_auxiliary_tape(circuit: qml.tape.QuantumScript): """Creates an auxiliary circuit to perform one-shot mid-circuit measurement calculations. Measurements are replaced by SampleMP measurements on wires and observables found in the original measurements. Args: circuit (QuantumTape): The original QuantumScript Returns: QuantumScript: A copy of the circuit with modified measurements """ new_measurements = [] for m in circuit.measurements: if m.mv is None: if isinstance(m, VarianceMP): new_measurements.append(SampleMP(obs=m.obs)) else: new_measurements.append(m) for op in circuit.operations: if "MidCircuitMeasure" in str(type(op)): # pragma: no cover new_measurements.append(qml.sample(op.out_classical_tracers[0])) elif isinstance(op, MidMeasureMP): new_measurements.append(qml.sample(MeasurementValue([op], lambda res: res))) return qml.tape.QuantumScript( circuit.operations, new_measurements, shots=[1] * circuit.shots.total_shots, trainable_params=circuit.trainable_params, ) def _measurement_with_no_shots(measurement): return ( np.nan * np.ones_like(measurement.eigvals()) if isinstance(measurement, ProbabilityMP) else np.nan ) def _get_is_valid_has_valid(mcm_samples, all_mcms, interface): # Can't use boolean dtype array with tf, hence why conditionally setting items to 0 or 1 has_postselect = qml.math.array( [[op.postselect is not None for op in all_mcms]], like=interface, dtype=mcm_samples.dtype, ) postselect = qml.math.array( [[0 if op.postselect is None else op.postselect for op in all_mcms]], like=interface, dtype=mcm_samples.dtype, ) is_valid = qml.math.all(mcm_samples * has_postselect == postselect, axis=1) has_valid = qml.math.any(is_valid) return is_valid, has_valid # pylint: disable=unused-argument def parse_native_mid_circuit_measurements( circuit: qml.tape.QuantumScript, _removed_arg=None, # need to not break catalyst results: None | TensorLike = None, postselect_mode=None, ): """Combines, gathers and normalizes the results of native mid-circuit measurement runs. Args: circuit (QuantumTape): The original ``QuantumScript``. _removed_arg : a placeholder for an argument that used to exist. Can be removed pending update to catalyst. aux_tapes (List[QuantumTape]): List of auxiliary ``QuantumScript`` objects. results (TensorLike): Array of measurement results. postselect_mode (None | str): how to handle postselection. Returns: tuple(TensorLike): The results of the simulation. """ assert results is not None # condition needed to not break signature interface = qml.math.get_deep_interface(results) interface = "numpy" if interface == "builtins" else interface interface = "tensorflow" if interface == "tf" else interface all_mcms = [op for op in circuit.operations if is_mcm(op)] mcm_samples = qml.math.hstack( tuple(qml.math.reshape(res, (-1, 1)) for res in results[-len(all_mcms) :]) ) mcm_samples = qml.math.array(mcm_samples, like=interface) is_valid, has_valid = _get_is_valid_has_valid(mcm_samples, all_mcms, interface) mcm_samples_map = {mcm: mcm_samples[:, i : i + 1] for i, mcm in enumerate(all_mcms)} normalized_meas, m_count = [], 0 handler = _handle_measurement_qjit if qml.compiler.active() else _handle_measurement for m in circuit.measurements: if not isinstance(m, (CountsMP, ExpectationMP, ProbabilityMP, SampleMP, VarianceMP)): raise TypeError( f"Native mid-circuit measurement mode does not support {type(m).__name__} measurements." ) r, m_count = handler( m, m_count, results, mcm_samples_map, interface=interface, has_valid=has_valid, postselect_mode=postselect_mode, is_valid=is_valid, ) if isinstance(m, SampleMP): r = qml.math.squeeze(r) normalized_meas.append(r) return tuple(normalized_meas) if len(normalized_meas) > 1 else normalized_meas[0] # pylint: disable=too-many-arguments def _handle_measurement_qjit( m: MeasurementProcess, m_count: int, results, mcm_samples, *, is_valid: bool, postselect_mode, **_, ): if m.mv is not None: return ( gather_mcm_qjit(m, mcm_samples, is_valid, postselect_mode=postselect_mode), m_count, ) # pragma: no cover result = results[m_count] if isinstance(m, CountsMP): res = ( result[0][0], qml.math.sum(result[1] * qml.math.reshape(is_valid, (-1, 1)), axis=0), ) return res, m_count + 1 return gather_non_mcm(m, result, is_valid, postselect_mode=postselect_mode), m_count + 1 # pylint: disable=too-many-arguments def _handle_measurement( m: MeasurementProcess, m_count: int, results, mcm_samples, *, interface, has_valid: bool, postselect_mode, is_valid, ): if interface != "jax" and not has_valid: return _measurement_with_no_shots(m), m_count + int(m.mv is None) if m.mv is not None: return gather_mcm(m, mcm_samples, is_valid, postselect_mode=postselect_mode), m_count result = results[m_count] if not isinstance(m, CountsMP): # We don't need to cast to arrays when using qml.counts. qml.math.array is not viable # as it assumes all elements of the input are of builtin python types and not belonging # to any particular interface result = qml.math.array(result, like=interface) return gather_non_mcm(m, result, is_valid, postselect_mode=postselect_mode), m_count + 1 def gather_mcm_qjit(measurement, samples, is_valid, postselect_mode=None): # pragma: no cover """Process MCM measurements when the Catalyst compiler is active. Args: measurement (MeasurementProcess): measurement samples (dict): Mid-circuit measurement samples is_valid (TensorLike): Boolean array with the same shape as ``samples`` where the value at each index specifies whether or not the respective sample is valid. Returns: TensorLike: The combined measurement outcome """ found, meas = False, None for k, meas in samples.items(): if measurement.mv is k.out_classical_tracers[0]: found = True break if not found: raise LookupError("MCM not found") meas = qml.math.squeeze(meas) if isinstance(measurement, (CountsMP, ProbabilityMP)): interface = qml.math.get_interface(is_valid) sum_valid = qml.math.sum(is_valid) count_1 = qml.math.sum(meas * is_valid) if isinstance(measurement, CountsMP): return qml.math.array([0, 1], like=interface), qml.math.array( [sum_valid - count_1, count_1], like=interface ) if isinstance(measurement, ProbabilityMP): counts = qml.math.array([sum_valid - count_1, count_1], like=interface) return counts / sum_valid return gather_non_mcm(measurement, meas, is_valid, postselect_mode=postselect_mode) # pylint: disable=unused-argument @singledispatch def gather_non_mcm(measurement, samples, is_valid, postselect_mode=None) -> TensorLike: """Combines, gathers and normalizes several measurements with trivial measurement values. Args: measurement (MeasurementProcess): measurement samples (TensorLike): Post-processed measurement samples is_valid (TensorLike): Boolean array with the same shape as ``samples`` where the value at each index specifies whether or not the respective sample is valid. postselect_mode (None | str): the postselect mode to use. Returns: TensorLike: The combined measurement outcome """ raise TypeError( f"Native mid-circuit measurement mode does not support {type(measurement).__name__} measurements." ) # pylint: disable=unused-argument @gather_non_mcm.register def _gather_counts(measurement: CountsMP, samples, is_valid, postselect_mode=None): tmp = Counter() if measurement.all_outcomes: if isinstance(measurement.mv, Sequence): values = [list(m.branches.values()) for m in measurement.mv] values = list(itertools.product(*values)) tmp = Counter({"".join(map(str, v)): 0 for v in values}) else: values = [list(measurement.mv.branches.values())] values = list(itertools.product(*values)) tmp = Counter({float(*v): 0 for v in values}) for i, d in enumerate(samples): tmp.update({k if isinstance(k, str) else float(k): v * is_valid[i] for k, v in d.items()}) if not measurement.all_outcomes: tmp = Counter({k: v for k, v in tmp.items() if v > 0}) return dict(sorted(tmp.items())) # pylint: disable=unused-argument @gather_non_mcm.register def _gather_samples(measurement: SampleMP, samples, is_valid, postselect_mode=None): if postselect_mode == "pad-invalid-samples" and samples.ndim == 2: is_valid = qml.math.reshape(is_valid, (-1, 1)) if postselect_mode == "pad-invalid-samples": return qml.math.where(is_valid, samples, fill_in_value) if qml.math.shape(samples) == (): # single shot case samples = qml.math.reshape(samples, (-1, 1)) return samples[is_valid] # pylint: disable=unused-arguement @gather_non_mcm.register def _gather_expval(measurement: ExpectationMP, samples, is_valid, postselect_mode=None): if qml.math.get_interface(is_valid) == "tensorflow": # Tensorflow requires arrays that are used for arithmetic with each other to have the # same dtype. We don't cast if measuring samples as float tf.Tensors cannot be used to # index other tf.Tensors (is_valid is used to index valid samples). is_valid = qml.math.cast_like(is_valid, samples) return qml.math.sum(samples * is_valid) / qml.math.sum(is_valid) # pylint: disable=unused-arguement @gather_non_mcm.register def _gather_probability(measurement: ProbabilityMP, samples, is_valid, postselect_mode=None): if qml.math.get_interface(is_valid) == "tensorflow": # Tensorflow requires arrays that are used for arithmetic with each other to have the # same dtype. We don't cast if measuring samples as float tf.Tensors cannot be used to # index other tf.Tensors (is_valid is used to index valid samples). is_valid = qml.math.cast_like(is_valid, samples) return qml.math.sum(samples * qml.math.reshape(is_valid, (-1, 1)), axis=0) / qml.math.sum( is_valid ) # pylint: disable=unused-argument @gather_non_mcm.register def _gather_variance(measurement: VarianceMP, samples, is_valid, postselect_mode=None): if (interface := qml.math.get_interface(is_valid)) == "tensorflow": # Tensorflow requires arrays that are used for arithmetic with each other to have the # same dtype. We don't cast if measuring samples as float tf.Tensors cannot be used to # index other tf.Tensors (is_valid is used to index valid samples). is_valid = qml.math.cast_like(is_valid, samples) expval = qml.math.sum(samples * is_valid) / qml.math.sum(is_valid) if interface == "tensorflow": # Casting needed for tensorflow samples = qml.math.cast_like(samples, expval) is_valid = qml.math.cast_like(is_valid, expval) return qml.math.sum((samples - expval) ** 2 * is_valid) / qml.math.sum(is_valid) def gather_mcm(measurement: MeasurementProcess, samples, is_valid, postselect_mode=None): """Combines, gathers and normalizes several measurements with non-trivial measurement values. Args: measurement (MeasurementProcess): measurement samples (List[dict]): Mid-circuit measurement samples is_valid (TensorLike): Boolean array with the same shape as ``samples`` where the value at each index specifies whether or not the respective sample is valid. Returns: TensorLike: The combined measurement outcome """ interface = qml.math.get_deep_interface(is_valid) mv = measurement.mv # The following block handles measurement value lists, like ``qml.counts(op=[mcm0, mcm1, mcm2])``. if isinstance(measurement, (CountsMP, ProbabilityMP, SampleMP)) and isinstance(mv, Sequence): mcm_samples = [m.concretize(samples) for m in mv] mcm_samples = qml.math.concatenate(mcm_samples, axis=1) if isinstance(measurement, ProbabilityMP): values = [list(m.branches.values()) for m in mv] values = list(itertools.product(*values)) values = [qml.math.array([v], like=interface, dtype=mcm_samples.dtype) for v in values] # Need to use boolean functions explicitly as Tensorflow does not allow integer math # on boolean arrays counts = [ qml.math.count_nonzero( qml.math.logical_and(qml.math.all(mcm_samples == v, axis=1), is_valid) ) for v in values ] counts = qml.math.array(counts, like=interface) return counts / qml.math.sum(counts) if isinstance(measurement, CountsMP): mcm_samples = [{"".join(str(int(v)) for v in tuple(s)): 1} for s in mcm_samples] return gather_non_mcm(measurement, mcm_samples, is_valid, postselect_mode=postselect_mode) mcm_samples = qml.math.ravel(qml.math.array(mv.concretize(samples), like=interface)) if isinstance(measurement, ProbabilityMP): # Need to use boolean functions explicitly as Tensorflow does not allow integer math # on boolean arrays counts = [ qml.math.count_nonzero(qml.math.logical_and((mcm_samples == v), is_valid)) for v in list(mv.branches.values()) ] counts = qml.math.array(counts, like=interface) return counts / qml.math.sum(counts) if isinstance(measurement, CountsMP): mcm_samples = [{float(s): 1} for s in mcm_samples] return gather_non_mcm(measurement, mcm_samples, is_valid, postselect_mode=postselect_mode)