From 11104cc35ac26ab4849b707f5d756d9d46cb1907 Mon Sep 17 00:00:00 2001 From: Renyi Chen Date: Tue, 25 Feb 2025 23:25:08 -0800 Subject: [PATCH 01/19] Support merge 1-qubit gates for parameterized circuits Note: if users expect update sweeps together with merging, it's not supported yet. It's a todo item to be supported in the followup PRs. --- cirq-core/cirq/__init__.py | 1 + cirq-core/cirq/transformers/__init__.py | 1 + .../transformers/merge_single_qubit_gates.py | 88 +++++++++++++++++++ .../merge_single_qubit_gates_test.py | 71 +++++++++++++++ 4 files changed, 161 insertions(+) diff --git a/cirq-core/cirq/__init__.py b/cirq-core/cirq/__init__.py index 9c34025663c..0fd226691bd 100644 --- a/cirq-core/cirq/__init__.py +++ b/cirq-core/cirq/__init__.py @@ -378,6 +378,7 @@ merge_single_qubit_gates_to_phased_x_and_z as merge_single_qubit_gates_to_phased_x_and_z, merge_single_qubit_gates_to_phxz as merge_single_qubit_gates_to_phxz, merge_single_qubit_moments_to_phxz as merge_single_qubit_moments_to_phxz, + merge_into_symbolized_phxz as merge_into_symbolized_phxz, optimize_for_target_gateset as optimize_for_target_gateset, parameterized_2q_op_to_sqrt_iswap_operations as parameterized_2q_op_to_sqrt_iswap_operations, prepare_two_qubit_state_using_cz as prepare_two_qubit_state_using_cz, diff --git a/cirq-core/cirq/transformers/__init__.py b/cirq-core/cirq/transformers/__init__.py index 0c32780bc17..7cad239d350 100644 --- a/cirq-core/cirq/transformers/__init__.py +++ b/cirq-core/cirq/transformers/__init__.py @@ -101,6 +101,7 @@ merge_single_qubit_gates_to_phased_x_and_z as merge_single_qubit_gates_to_phased_x_and_z, merge_single_qubit_gates_to_phxz as merge_single_qubit_gates_to_phxz, merge_single_qubit_moments_to_phxz as merge_single_qubit_moments_to_phxz, + merge_into_symbolized_phxz as merge_into_symbolized_phxz, ) from cirq.transformers.qubit_management_transformers import ( diff --git a/cirq-core/cirq/transformers/merge_single_qubit_gates.py b/cirq-core/cirq/transformers/merge_single_qubit_gates.py index c48e73fae8e..afe9fd5ff55 100644 --- a/cirq-core/cirq/transformers/merge_single_qubit_gates.py +++ b/cirq-core/cirq/transformers/merge_single_qubit_gates.py @@ -14,10 +14,14 @@ """Transformer passes to combine adjacent single-qubit rotations.""" +import enum +import warnings from typing import Optional, TYPE_CHECKING + from cirq import circuits, ops, protocols from cirq.transformers import merge_k_qubit_gates, transformer_api, transformer_primitives +from cirq.study import sweepable from cirq.transformers.analytical_decompositions import single_qubit_decompositions if TYPE_CHECKING: @@ -152,3 +156,87 @@ def merge_func(m1: 'cirq.Moment', m2: 'cirq.Moment') -> Optional['cirq.Moment']: deep=context.deep if context else False, tags_to_ignore=tuple(tags_to_ignore), ).unfreeze(copy=False) + + +@transformer_api.transformer +def merge_into_symbolized_phxz( + circuit: 'cirq.AbstractCircuit', + *, + context: Optional['cirq.TransformerContext'] = None, + sweeps: Optional['sweepable.Sweepable'] = None, + atol: float = 1e-8, +) -> 'cirq.Circuit': + """Merge consecutive single qubit gates into connected symbolized PhasedXZ gates. + + Specifically, if at least one of the consecutive gates is symbolized, then the merged gate + will be a symbolized gate. + + e.g., X-Y-H-phxz(sa, sx, sz) ---transform---> phxz(sa, sx, sz) + + Note, we only consider merging non-parameterized gates to symbolized phxz with + 3 degrees of freedom, meaning that gates like Z^exp_symbol will be considered non-mergable. + + Args: + circuit: Input circuit to transform. It will not be modified. + sweeps: Sweeps of the symbols in the input circuit, updated Sweeps will be returned + based on the transformation. + context: `cirq.TransformerContext` storing common configurable options for transformers. + atol: Absolute tolerance to angle error. Larger values allow more negligible gates to be + dropped, smaller values increase accuracy. + + Returns: + Copy of the transformed input circuit. + """ + + # TODO(#6994): support returning update sweeps when sweeps are provided. + if sweeps is not None: + raise NotImplementedError("To be supported in #6994.") + + if not protocols.is_parameterized(circuit): + warnings.warn( + "Expect parameterized circuits. " + "Please use cirq.merge_single_qubit_gates_to_phxz instead.", + UserWarning, + ) + return merge_single_qubit_gates_to_phxz(circuit, context=context, atol=atol) + + # Merge all non parameterized single qubit gates first. + circuit = merge_single_qubit_gates_to_phxz(circuit, context=context, atol=atol) + + def _merge_func(op1: 'cirq.Operation', op2: 'cirq.Operation'): + + class _MergeGateType(enum.Enum): + MERAGABLE_NON_PARAMETERIZED = 0 + MERAGABLE_PARAMETERIZED_PHXZ = 1 + NON_MERGEABLE = 2 + + def _categorize(op: 'cirq.Operation') -> _MergeGateType: + if protocols.has_unitary(op) and protocols.num_qubits(op) == 1: + return _MergeGateType.MERAGABLE_NON_PARAMETERIZED + if isinstance(op.gate, ops.PhasedXZGate) and protocols.is_parameterized(op): + return _MergeGateType.MERAGABLE_PARAMETERIZED_PHXZ + return _MergeGateType.NON_MERGEABLE + + merge_type1 = _categorize(op1) + merge_type2 = _categorize(op2) + + if ( + merge_type1 == _MergeGateType.NON_MERGEABLE + or merge_type2 == _MergeGateType.NON_MERGEABLE + ): + return None + + # absorb the non-parameterized gate into the parameterized gate. + if merge_type1 == _MergeGateType.MERAGABLE_PARAMETERIZED_PHXZ: + return op1 + if merge_type2 == _MergeGateType.MERAGABLE_PARAMETERIZED_PHXZ: + return op2 + + return None # pragma: no cover + + return transformer_primitives.merge_operations( + circuit, + _merge_func, + deep=context.deep if context else False, + tags_to_ignore=context.tags_to_ignore if context else (), + ).unfreeze() diff --git a/cirq-core/cirq/transformers/merge_single_qubit_gates_test.py b/cirq-core/cirq/transformers/merge_single_qubit_gates_test.py index 8ea1fd3d273..8ded06baa87 100644 --- a/cirq-core/cirq/transformers/merge_single_qubit_gates_test.py +++ b/cirq-core/cirq/transformers/merge_single_qubit_gates_test.py @@ -14,7 +14,10 @@ from typing import List +import pytest +import sympy import cirq +from cirq.study.sweeps import Points def assert_optimizes(optimized: cirq.AbstractCircuit, expected: cirq.AbstractCircuit): @@ -231,3 +234,71 @@ def test_merge_single_qubit_moments_to_phased_x_and_z_global_phase(): c = cirq.Circuit(cirq.GlobalPhaseGate(1j).on()) c2 = cirq.merge_single_qubit_gates_to_phased_x_and_z(c) assert c == c2 + + +def test_merge_into_symbolized_phxz(): + """Test case diagram. + Input circuit: + 0: ───X───────@───H[ignore]───H───X───PhXZ(a=a1,x=x1,z=z1)───X───PhXZ(a=a2,x=x2,z=z2)───H─── + │ ║ + 1: ───Y^0.5───@───M─────────────────────────────────────────────────────────────────────╫─── + ║ ║ + m: ═══════════════@═════════════════════════════════════════════════════════════════════^═══ + Expected output: + 0: ───PhXZ(a=-1,x=1,z=0)──────@───H[ignore]───PhXZ(a=a1,x=x1,z=z1)───H─── + │ ║ + 1: ───PhXZ(a=0.5,x=0.5,z=0)───@───M──────────────────────────────────╫─── + ║ ║ + m: ═══════════════════════════════@══════════════════════════════════^═══ + """ + a, b = cirq.LineQubit.range(2) + sa1, sa2 = [sympy.Symbol(a) for a in ["a1", "a2"]] + sx1, sx2 = [sympy.Symbol(x) for x in ["x1", "x2"]] + sz1, sz2 = [sympy.Symbol(z) for z in ["z1", "z2"]] + input_circuit = cirq.Circuit( + cirq.X(a), + cirq.Y(b) ** 0.5, + cirq.CZ(a, b), + cirq.H(a).with_tags("ignore"), + cirq.H(a), + cirq.X(a), + _phxz(sa1, sx1, sz1).on(a), + cirq.X(a), + _phxz(sa2, sx2, sz2).on(a), + cirq.measure(b, key="m"), + cirq.H(a).with_classical_controls("m"), + ) + context = cirq.TransformerContext(tags_to_ignore=["ignore"]) + assert_optimizes( + optimized=cirq.merge_into_symbolized_phxz(input_circuit, context=context), + expected=cirq.Circuit( + _phxz(-1, 1, 0).on(a), + _phxz(0.5, 0.5, 0).on(b), + cirq.CZ(a, b), + cirq.H(a).with_tags("ignore"), + _phxz(sa1, sx1, sz1).on(a), + cirq.measure(b, key="m"), + cirq.H(a).with_classical_controls("m"), + ), + ) + + +def test_merge_into_symbolized_phxz_other_symbolized_gates(): + a = cirq.NamedQubit('a') + input_circuit = cirq.Circuit(_phxz(1, 1, 1).on(a), cirq.H(a) ** sympy.Symbol("exp")) + assert_optimizes( + optimized=cirq.merge_into_symbolized_phxz(input_circuit), expected=input_circuit + ) + + +def test_merge_into_symbolized_phxz_non_symbolized_input(): + a = cirq.NamedQubit('a') + with pytest.warns(UserWarning): + cirq.merge_into_symbolized_phxz(cirq.Circuit(cirq.H(a), cirq.H(a))) + + +def test_merge_into_symbolized_phxz_with_sweeps(): + with pytest.raises(NotImplementedError): + cirq.merge_into_symbolized_phxz( + cirq.Circuit(), sweeps=[Points(key="x", points=[0.1, 0.2, 0.5])] + ) From f45f90e931ff4eec281eff739be6dc378616edfe Mon Sep 17 00:00:00 2001 From: Renyi Chen Date: Tue, 1 Apr 2025 22:44:30 -0700 Subject: [PATCH 02/19] Reimplement. --- cirq-core/cirq/__init__.py | 2 +- cirq-core/cirq/transformers/__init__.py | 2 +- .../gauge_compiling/gauge_compiling.py | 9 +- .../transformers/merge_single_qubit_gates.py | 205 +++++++++++++----- .../merge_single_qubit_gates_test.py | 96 ++++---- 5 files changed, 204 insertions(+), 110 deletions(-) diff --git a/cirq-core/cirq/__init__.py b/cirq-core/cirq/__init__.py index 0fd226691bd..12916f66a33 100644 --- a/cirq-core/cirq/__init__.py +++ b/cirq-core/cirq/__init__.py @@ -377,8 +377,8 @@ merge_operations_to_circuit_op as merge_operations_to_circuit_op, merge_single_qubit_gates_to_phased_x_and_z as merge_single_qubit_gates_to_phased_x_and_z, merge_single_qubit_gates_to_phxz as merge_single_qubit_gates_to_phxz, + merge_single_qubit_gates_to_phxz_symbolized as merge_single_qubit_gates_to_phxz_symbolized, merge_single_qubit_moments_to_phxz as merge_single_qubit_moments_to_phxz, - merge_into_symbolized_phxz as merge_into_symbolized_phxz, optimize_for_target_gateset as optimize_for_target_gateset, parameterized_2q_op_to_sqrt_iswap_operations as parameterized_2q_op_to_sqrt_iswap_operations, prepare_two_qubit_state_using_cz as prepare_two_qubit_state_using_cz, diff --git a/cirq-core/cirq/transformers/__init__.py b/cirq-core/cirq/transformers/__init__.py index 7cad239d350..833a914ca85 100644 --- a/cirq-core/cirq/transformers/__init__.py +++ b/cirq-core/cirq/transformers/__init__.py @@ -100,8 +100,8 @@ from cirq.transformers.merge_single_qubit_gates import ( merge_single_qubit_gates_to_phased_x_and_z as merge_single_qubit_gates_to_phased_x_and_z, merge_single_qubit_gates_to_phxz as merge_single_qubit_gates_to_phxz, + merge_single_qubit_gates_to_phxz_symbolized as merge_single_qubit_gates_to_phxz_symbolized, merge_single_qubit_moments_to_phxz as merge_single_qubit_moments_to_phxz, - merge_into_symbolized_phxz as merge_into_symbolized_phxz, ) from cirq.transformers.qubit_management_transformers import ( diff --git a/cirq-core/cirq/transformers/gauge_compiling/gauge_compiling.py b/cirq-core/cirq/transformers/gauge_compiling/gauge_compiling.py index 3ad835904a9..ffee54f281e 100644 --- a/cirq-core/cirq/transformers/gauge_compiling/gauge_compiling.py +++ b/cirq-core/cirq/transformers/gauge_compiling/gauge_compiling.py @@ -25,12 +25,11 @@ import sympy from attrs import field, frozen -from cirq import circuits, ops +from cirq.transformers import transformer_api +from cirq import ops, circuits from cirq.protocols import unitary_protocol from cirq.protocols.has_unitary_protocol import has_unitary -from cirq.study import sweepable -from cirq.study.sweeps import Points, Zip -from cirq.transformers import transformer_api +from cirq.study.sweeps import Points, Sweep, Zip from cirq.transformers.analytical_decompositions import single_qubit_decompositions @@ -256,7 +255,7 @@ def as_sweep( N: int, context: Optional[transformer_api.TransformerContext] = None, prng: Optional[np.random.Generator] = None, - ) -> Tuple[circuits.AbstractCircuit, sweepable.Sweepable]: + ) -> Tuple[circuits.AbstractCircuit, Sweep]: """Generates a parameterized circuit with *N* sets of sweepable parameters. Args: diff --git a/cirq-core/cirq/transformers/merge_single_qubit_gates.py b/cirq-core/cirq/transformers/merge_single_qubit_gates.py index afe9fd5ff55..d153c9f983c 100644 --- a/cirq-core/cirq/transformers/merge_single_qubit_gates.py +++ b/cirq-core/cirq/transformers/merge_single_qubit_gates.py @@ -14,14 +14,15 @@ """Transformer passes to combine adjacent single-qubit rotations.""" -import enum +import itertools import warnings -from typing import Optional, TYPE_CHECKING +from typing import Dict, List, Optional, Tuple, TYPE_CHECKING +import sympy from cirq import circuits, ops, protocols -from cirq.transformers import merge_k_qubit_gates, transformer_api, transformer_primitives -from cirq.study import sweepable +from cirq.study.sweeps import Points, Sweep, Zip +from cirq.transformers import merge_k_qubit_gates, transformer_api, transformer_primitives, align from cirq.transformers.analytical_decompositions import single_qubit_decompositions if TYPE_CHECKING: @@ -158,27 +159,43 @@ def merge_func(m1: 'cirq.Moment', m2: 'cirq.Moment') -> Optional['cirq.Moment']: ).unfreeze(copy=False) +def _values_of_sweep(sweep: Sweep, key: str | sympy.Symbol): + p = sympy.Symbol(key) if isinstance(key, str) else key + return [resolver.value_of(p) for resolver in sweep] + + @transformer_api.transformer -def merge_into_symbolized_phxz( +def merge_single_qubit_gates_to_phxz_symbolized( circuit: 'cirq.AbstractCircuit', *, context: Optional['cirq.TransformerContext'] = None, - sweeps: Optional['sweepable.Sweepable'] = None, + sweep: Sweep, atol: float = 1e-8, -) -> 'cirq.Circuit': - """Merge consecutive single qubit gates into connected symbolized PhasedXZ gates. - - Specifically, if at least one of the consecutive gates is symbolized, then the merged gate - will be a symbolized gate. - - e.g., X-Y-H-phxz(sa, sx, sz) ---transform---> phxz(sa, sx, sz) - - Note, we only consider merging non-parameterized gates to symbolized phxz with - 3 degrees of freedom, meaning that gates like Z^exp_symbol will be considered non-mergable. +) -> Tuple['cirq.Circuit', Sweep]: + """Merge consecutive single qubit gates as PhasedXZ Gates. Symbolize if any of the consecutive gates is symbolized. + + Example: + # pylint: disable=line-too-long + >>> q0, q1 = cirq.LineQubit.range(2) + >>> c = cirq.Circuit(cirq.X(q0),cirq.CZ(q0,q1)**sympy.Symbol("cz_exp"),cirq.Y(q0)**sympy.Symbol("y_exp"),cirq.X(q0)) + >>> print(c) + 0: ───X───@──────────Y^y_exp───X─── + │ + 1: ───────@^cz_exp───────────────── + >>> new_circuit, new_sweep = cirq.merge_single_qubit_gates_to_phxz_symbolized(\ + c, sweep=cirq.Points(key="cz_exp", points=[0, 1]) * cirq.Points(key="y_exp", points=[0, 1])\ + ) + >>> print(new_circuit) + 0: ───PhXZ(a=-1,x=1,z=0)───@──────────PhXZ(a=a0,x=x0,z=z0)─── + │ + 1: ────────────────────────@^cz_exp────────────────────────── + >>> print(new_sweep) + cirq.Points('z0', [0, -1.0, 0, -1.0]) + cirq.Points('x0', [1, 0.0, 1, 0.0]) + cirq.Points('a0', [-1.0, -0.5, -1.0, -0.5]) + cirq.Points('cz_exp', [0, 0, 1, 1]) + # pylint: disable=line-too-long Args: circuit: Input circuit to transform. It will not be modified. - sweeps: Sweeps of the symbols in the input circuit, updated Sweeps will be returned + sweep: Sweep of the symbols in the input circuit, updated Sweep will be returned based on the transformation. context: `cirq.TransformerContext` storing common configurable options for transformers. atol: Absolute tolerance to angle error. Larger values allow more negligible gates to be @@ -187,10 +204,7 @@ def merge_into_symbolized_phxz( Returns: Copy of the transformed input circuit. """ - - # TODO(#6994): support returning update sweeps when sweeps are provided. - if sweeps is not None: - raise NotImplementedError("To be supported in #6994.") + deep = context.deep if context else False if not protocols.is_parameterized(circuit): warnings.warn( @@ -200,43 +214,128 @@ def merge_into_symbolized_phxz( ) return merge_single_qubit_gates_to_phxz(circuit, context=context, atol=atol) - # Merge all non parameterized single qubit gates first. - circuit = merge_single_qubit_gates_to_phxz(circuit, context=context, atol=atol) + # Tag symbolized single qubit op. + symbolized_single_tag = "_symbolized_single" - def _merge_func(op1: 'cirq.Operation', op2: 'cirq.Operation'): + circuit_tagged = transformer_primitives.map_operations( + circuit, + lambda op, _: ( + op.with_tags(symbolized_single_tag) + if protocols.is_parameterized(op) and len(op.qubits) == 1 + else op + ), + deep=deep, + ) - class _MergeGateType(enum.Enum): - MERAGABLE_NON_PARAMETERIZED = 0 - MERAGABLE_PARAMETERIZED_PHXZ = 1 - NON_MERGEABLE = 2 + # Symbols of the single qubit symbolized ops. + single_qubit_gate_symbols: set[sympy.Symbol] = set().union( + *[ + protocols.parameter_symbols(op) if symbolized_single_tag in op.tags else set() + for op in circuit_tagged.all_operations() + ] + ) + # Remaing symbols, e.g., 2 qubit gates' symbols. Sweep of those symbols keeps unchanged. + remaining_symbols = protocols.parameter_symbols(circuit) - single_qubit_gate_symbols - def _categorize(op: 'cirq.Operation') -> _MergeGateType: - if protocols.has_unitary(op) and protocols.num_qubits(op) == 1: - return _MergeGateType.MERAGABLE_NON_PARAMETERIZED - if isinstance(op.gate, ops.PhasedXZGate) and protocols.is_parameterized(op): - return _MergeGateType.MERAGABLE_PARAMETERIZED_PHXZ - return _MergeGateType.NON_MERGEABLE + sweep_of_single: Sweep = Zip( + *[Points(key=k, points=_values_of_sweep(sweep, k)) for k in single_qubit_gate_symbols] + ) - merge_type1 = _categorize(op1) - merge_type2 = _categorize(op2) + # Get all resolved circuits from all sets of resolvers in sweep. + resolved_circuits = [ + protocols.resolve_parameters(circuit_tagged, resolver) for resolver in sweep_of_single + ] + + # Store the number of merges for all set of resolvers, + # it should be the same for all resolved circuits. + merge_counts: list[int] = [] + merged_circuits = [] + phxz_tag_prefix = "_phxz" + tag_iter: itertools.count + + def rewriter(circuit_op: 'cirq.CircuitOperation') -> 'cirq.OP_TREE': + nonlocal tag_iter + tag: Optional[str] = None + u = protocols.unitary(circuit_op) + if protocols.num_qubits(circuit_op) == 0: + return ops.GlobalPhaseGate(u[0, 0]).on() + for op in circuit_op.circuit.all_operations(): + if symbolized_single_tag in op.tags: + # Record parameterizations info via tags. + tag = f"{phxz_tag_prefix}_{next(tag_iter)}" + break + gate = single_qubit_decompositions.single_qubit_matrix_to_phxz(u, atol) or ops.I + op = gate.on(circuit_op.qubits[0]) + if not gate: + return [] + return op.with_tags(tag) if tag else op + + for resolved_circuit in resolved_circuits: + tag_iter = itertools.count(start=0, step=1) + merged_circuits.append( + merge_k_qubit_gates.merge_k_qubit_unitaries( + resolved_circuit, k=1, context=context, rewriter=rewriter + ) + ) + merge_counts.append(next(tag_iter)) - if ( - merge_type1 == _MergeGateType.NON_MERGEABLE - or merge_type2 == _MergeGateType.NON_MERGEABLE - ): - return None + if not all(count == merge_counts[0] for count in merge_counts): + raise RuntimeError("Different resolvers in sweep result different merged strcuture.") + + # Get the output circuit from the first resolved circuits. + merge_tags: set[str] = {f"{phxz_tag_prefix}_{i}" for i in range(merge_counts[0])} + new_symbols: set[str] = set().union( + *[{f"x{i}", f"z{i}", f"a{i}"} for i in range(merge_counts[0])] + ) - # absorb the non-parameterized gate into the parameterized gate. - if merge_type1 == _MergeGateType.MERAGABLE_PARAMETERIZED_PHXZ: - return op1 - if merge_type2 == _MergeGateType.MERAGABLE_PARAMETERIZED_PHXZ: - return op2 + def _map_func(op: 'cirq.Operation', _): + """Maps op with tag `_phxz_i` to a symbolzied `PhasedXZGate(xi,zi,ai)`""" + the_merge_tag = merge_tags.intersection(op.tags) + if len(the_merge_tag) == 0: + return op + if len(the_merge_tag) > 1: + raise RuntimeError("Multiple merge tags found.") + sid = the_merge_tag.pop().split("_")[-1] + phxz_params = { + "x_exponent": sympy.Symbol(f"x{sid}"), + "z_exponent": sympy.Symbol(f"z{sid}"), + "axis_phase_exponent": sympy.Symbol(f"a{sid}"), + } + return ops.PhasedXZGate(**phxz_params).on(*op.qubits) + + output_circuit: 'cirq.Circuit' = align.align_right( + transformer_primitives.map_operations(merged_circuits[0].freeze(), _map_func, deep=deep) + ) - return None # pragma: no cover + values_by_params: Dict[str, List[float]] = { + **{s: [] for s in new_symbols}, # New symbols introduced in merging + **{ + s: _values_of_sweep(sweep, s) for s in remaining_symbols + }, # Existing symbols in ops that are not merged, e.g., symbols in 2 qubit gates. + } + + # Get parameterization for the merged phxz gates. + for merged_circuit in merged_circuits: + for op in merged_circuit.all_operations(): + the_merge_tag = merge_tags.intersection(op.tags) + if len(the_merge_tag) == 0: + continue + if len(the_merge_tag) > 1: + raise RuntimeError("Multiple merge tags found.") + sid = the_merge_tag.pop().split("_")[-1] + x, z, a = 0.0, 0.0, 0.0 # Identity gate's parameters. + if isinstance(op.gate, ops.PhasedXZGate): + x, z, a = op.gate.x_exponent, op.gate.z_exponent, op.gate.axis_phase_exponent + elif op.gate is not ops.I: + raise RuntimeError( + f"Expect the merged gate to be a PhasedXZGate or IdentityGate. But got {op.gate}." + ) + values_by_params[f"x{sid}"].append(x) + values_by_params[f"z{sid}"].append(z) + values_by_params[f"a{sid}"].append(a) + + new_sweep: Sweep = Zip( + *[Points(key=key, points=values) for key, values in values_by_params.items()] + ) - return transformer_primitives.merge_operations( - circuit, - _merge_func, - deep=context.deep if context else False, - tags_to_ignore=context.tags_to_ignore if context else (), - ).unfreeze() + return output_circuit.unfreeze(copy=False), new_sweep diff --git a/cirq-core/cirq/transformers/merge_single_qubit_gates_test.py b/cirq-core/cirq/transformers/merge_single_qubit_gates_test.py index 8ded06baa87..0aa65bfc01c 100644 --- a/cirq-core/cirq/transformers/merge_single_qubit_gates_test.py +++ b/cirq-core/cirq/transformers/merge_single_qubit_gates_test.py @@ -84,7 +84,7 @@ def test_merge_single_qubit_gates_to_phased_x_and_z_deep(): cirq.testing.assert_same_circuits(c_new, c_expected) -def _phxz(a: float, x: float, z: float): +def _phxz(a: float | sympy.Symbol, x: float | sympy.Symbol, z: float | sympy.Symbol): return cirq.PhasedXZGate(axis_phase_exponent=a, x_exponent=x, z_exponent=z) @@ -239,66 +239,62 @@ def test_merge_single_qubit_moments_to_phased_x_and_z_global_phase(): def test_merge_into_symbolized_phxz(): """Test case diagram. Input circuit: - 0: ───X───────@───H[ignore]───H───X───PhXZ(a=a1,x=x1,z=z1)───X───PhXZ(a=a2,x=x2,z=z2)───H─── - │ ║ - 1: ───Y^0.5───@───M─────────────────────────────────────────────────────────────────────╫─── - ║ ║ - m: ═══════════════@═════════════════════════════════════════════════════════════════════^═══ + 0: ───X─────────@──────────H[ignore]───H───X───PhXZ(a=a0,x=x0,z=z0)───X───PhXZ(a=a1,x=x1,z=z1)─── + │ + 1: ───H^h_exp───@^cz_exp───────────────────────────────────────────────────────────────────────── Expected output: - 0: ───PhXZ(a=-1,x=1,z=0)──────@───H[ignore]───PhXZ(a=a1,x=x1,z=z1)───H─── - │ ║ - 1: ───PhXZ(a=0.5,x=0.5,z=0)───@───M──────────────────────────────────╫─── - ║ ║ - m: ═══════════════════════════════@══════════════════════════════════^═══ + 0: ───PhXZ(a=-1,x=1,z=0)─────@──────────H[ignore]───PhXZ(a=a1,x=x1,z=z1)─── + │ + 1: ───PhXZ(a=a0,x=x0,z=z0)───@^cz_exp────────────────────────────────────── """ a, b = cirq.LineQubit.range(2) - sa1, sa2 = [sympy.Symbol(a) for a in ["a1", "a2"]] - sx1, sx2 = [sympy.Symbol(x) for x in ["x1", "x2"]] - sz1, sz2 = [sympy.Symbol(z) for z in ["z1", "z2"]] + sa0, sa1 = [sympy.Symbol(a) for a in ["a0", "a1"]] + sx0, sx1 = [sympy.Symbol(x) for x in ["x0", "x1"]] + sz0, sz1 = [sympy.Symbol(z) for z in ["z0", "z1"]] input_circuit = cirq.Circuit( - cirq.X(a), - cirq.Y(b) ** 0.5, - cirq.CZ(a, b), - cirq.H(a).with_tags("ignore"), - cirq.H(a), - cirq.X(a), - _phxz(sa1, sx1, sz1).on(a), - cirq.X(a), - _phxz(sa2, sx2, sz2).on(a), - cirq.measure(b, key="m"), - cirq.H(a).with_classical_controls("m"), + cirq.Moment(cirq.X(a), cirq.H(b) ** sympy.Symbol("h_exp")), + cirq.Moment(cirq.CZ(a, b) ** sympy.Symbol("cz_exp")), + cirq.Moment(cirq.H(a).with_tags("ignore")), + cirq.Moment(cirq.H(a)), + cirq.Moment(cirq.X(a)), + cirq.Moment(_phxz(sa0, sx0, sz0).on(a)), + cirq.Moment(cirq.X(a)), + cirq.Moment(_phxz(sa1, sx1, sz1).on(a)), ) context = cirq.TransformerContext(tags_to_ignore=["ignore"]) - assert_optimizes( - optimized=cirq.merge_into_symbolized_phxz(input_circuit, context=context), - expected=cirq.Circuit( - _phxz(-1, 1, 0).on(a), - _phxz(0.5, 0.5, 0).on(b), - cirq.CZ(a, b), - cirq.H(a).with_tags("ignore"), - _phxz(sa1, sx1, sz1).on(a), - cirq.measure(b, key="m"), - cirq.H(a).with_classical_controls("m"), - ), + sweep = cirq.Zip( + cirq.Points(key="h_exp", points=[0, 1]), + cirq.Points(key="cz_exp", points=[0, 1]), + cirq.Points(key="a0", points=[0, 1]), + cirq.Points(key="x0", points=[0, 1]), + cirq.Points(key="z0", points=[0, 1]), + cirq.Points(key="a1", points=[0, 1]), + cirq.Points(key="x1", points=[0, 1]), + cirq.Points(key="z1", points=[0, 1]), ) - - -def test_merge_into_symbolized_phxz_other_symbolized_gates(): - a = cirq.NamedQubit('a') - input_circuit = cirq.Circuit(_phxz(1, 1, 1).on(a), cirq.H(a) ** sympy.Symbol("exp")) - assert_optimizes( - optimized=cirq.merge_into_symbolized_phxz(input_circuit), expected=input_circuit + output_circuit, new_sweep = cirq.merge_single_qubit_gates_to_phxz_symbolized( + input_circuit, context=context, sweep=sweep + ) + expected = cirq.Circuit( + cirq.Moment(_phxz(-1, 1, 0).on(a), _phxz(sa0, sx0, sz0).on(b)), + cirq.Moment(cirq.CZ(a, b) ** sympy.Symbol("cz_exp")), + cirq.Moment(cirq.H(a).with_tags("ignore")), + cirq.Moment(_phxz(sa1, sx1, sz1).on(a)), ) + assert_optimizes(output_circuit, expected) + + # Check the unitaries are preserved for each set of sweep paramerization. + for old_resolver, new_resolver in zip(sweep, new_sweep): + cirq.testing.assert_circuits_have_same_unitary_given_final_permutation( + cirq.resolve_parameters(input_circuit, old_resolver), + cirq.resolve_parameters(output_circuit, new_resolver), + {q: q for q in input_circuit.all_qubits()}, + ) def test_merge_into_symbolized_phxz_non_symbolized_input(): a = cirq.NamedQubit('a') with pytest.warns(UserWarning): - cirq.merge_into_symbolized_phxz(cirq.Circuit(cirq.H(a), cirq.H(a))) - - -def test_merge_into_symbolized_phxz_with_sweeps(): - with pytest.raises(NotImplementedError): - cirq.merge_into_symbolized_phxz( - cirq.Circuit(), sweeps=[Points(key="x", points=[0.1, 0.2, 0.5])] + cirq.merge_single_qubit_gates_to_phxz_symbolized( + cirq.Circuit(cirq.H(a), cirq.H(a)), sweep=cirq.Points(key="a", points=[0.1, 0.2, 0.5]) ) From a48d4bba91400bd5c24e30cd34b9b6c22a555a17 Mon Sep 17 00:00:00 2001 From: Renyi Chen Date: Tue, 8 Apr 2025 19:22:49 -0700 Subject: [PATCH 03/19] fix checks. --- .../gauge_compiling/gauge_compiling.py | 4 +- .../transformers/merge_single_qubit_gates.py | 298 +++++++++++------- .../merge_single_qubit_gates_test.py | 111 ++++++- 3 files changed, 289 insertions(+), 124 deletions(-) diff --git a/cirq-core/cirq/transformers/gauge_compiling/gauge_compiling.py b/cirq-core/cirq/transformers/gauge_compiling/gauge_compiling.py index ffee54f281e..f15dc08e6ee 100644 --- a/cirq-core/cirq/transformers/gauge_compiling/gauge_compiling.py +++ b/cirq-core/cirq/transformers/gauge_compiling/gauge_compiling.py @@ -25,11 +25,11 @@ import sympy from attrs import field, frozen -from cirq.transformers import transformer_api -from cirq import ops, circuits +from cirq import circuits, ops from cirq.protocols import unitary_protocol from cirq.protocols.has_unitary_protocol import has_unitary from cirq.study.sweeps import Points, Sweep, Zip +from cirq.transformers import transformer_api from cirq.transformers.analytical_decompositions import single_qubit_decompositions diff --git a/cirq-core/cirq/transformers/merge_single_qubit_gates.py b/cirq-core/cirq/transformers/merge_single_qubit_gates.py index d153c9f983c..78cfd6db091 100644 --- a/cirq-core/cirq/transformers/merge_single_qubit_gates.py +++ b/cirq-core/cirq/transformers/merge_single_qubit_gates.py @@ -15,14 +15,13 @@ """Transformer passes to combine adjacent single-qubit rotations.""" import itertools -import warnings -from typing import Dict, List, Optional, Tuple, TYPE_CHECKING +from typing import Dict, Hashable, List, Optional, Tuple, TYPE_CHECKING import sympy from cirq import circuits, ops, protocols from cirq.study.sweeps import Points, Sweep, Zip -from cirq.transformers import merge_k_qubit_gates, transformer_api, transformer_primitives, align +from cirq.transformers import align, merge_k_qubit_gates, transformer_api, transformer_primitives from cirq.transformers.analytical_decompositions import single_qubit_decompositions if TYPE_CHECKING: @@ -159,115 +158,53 @@ def merge_func(m1: 'cirq.Moment', m2: 'cirq.Moment') -> Optional['cirq.Moment']: ).unfreeze(copy=False) +# ---------------------------------------------------------------------- +# Impl merge_single_qubit_gates_to_phxz_symbolized: Start +# ---------------------------------------------------------------------- + + def _values_of_sweep(sweep: Sweep, key: str | sympy.Symbol): p = sympy.Symbol(key) if isinstance(key, str) else key return [resolver.value_of(p) for resolver in sweep] -@transformer_api.transformer -def merge_single_qubit_gates_to_phxz_symbolized( - circuit: 'cirq.AbstractCircuit', - *, - context: Optional['cirq.TransformerContext'] = None, - sweep: Sweep, - atol: float = 1e-8, -) -> Tuple['cirq.Circuit', Sweep]: - """Merge consecutive single qubit gates as PhasedXZ Gates. Symbolize if any of the consecutive gates is symbolized. - - Example: - # pylint: disable=line-too-long - >>> q0, q1 = cirq.LineQubit.range(2) - >>> c = cirq.Circuit(cirq.X(q0),cirq.CZ(q0,q1)**sympy.Symbol("cz_exp"),cirq.Y(q0)**sympy.Symbol("y_exp"),cirq.X(q0)) - >>> print(c) - 0: ───X───@──────────Y^y_exp───X─── - │ - 1: ───────@^cz_exp───────────────── - >>> new_circuit, new_sweep = cirq.merge_single_qubit_gates_to_phxz_symbolized(\ - c, sweep=cirq.Points(key="cz_exp", points=[0, 1]) * cirq.Points(key="y_exp", points=[0, 1])\ - ) - >>> print(new_circuit) - 0: ───PhXZ(a=-1,x=1,z=0)───@──────────PhXZ(a=a0,x=x0,z=z0)─── - │ - 1: ────────────────────────@^cz_exp────────────────────────── - >>> print(new_sweep) - cirq.Points('z0', [0, -1.0, 0, -1.0]) + cirq.Points('x0', [1, 0.0, 1, 0.0]) + cirq.Points('a0', [-1.0, -0.5, -1.0, -0.5]) + cirq.Points('cz_exp', [0, 0, 1, 1]) - # pylint: disable=line-too-long +def _merge_single_qubit_gates_to_circuit_op_symbolized( + resolved_circuits: List['cirq.AbstractCircuit'], + symbolized_single_tag: str, + context: Optional['cirq.TransformerContext'], + atol: float, +) -> Tuple[List['cirq.Circuit'], frozenset[str], frozenset[str]]: + """Helper function to merge single qubit ops of resolved circuits to ops of CircuitOperation + type using merge_k_qubit_unitaries. Args: - circuit: Input circuit to transform. It will not be modified. - sweep: Sweep of the symbols in the input circuit, updated Sweep will be returned - based on the transformation. - context: `cirq.TransformerContext` storing common configurable options for transformers. - atol: Absolute tolerance to angle error. Larger values allow more negligible gates to be - dropped, smaller values increase accuracy. + resolved_circuits: A list of circuits where symbols have been replaced with concrete values. + symbolized_single_tag: The tag applied to single-qubit operations that originally contained symbols + before parameterizations. Returns: - Copy of the transformed input circuit. + Tuple of merge counts, merged circuits, and merge tags. """ - deep = context.deep if context else False - - if not protocols.is_parameterized(circuit): - warnings.warn( - "Expect parameterized circuits. " - "Please use cirq.merge_single_qubit_gates_to_phxz instead.", - UserWarning, - ) - return merge_single_qubit_gates_to_phxz(circuit, context=context, atol=atol) - - # Tag symbolized single qubit op. - symbolized_single_tag = "_symbolized_single" - - circuit_tagged = transformer_primitives.map_operations( - circuit, - lambda op, _: ( - op.with_tags(symbolized_single_tag) - if protocols.is_parameterized(op) and len(op.qubits) == 1 - else op - ), - deep=deep, - ) - - # Symbols of the single qubit symbolized ops. - single_qubit_gate_symbols: set[sympy.Symbol] = set().union( - *[ - protocols.parameter_symbols(op) if symbolized_single_tag in op.tags else set() - for op in circuit_tagged.all_operations() - ] - ) - # Remaing symbols, e.g., 2 qubit gates' symbols. Sweep of those symbols keeps unchanged. - remaining_symbols = protocols.parameter_symbols(circuit) - single_qubit_gate_symbols - - sweep_of_single: Sweep = Zip( - *[Points(key=k, points=_values_of_sweep(sweep, k)) for k in single_qubit_gate_symbols] - ) - - # Get all resolved circuits from all sets of resolvers in sweep. - resolved_circuits = [ - protocols.resolve_parameters(circuit_tagged, resolver) for resolver in sweep_of_single - ] - - # Store the number of merges for all set of resolvers, - # it should be the same for all resolved circuits. - merge_counts: list[int] = [] - merged_circuits = [] - phxz_tag_prefix = "_phxz" + merge_counts: list[int] = [] # number of merges per resolved_circuit + merged_circuits: list['cirq.Circuit'] = [] tag_iter: itertools.count + phxz_tag_prefix = "_phxz" def rewriter(circuit_op: 'cirq.CircuitOperation') -> 'cirq.OP_TREE': nonlocal tag_iter tag: Optional[str] = None + u = protocols.unitary(circuit_op) if protocols.num_qubits(circuit_op) == 0: return ops.GlobalPhaseGate(u[0, 0]).on() + # If any of the op in the merged circuit_op is a symbolized single qubit gate, + # tag the merged phxz gate with next tag id, for further parameterization references. for op in circuit_op.circuit.all_operations(): if symbolized_single_tag in op.tags: - # Record parameterizations info via tags. tag = f"{phxz_tag_prefix}_{next(tag_iter)}" break gate = single_qubit_decompositions.single_qubit_matrix_to_phxz(u, atol) or ops.I op = gate.on(circuit_op.qubits[0]) - if not gate: - return [] return op.with_tags(tag) if tag else op for resolved_circuit in resolved_circuits: @@ -280,22 +217,49 @@ def rewriter(circuit_op: 'cirq.CircuitOperation') -> 'cirq.OP_TREE': merge_counts.append(next(tag_iter)) if not all(count == merge_counts[0] for count in merge_counts): - raise RuntimeError("Different resolvers in sweep result different merged strcuture.") + raise RuntimeError("Different resolvers in sweep resulted in different merged structures.") - # Get the output circuit from the first resolved circuits. - merge_tags: set[str] = {f"{phxz_tag_prefix}_{i}" for i in range(merge_counts[0])} - new_symbols: set[str] = set().union( - *[{f"x{i}", f"z{i}", f"a{i}"} for i in range(merge_counts[0])] + merge_tags: frozenset[str] = frozenset( + {f"{phxz_tag_prefix}_{i}" for i in range(merge_counts[0])} + ) + new_symbols: frozenset[str] = frozenset( + set().union(*[{f"x{i}", f"z{i}", f"a{i}"} for i in range(merge_counts[0])]) ) + return merged_circuits, merge_tags, new_symbols + + +def _get_merge_tag_id(merge_tags: frozenset[str], op_tags: Tuple[Hashable, ...]) -> Optional[str]: + """Extract the id `i` from the merge tag `_phxz_i` if it exists.""" + the_merge_tag: set[str] = set(merge_tags.intersection(op_tags)) + if len(the_merge_tag) == 0: + return None + if len(the_merge_tag) > 1: + raise RuntimeError("Multiple merge tags found.") + return the_merge_tag.pop().split("_")[-1] + + +def _map_merged_ops_to_symbolized_phxz( + circuit: 'cirq.Circuit', merge_tags: frozenset[str], deep: bool +) -> 'cirq.Circuit': + """Maps merged operations (tagged with merge_tags) in the circuit to symbolized PhasedXZGates. + + Args: + circuit: Circuit with merge tags to be mapped. + merge_tags: The set of tags used to identify the merged PhasedXZ gates that need to be + symbolized. + deep: Whether to perform the mapping recursively within CircuitOperations. + + Returns: + A new circuit where tagged PhasedXZ gates are replaced by symbolized versions. + """ + + # Map merged ops to `PhasedXZGate(xi,zi,ai)` based on the tag "_phxz_i". def _map_func(op: 'cirq.Operation', _): - """Maps op with tag `_phxz_i` to a symbolzied `PhasedXZGate(xi,zi,ai)`""" - the_merge_tag = merge_tags.intersection(op.tags) - if len(the_merge_tag) == 0: + """Maps an op with tag `_phxz_i` to a symbolzied `PhasedXZGate(xi,zi,ai)`""" + sid = _get_merge_tag_id(merge_tags, op.tags) + if sid is None: return op - if len(the_merge_tag) > 1: - raise RuntimeError("Multiple merge tags found.") - sid = the_merge_tag.pop().split("_")[-1] phxz_params = { "x_exponent": sympy.Symbol(f"x{sid}"), "z_exponent": sympy.Symbol(f"z{sid}"), @@ -303,39 +267,145 @@ def _map_func(op: 'cirq.Operation', _): } return ops.PhasedXZGate(**phxz_params).on(*op.qubits) - output_circuit: 'cirq.Circuit' = align.align_right( - transformer_primitives.map_operations(merged_circuits[0].freeze(), _map_func, deep=deep) + return align.align_right( + transformer_primitives.map_operations(circuit.freeze(), _map_func, deep=deep) ) + +def _parameterize_merged_circuits( + merged_circuits: List['cirq.Circuit'], + merge_tags: frozenset[str], + new_symbols: frozenset[str], + remaining_symbols: frozenset[str], + sweep: Sweep, +) -> Sweep: + """Parameterizes the merged circuits and returns a new sweep.""" values_by_params: Dict[str, List[float]] = { - **{s: [] for s in new_symbols}, # New symbols introduced in merging + **{s: [] for s in new_symbols}, # New symbols introduced during merging **{ s: _values_of_sweep(sweep, s) for s in remaining_symbols - }, # Existing symbols in ops that are not merged, e.g., symbols in 2 qubit gates. + }, # Existing symbols in ops that were not merged, e.g., symbols in 2-qubit gates. } - # Get parameterization for the merged phxz gates. for merged_circuit in merged_circuits: for op in merged_circuit.all_operations(): - the_merge_tag = merge_tags.intersection(op.tags) - if len(the_merge_tag) == 0: + sid = _get_merge_tag_id(merge_tags, op.tags) + if sid is None: continue - if len(the_merge_tag) > 1: - raise RuntimeError("Multiple merge tags found.") - sid = the_merge_tag.pop().split("_")[-1] - x, z, a = 0.0, 0.0, 0.0 # Identity gate's parameters. + x, z, a = 0.0, 0.0, 0.0 # Identity gate's parameters if isinstance(op.gate, ops.PhasedXZGate): x, z, a = op.gate.x_exponent, op.gate.z_exponent, op.gate.axis_phase_exponent elif op.gate is not ops.I: raise RuntimeError( - f"Expect the merged gate to be a PhasedXZGate or IdentityGate. But got {op.gate}." + f"Expected the merged gate to be a PhasedXZGate or IdentityGate," + f" but got {op.gate}." ) values_by_params[f"x{sid}"].append(x) values_by_params[f"z{sid}"].append(z) values_by_params[f"a{sid}"].append(a) - new_sweep: Sweep = Zip( - *[Points(key=key, points=values) for key, values in values_by_params.items()] + return Zip(*[Points(key=key, points=values) for key, values in values_by_params.items()]) + + +def merge_single_qubit_gates_to_phxz_symbolized( + circuit: 'cirq.AbstractCircuit', + *, + context: Optional['cirq.TransformerContext'] = None, + sweep: Sweep, + atol: float = 1e-8, +) -> Tuple['cirq.Circuit', Sweep]: + """Merge consecutive single qubit gates as PhasedXZ Gates. Symbolize if any of the consecutive + gates is symbolized. + + Example: + >>> q0, q1 = cirq.LineQubit.range(2) + >>> c = cirq.Circuit(\ + cirq.X(q0),\ + cirq.CZ(q0,q1)**sympy.Symbol("cz_exp"),\ + cirq.Y(q0)**sympy.Symbol("y_exp"),\ + cirq.X(q0)) + >>> print(c) + 0: ───X───@──────────Y^y_exp───X─── + │ + 1: ───────@^cz_exp───────────────── + >>> new_circuit, new_sweep = cirq.merge_single_qubit_gates_to_phxz_symbolized(\ + c, sweep=cirq.Zip(cirq.Points(key="cz_exp", points=[0, 1]),\ + cirq.Points(key="y_exp", points=[0, 1]))) + >>> print(new_circuit) + 0: ───PhXZ(a=-1,x=1,z=0)───@──────────PhXZ(a=a0,x=x0,z=z0)─── + │ + 1: ────────────────────────@^cz_exp────────────────────────── + >>> assert new_sweep[0] == cirq.ParamResolver({'a0': -1, 'x0': 1, 'z0': 0, 'cz_exp': 0}) + >>> assert new_sweep[1] == cirq.ParamResolver({'a0': -0.5, 'x0': 0, 'z0': -1, 'cz_exp': 1}) + + Args: + circuit: Input circuit to transform. It will not be modified. + sweep: Sweep of the symbols in the input circuit, updated Sweep will be returned + based on the transformation. + context: `cirq.TransformerContext` storing common configurable options for transformers. + atol: Absolute tolerance to angle error. Larger values allow more negligible gates to be + dropped, smaller values increase accuracy. + + Returns: + Copy of the transformed input circuit. + """ + deep = context.deep if context else False + + # Tag symbolized single-qubit op. + symbolized_single_tag = "_symbolized_single" + + circuit_tagged = transformer_primitives.map_operations( + circuit, + lambda op, _: ( + op.with_tags(symbolized_single_tag) + if protocols.is_parameterized(op) and len(op.qubits) == 1 + else op + ), + deep=deep, + ) + + # Step 0, isolate single qubit symbolized symbols and resolve the circuit on them. + + single_qubit_gate_symbols: frozenset[sympy.Symbol] = frozenset( + set().union( + *[ + protocols.parameter_symbols(op) if symbolized_single_tag in op.tags else set() + for op in circuit_tagged.all_operations() + ] + ) ) + # If all single qubit gates are not parameterized, call the nonparamerized version of + # the transformer. + if not single_qubit_gate_symbols: + return merge_single_qubit_gates_to_phxz(circuit, context=context, atol=atol), sweep + # Remaining symbols, e.g., 2 qubit gates' symbols. Sweep of those symbols keeps unchanged. + remaining_symbols: frozenset[sympy.Symbol] = frozenset( + protocols.parameter_symbols(circuit) - single_qubit_gate_symbols + ) + sweep_of_single: Sweep = Zip( + *[Points(key=k, points=_values_of_sweep(sweep, k)) for k in single_qubit_gate_symbols] + ) + # Get all resolved circuits from all sets of resolvers in the sweep. + resolved_circuits = [ + protocols.resolve_parameters(circuit_tagged, resolver) for resolver in sweep_of_single + ] + + # Step 1, merge single qubit gates of resolved circuits using merge_k_qubit_unitaries. + merged_circuits, merge_tags, new_symbols = _merge_single_qubit_gates_to_circuit_op_symbolized( + resolved_circuits, symbolized_single_tag, context, atol + ) + + # Step 2, get the new symbolzied circuit as new_sweep by mapping merged operations. + new_circuit = _map_merged_ops_to_symbolized_phxz(merged_circuits[0], merge_tags, deep) + + # Step 3, get N sets of parameterizations as new_sweep. + new_sweep = _parameterize_merged_circuits( + merged_circuits, merge_tags, new_symbols, remaining_symbols, sweep + ) + + return new_circuit.unfreeze(copy=False), new_sweep + - return output_circuit.unfreeze(copy=False), new_sweep +# ---------------------------------------------------------------------- +# Impl merge_single_qubit_gates_to_phxz_symbolized: End +# ---------------------------------------------------------------------- diff --git a/cirq-core/cirq/transformers/merge_single_qubit_gates_test.py b/cirq-core/cirq/transformers/merge_single_qubit_gates_test.py index 0aa65bfc01c..779d53e7df4 100644 --- a/cirq-core/cirq/transformers/merge_single_qubit_gates_test.py +++ b/cirq-core/cirq/transformers/merge_single_qubit_gates_test.py @@ -13,11 +13,12 @@ # limitations under the License. from typing import List +from unittest.mock import Mock, patch import pytest import sympy + import cirq -from cirq.study.sweeps import Points def assert_optimizes(optimized: cirq.AbstractCircuit, expected: cirq.AbstractCircuit): @@ -236,9 +237,10 @@ def test_merge_single_qubit_moments_to_phased_x_and_z_global_phase(): assert c == c2 -def test_merge_into_symbolized_phxz(): +def test_merge_single_qubit_gates_to_phxz_symbolized(): """Test case diagram. Input circuit: + # pylint: disable=line-too-long 0: ───X─────────@──────────H[ignore]───H───X───PhXZ(a=a0,x=x0,z=z0)───X───PhXZ(a=a1,x=x1,z=z1)─── │ 1: ───H^h_exp───@^cz_exp───────────────────────────────────────────────────────────────────────── @@ -246,6 +248,7 @@ def test_merge_into_symbolized_phxz(): 0: ───PhXZ(a=-1,x=1,z=0)─────@──────────H[ignore]───PhXZ(a=a1,x=x1,z=z1)─── │ 1: ───PhXZ(a=a0,x=x0,z=z0)───@^cz_exp────────────────────────────────────── + # pylint: enable=line-too-long """ a, b = cirq.LineQubit.range(2) sa0, sa1 = [sympy.Symbol(a) for a in ["a0", "a1"]] @@ -292,9 +295,101 @@ def test_merge_into_symbolized_phxz(): ) -def test_merge_into_symbolized_phxz_non_symbolized_input(): - a = cirq.NamedQubit('a') - with pytest.warns(UserWarning): - cirq.merge_single_qubit_gates_to_phxz_symbolized( - cirq.Circuit(cirq.H(a), cirq.H(a)), sweep=cirq.Points(key="a", points=[0.1, 0.2, 0.5]) - ) +def test_merge_single_qubit_gates_to_phxz_symbolized_non_parameterized_singles(): + """Test merge_single_qubit_gates_to_phxz_symbolized when all single qubit gates are not + parameterized.""" + + a, b = cirq.LineQubit.range(2) + input_circuit = cirq.Circuit(cirq.H(a), cirq.H(a), cirq.CZ(a, b) ** sympy.Symbol("exp")) + expected_circuit = cirq.merge_single_qubit_gates_to_phxz(input_circuit) + output_circuit, _ = cirq.merge_single_qubit_gates_to_phxz_symbolized( + input_circuit, sweep=cirq.Points(key="exp", points=[0.1, 0.2, 0.5]) + ) + assert_optimizes(output_circuit, expected_circuit) + + +def test_merge_single_qubit_gates_to_phxz_symbolized_with_global_phases(): + a = cirq.NamedQubit("a") + input_circuit = cirq.Circuit( + cirq.GlobalPhaseGate(1j).on(), cirq.X(a), cirq.Y(a) ** sympy.Symbol("y_exp") + ) + new_circuit, _ = cirq.merge_single_qubit_gates_to_phxz_symbolized( + input_circuit, sweep=cirq.Points(key="y_exp", points=[0, 1]) + ) + expected_circuit = cirq.Circuit( + cirq.GlobalPhaseGate(1j).on(), + _phxz(sympy.Symbol("a0"), sympy.Symbol("x0"), sympy.Symbol("z0")).on(a), + ) + + assert_optimizes(new_circuit, expected_circuit) + + +def test_merge_single_qubit_gates_to_phxz_symbolized_different_structures_error(): + """Tests that the function raises a RuntimeError if merged structures of the circuit differ + for different parameterizations.""" + a = cirq.NamedQubit("a") + circuit = cirq.Circuit(cirq.H(a) ** sympy.Symbol("exp")) + sweep = cirq.Points(key="exp", points=[0.1, 0.2]) + + with patch( + "cirq.protocols.resolve_parameters", + side_effect=[ + cirq.Circuit(cirq.H(a).with_tags("_symbolized_single")), + cirq.Circuit(cirq.H(a)), + ], + ): + with pytest.raises( + RuntimeError, + match="Different resolvers in sweep resulted in different merged structures.", + ): + cirq.merge_single_qubit_gates_to_phxz_symbolized(circuit, sweep=sweep) + + +def test_merge_single_qubit_gates_to_phxz_symbolized_multiple_phxz_tags_error(): + """Tests that the function raises a RuntimeError of incorrect merges.""" + a, b = cirq.LineQubit.range(2) + circuit = cirq.Circuit( + cirq.H(a) ** sympy.Symbol("exp1"), + cirq.X(a), + cirq.CZ(a, b), + cirq.Y(a), + cirq.H(a) ** sympy.Symbol("exp2"), + ) + sweep = cirq.Points(key="exp1", points=[0.1, 0.2]) * cirq.Points(key="exp2", points=[0.1, 0.2]) + + mock_iter = Mock() + mock_iter.__next__ = Mock(return_value=2) + + with patch( + "cirq.transformers.merge_k_qubit_gates.merge_k_qubit_unitaries", + return_value=cirq.Circuit(cirq.H(a).with_tags("_phxz_0", "_phxz_1")), + ): + with patch("itertools.count", return_value=mock_iter): + with pytest.raises(RuntimeError, match="Multiple merge tags found."): + cirq.merge_single_qubit_gates_to_phxz_symbolized(circuit, sweep=sweep) + + +def test_merge_single_qubit_gates_to_phxz_symbolized_unexpected_gate_error(): + """Tests that the function raises a RuntimeError of unexpected gate.""" + a, b = cirq.LineQubit.range(2) + circuit = cirq.Circuit( + cirq.H(a) ** sympy.Symbol("exp1"), + cirq.X(a), + cirq.CZ(a, b), + cirq.Y(a), + cirq.H(a) ** sympy.Symbol("exp2"), + ) + sweep = cirq.Points(key="exp1", points=[0.1, 0.2]) * cirq.Points(key="exp2", points=[0.1, 0.2]) + + mock_iter = Mock() + mock_iter.__next__ = Mock(return_value=2) + + with patch( + "cirq.transformers.analytical_decompositions" + ".single_qubit_decompositions.single_qubit_matrix_to_phxz", + return_value=cirq.H, + ): + with pytest.raises( + RuntimeError, match="Expected the merged gate to be a PhasedXZGate or IdentityGate." + ): + cirq.merge_single_qubit_gates_to_phxz_symbolized(circuit, sweep=sweep) From 1ddf7ce56d8dd6cdb3467f84bbf8e7126a625df6 Mon Sep 17 00:00:00 2001 From: Renyi Chen Date: Wed, 9 Apr 2025 13:07:12 -0700 Subject: [PATCH 04/19] fix lint --- cirq-core/cirq/transformers/merge_single_qubit_gates.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/cirq-core/cirq/transformers/merge_single_qubit_gates.py b/cirq-core/cirq/transformers/merge_single_qubit_gates.py index 78cfd6db091..ff910d483dd 100644 --- a/cirq-core/cirq/transformers/merge_single_qubit_gates.py +++ b/cirq-core/cirq/transformers/merge_single_qubit_gates.py @@ -179,8 +179,8 @@ def _merge_single_qubit_gates_to_circuit_op_symbolized( Args: resolved_circuits: A list of circuits where symbols have been replaced with concrete values. - symbolized_single_tag: The tag applied to single-qubit operations that originally contained symbols - before parameterizations. + symbolized_single_tag: The tag applied to single-qubit operations that originally + contained symbols before parameterizations. Returns: Tuple of merge counts, merged circuits, and merge tags. From 3c465073d855a9647b3c354f666ca25756079fe3 Mon Sep 17 00:00:00 2001 From: Renyi Chen Date: Wed, 9 Apr 2025 13:23:23 -0700 Subject: [PATCH 05/19] fix docstring and func name --- .../cirq/transformers/merge_single_qubit_gates.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/cirq-core/cirq/transformers/merge_single_qubit_gates.py b/cirq-core/cirq/transformers/merge_single_qubit_gates.py index ff910d483dd..233d6988ed8 100644 --- a/cirq-core/cirq/transformers/merge_single_qubit_gates.py +++ b/cirq-core/cirq/transformers/merge_single_qubit_gates.py @@ -168,14 +168,14 @@ def _values_of_sweep(sweep: Sweep, key: str | sympy.Symbol): return [resolver.value_of(p) for resolver in sweep] -def _merge_single_qubit_gates_to_circuit_op_symbolized( +def _merge_single_qubit_gates_to_phxz_symbolized( resolved_circuits: List['cirq.AbstractCircuit'], symbolized_single_tag: str, context: Optional['cirq.TransformerContext'], atol: float, ) -> Tuple[List['cirq.Circuit'], frozenset[str], frozenset[str]]: - """Helper function to merge single qubit ops of resolved circuits to ops of CircuitOperation - type using merge_k_qubit_unitaries. + """Helper function to merge single qubit ops of resolved circuits to PhasedXZ ops + using merge_k_qubit_unitaries. Args: resolved_circuits: A list of circuits where symbols have been replaced with concrete values. @@ -183,7 +183,8 @@ def _merge_single_qubit_gates_to_circuit_op_symbolized( contained symbols before parameterizations. Returns: - Tuple of merge counts, merged circuits, and merge tags. + Tuple of merge_counts, merged_circuits, and merge_tags, where + merged ops in merged_circuits are tagged by merge_tags. """ merge_counts: list[int] = [] # number of merges per resolved_circuit merged_circuits: list['cirq.Circuit'] = [] @@ -391,7 +392,7 @@ def merge_single_qubit_gates_to_phxz_symbolized( ] # Step 1, merge single qubit gates of resolved circuits using merge_k_qubit_unitaries. - merged_circuits, merge_tags, new_symbols = _merge_single_qubit_gates_to_circuit_op_symbolized( + merged_circuits, merge_tags, new_symbols = _merge_single_qubit_gates_to_phxz_symbolized( resolved_circuits, symbolized_single_tag, context, atol ) From ecb6a838d652a97f9f46fb6f34dd8c79333f71af Mon Sep 17 00:00:00 2001 From: Renyi Chen Date: Fri, 18 Apr 2025 01:51:49 -0700 Subject: [PATCH 06/19] Expose more transformers tag_transformers: remove_tags, index_tags. symbolize: symbolize_single_qubit_gates_by_indexed_tags --- cirq-core/cirq/__init__.py | 3 + cirq-core/cirq/transformers/__init__.py | 6 + .../transformers/merge_single_qubit_gates.py | 261 ++++++++---------- .../merge_single_qubit_gates_test.py | 42 +-- cirq-core/cirq/transformers/symbolize.py | 88 ++++++ cirq-core/cirq/transformers/symbolize_test.py | 52 ++++ .../cirq/transformers/tag_transformers.py | 95 +++++++ .../transformers/tag_transformers_test.py | 67 +++++ 8 files changed, 420 insertions(+), 194 deletions(-) create mode 100644 cirq-core/cirq/transformers/symbolize.py create mode 100644 cirq-core/cirq/transformers/symbolize_test.py create mode 100644 cirq-core/cirq/transformers/tag_transformers.py create mode 100644 cirq-core/cirq/transformers/tag_transformers_test.py diff --git a/cirq-core/cirq/__init__.py b/cirq-core/cirq/__init__.py index 12916f66a33..3301196040d 100644 --- a/cirq-core/cirq/__init__.py +++ b/cirq-core/cirq/__init__.py @@ -363,6 +363,7 @@ eject_z as eject_z, expand_composite as expand_composite, HardCodedInitialMapper as HardCodedInitialMapper, + index_tags as index_tags, is_negligible_turn as is_negligible_turn, LineInitialMapper as LineInitialMapper, MappingManager as MappingManager, @@ -379,6 +380,7 @@ merge_single_qubit_gates_to_phxz as merge_single_qubit_gates_to_phxz, merge_single_qubit_gates_to_phxz_symbolized as merge_single_qubit_gates_to_phxz_symbolized, merge_single_qubit_moments_to_phxz as merge_single_qubit_moments_to_phxz, + symbolize_single_qubit_gates_by_indexed_tags as symbolize_single_qubit_gates_by_indexed_tags, optimize_for_target_gateset as optimize_for_target_gateset, parameterized_2q_op_to_sqrt_iswap_operations as parameterized_2q_op_to_sqrt_iswap_operations, prepare_two_qubit_state_using_cz as prepare_two_qubit_state_using_cz, @@ -386,6 +388,7 @@ prepare_two_qubit_state_using_sqrt_iswap as prepare_two_qubit_state_using_sqrt_iswap, quantum_shannon_decomposition as quantum_shannon_decomposition, RouteCQC as RouteCQC, + remove_tags as remove_tags, routed_circuit_with_mapping as routed_circuit_with_mapping, SqrtIswapTargetGateset as SqrtIswapTargetGateset, single_qubit_matrix_to_gates as single_qubit_matrix_to_gates, diff --git a/cirq-core/cirq/transformers/__init__.py b/cirq-core/cirq/transformers/__init__.py index 833a914ca85..56fd23fd97f 100644 --- a/cirq-core/cirq/transformers/__init__.py +++ b/cirq-core/cirq/transformers/__init__.py @@ -104,6 +104,12 @@ merge_single_qubit_moments_to_phxz as merge_single_qubit_moments_to_phxz, ) +from cirq.transformers.tag_transformers import index_tags as index_tags, remove_tags as remove_tags +from cirq.transformers.symbolize import ( + symbolize_single_qubit_gates_by_indexed_tags as symbolize_single_qubit_gates_by_indexed_tags, +) + + from cirq.transformers.qubit_management_transformers import ( map_clean_and_borrowable_qubits as map_clean_and_borrowable_qubits, ) diff --git a/cirq-core/cirq/transformers/merge_single_qubit_gates.py b/cirq-core/cirq/transformers/merge_single_qubit_gates.py index 233d6988ed8..cff04faa702 100644 --- a/cirq-core/cirq/transformers/merge_single_qubit_gates.py +++ b/cirq-core/cirq/transformers/merge_single_qubit_gates.py @@ -14,15 +14,22 @@ """Transformer passes to combine adjacent single-qubit rotations.""" -import itertools -from typing import Dict, Hashable, List, Optional, Tuple, TYPE_CHECKING +from typing import Callable, Dict, Hashable, List, Optional, Tuple, TYPE_CHECKING import sympy from cirq import circuits, ops, protocols +from cirq.study.result import TMeasurementKey from cirq.study.sweeps import Points, Sweep, Zip -from cirq.transformers import align, merge_k_qubit_gates, transformer_api, transformer_primitives +from cirq.transformers import ( + align, + merge_k_qubit_gates, + symbolize, + transformer_api, + transformer_primitives, +) from cirq.transformers.analytical_decompositions import single_qubit_decompositions +from cirq.transformers.tag_transformers import index_tags, remove_tags if TYPE_CHECKING: import cirq @@ -69,6 +76,7 @@ def merge_single_qubit_gates_to_phxz( circuit: 'cirq.AbstractCircuit', *, context: Optional['cirq.TransformerContext'] = None, + merge_tags_fn: Optional[Callable[['cirq.CircuitOperation'], List[Hashable]]] = None, atol: float = 1e-8, ) -> 'cirq.Circuit': """Replaces runs of single qubit rotations with a single optional `cirq.PhasedXZGate`. @@ -79,6 +87,8 @@ def merge_single_qubit_gates_to_phxz( Args: circuit: Input circuit to transform. It will not be modified. context: `cirq.TransformerContext` storing common configurable options for transformers. + merge_tag: If provided, tag merged PhXZ gate with it. + merge_tags_fn: A callable returns the tags to be added to the merged operation. atol: Absolute tolerance to angle error. Larger values allow more negligible gates to be dropped, smaller values increase accuracy. @@ -86,12 +96,15 @@ def merge_single_qubit_gates_to_phxz( Copy of the transformed input circuit. """ - def rewriter(op: 'cirq.CircuitOperation') -> 'cirq.OP_TREE': - u = protocols.unitary(op) - if protocols.num_qubits(op) == 0: + def rewriter(circuit_op: 'cirq.CircuitOperation') -> 'cirq.OP_TREE': + + u = protocols.unitary(circuit_op) + if protocols.num_qubits(circuit_op) == 0: return ops.GlobalPhaseGate(u[0, 0]).on() - gate = single_qubit_decompositions.single_qubit_matrix_to_phxz(u, atol) - return gate(op.qubits[0]) if gate else [] + + gate = single_qubit_decompositions.single_qubit_matrix_to_phxz(u, atol) or ops.I + phxz_op = gate.on(circuit_op.qubits[0]) + return phxz_op.with_tags(*merge_tags_fn(circuit_op)) if merge_tags_fn else phxz_op return merge_k_qubit_gates.merge_k_qubit_unitaries( circuit, k=1, context=context, rewriter=rewriter @@ -158,141 +171,33 @@ def merge_func(m1: 'cirq.Moment', m2: 'cirq.Moment') -> Optional['cirq.Moment']: ).unfreeze(copy=False) -# ---------------------------------------------------------------------- -# Impl merge_single_qubit_gates_to_phxz_symbolized: Start -# ---------------------------------------------------------------------- - - -def _values_of_sweep(sweep: Sweep, key: str | sympy.Symbol): +def _values_of_sweep(sweep: Sweep, key: TMeasurementKey): p = sympy.Symbol(key) if isinstance(key, str) else key return [resolver.value_of(p) for resolver in sweep] -def _merge_single_qubit_gates_to_phxz_symbolized( - resolved_circuits: List['cirq.AbstractCircuit'], - symbolized_single_tag: str, - context: Optional['cirq.TransformerContext'], - atol: float, -) -> Tuple[List['cirq.Circuit'], frozenset[str], frozenset[str]]: - """Helper function to merge single qubit ops of resolved circuits to PhasedXZ ops - using merge_k_qubit_unitaries. - - Args: - resolved_circuits: A list of circuits where symbols have been replaced with concrete values. - symbolized_single_tag: The tag applied to single-qubit operations that originally - contained symbols before parameterizations. - - Returns: - Tuple of merge_counts, merged_circuits, and merge_tags, where - merged ops in merged_circuits are tagged by merge_tags. - """ - merge_counts: list[int] = [] # number of merges per resolved_circuit - merged_circuits: list['cirq.Circuit'] = [] - tag_iter: itertools.count - phxz_tag_prefix = "_phxz" - - def rewriter(circuit_op: 'cirq.CircuitOperation') -> 'cirq.OP_TREE': - nonlocal tag_iter - tag: Optional[str] = None - - u = protocols.unitary(circuit_op) - if protocols.num_qubits(circuit_op) == 0: - return ops.GlobalPhaseGate(u[0, 0]).on() - # If any of the op in the merged circuit_op is a symbolized single qubit gate, - # tag the merged phxz gate with next tag id, for further parameterization references. - for op in circuit_op.circuit.all_operations(): - if symbolized_single_tag in op.tags: - tag = f"{phxz_tag_prefix}_{next(tag_iter)}" - break - gate = single_qubit_decompositions.single_qubit_matrix_to_phxz(u, atol) or ops.I - op = gate.on(circuit_op.qubits[0]) - return op.with_tags(tag) if tag else op - - for resolved_circuit in resolved_circuits: - tag_iter = itertools.count(start=0, step=1) - merged_circuits.append( - merge_k_qubit_gates.merge_k_qubit_unitaries( - resolved_circuit, k=1, context=context, rewriter=rewriter - ) - ) - merge_counts.append(next(tag_iter)) - - if not all(count == merge_counts[0] for count in merge_counts): - raise RuntimeError("Different resolvers in sweep resulted in different merged structures.") - - merge_tags: frozenset[str] = frozenset( - {f"{phxz_tag_prefix}_{i}" for i in range(merge_counts[0])} - ) - new_symbols: frozenset[str] = frozenset( - set().union(*[{f"x{i}", f"z{i}", f"a{i}"} for i in range(merge_counts[0])]) - ) - - return merged_circuits, merge_tags, new_symbols - - -def _get_merge_tag_id(merge_tags: frozenset[str], op_tags: Tuple[Hashable, ...]) -> Optional[str]: - """Extract the id `i` from the merge tag `_phxz_i` if it exists.""" - the_merge_tag: set[str] = set(merge_tags.intersection(op_tags)) - if len(the_merge_tag) == 0: - return None - if len(the_merge_tag) > 1: - raise RuntimeError("Multiple merge tags found.") - return the_merge_tag.pop().split("_")[-1] - - -def _map_merged_ops_to_symbolized_phxz( - circuit: 'cirq.Circuit', merge_tags: frozenset[str], deep: bool -) -> 'cirq.Circuit': - """Maps merged operations (tagged with merge_tags) in the circuit to symbolized PhasedXZGates. - - Args: - circuit: Circuit with merge tags to be mapped. - merge_tags: The set of tags used to identify the merged PhasedXZ gates that need to be - symbolized. - deep: Whether to perform the mapping recursively within CircuitOperations. - - Returns: - A new circuit where tagged PhasedXZ gates are replaced by symbolized versions. - """ - - # Map merged ops to `PhasedXZGate(xi,zi,ai)` based on the tag "_phxz_i". - def _map_func(op: 'cirq.Operation', _): - """Maps an op with tag `_phxz_i` to a symbolzied `PhasedXZGate(xi,zi,ai)`""" - sid = _get_merge_tag_id(merge_tags, op.tags) - if sid is None: - return op - phxz_params = { - "x_exponent": sympy.Symbol(f"x{sid}"), - "z_exponent": sympy.Symbol(f"z{sid}"), - "axis_phase_exponent": sympy.Symbol(f"a{sid}"), - } - return ops.PhasedXZGate(**phxz_params).on(*op.qubits) - - return align.align_right( - transformer_primitives.map_operations(circuit.freeze(), _map_func, deep=deep) - ) - - -def _parameterize_merged_circuits( - merged_circuits: List['cirq.Circuit'], - merge_tags: frozenset[str], - new_symbols: frozenset[str], - remaining_symbols: frozenset[str], +def _parameterize_phxz_in_circuits( + circuit_list: List['cirq.Circuit'], + merge_tag_prefix: str, + phxz_symbols: frozenset[sympy.Symbol], + remaining_symbols: frozenset[sympy.Symbol], sweep: Sweep, ) -> Sweep: - """Parameterizes the merged circuits and returns a new sweep.""" + """Parameterizes the circuits and returns a new sweep.""" values_by_params: Dict[str, List[float]] = { - **{s: [] for s in new_symbols}, # New symbols introduced during merging - **{ - s: _values_of_sweep(sweep, s) for s in remaining_symbols - }, # Existing symbols in ops that were not merged, e.g., symbols in 2-qubit gates. + **{str(s): [] for s in phxz_symbols}, + **{str(s): _values_of_sweep(sweep, s) for s in remaining_symbols}, } - for merged_circuit in merged_circuits: - for op in merged_circuit.all_operations(): - sid = _get_merge_tag_id(merge_tags, op.tags) - if sid is None: + for circuit in circuit_list: + for op in circuit.all_operations(): + the_merge_tag: Optional[str] = None + for tag in op.tags: + if str(tag).startswith(merge_tag_prefix): + the_merge_tag = str(tag) + if not the_merge_tag: continue + sid = the_merge_tag.rsplit("_", maxsplit=-1)[-1] x, z, a = 0.0, 0.0, 0.0 # Identity gate's parameters if isinstance(op.gate, ops.PhasedXZGate): x, z, a = op.gate.x_exponent, op.gate.z_exponent, op.gate.axis_phase_exponent @@ -308,6 +213,15 @@ def _parameterize_merged_circuits( return Zip(*[Points(key=key, points=values) for key, values in values_by_params.items()]) +def _all_tags_startswith(circuit: 'cirq.AbstractCircuit', startswith: str): + tag_set: set[Hashable] = set() + for op in circuit.all_operations(): + for tag in op.tags: + if str(tag).startswith(startswith): + tag_set.add(tag) + return tag_set + + def merge_single_qubit_gates_to_phxz_symbolized( circuit: 'cirq.AbstractCircuit', *, @@ -353,7 +267,7 @@ def merge_single_qubit_gates_to_phxz_symbolized( deep = context.deep if context else False # Tag symbolized single-qubit op. - symbolized_single_tag = "_symbolized_single" + symbolized_single_tag = "TMP-TAG-symbolized-single" circuit_tagged = transformer_primitives.map_operations( circuit, @@ -366,7 +280,6 @@ def merge_single_qubit_gates_to_phxz_symbolized( ) # Step 0, isolate single qubit symbolized symbols and resolve the circuit on them. - single_qubit_gate_symbols: frozenset[sympy.Symbol] = frozenset( set().union( *[ @@ -378,11 +291,7 @@ def merge_single_qubit_gates_to_phxz_symbolized( # If all single qubit gates are not parameterized, call the nonparamerized version of # the transformer. if not single_qubit_gate_symbols: - return merge_single_qubit_gates_to_phxz(circuit, context=context, atol=atol), sweep - # Remaining symbols, e.g., 2 qubit gates' symbols. Sweep of those symbols keeps unchanged. - remaining_symbols: frozenset[sympy.Symbol] = frozenset( - protocols.parameter_symbols(circuit) - single_qubit_gate_symbols - ) + return (merge_single_qubit_gates_to_phxz(circuit, context=context, atol=atol), sweep) sweep_of_single: Sweep = Zip( *[Points(key=k, points=_values_of_sweep(sweep, k)) for k in single_qubit_gate_symbols] ) @@ -391,22 +300,68 @@ def merge_single_qubit_gates_to_phxz_symbolized( protocols.resolve_parameters(circuit_tagged, resolver) for resolver in sweep_of_single ] - # Step 1, merge single qubit gates of resolved circuits using merge_k_qubit_unitaries. - merged_circuits, merge_tags, new_symbols = _merge_single_qubit_gates_to_phxz_symbolized( - resolved_circuits, symbolized_single_tag, context, atol - ) + # Step 1, merge single qubit gates per resolved circuit, preserving the "symbolized_single_tag". + merged_circuits: List['cirq.Circuit'] = [] + phxz_symbols: set[sympy.Symbols] = set() + for resolved_circuit in resolved_circuits: + merged_circuit = index_tags( + merge_single_qubit_gates_to_phxz( + resolved_circuit, + context=context, + merge_tags_fn=lambda circuit_op: ( + [symbolized_single_tag] + if any( + symbolized_single_tag in set(op.tags) + for op in circuit_op.circuit.all_operations() + ) + else [] + ), + atol=atol, + ), + target_tags={symbolized_single_tag}, + context=context, + ) + merged_circuits.append(merged_circuit) + + if not all( + _all_tags_startswith(merged_circuits[0], startswith=symbolized_single_tag) + == _all_tags_startswith(merged_circuit, startswith=symbolized_single_tag) + for merged_circuit in merged_circuits + ): + raise RuntimeError("Different resolvers in sweep resulted in different merged structures.") - # Step 2, get the new symbolzied circuit as new_sweep by mapping merged operations. - new_circuit = _map_merged_ops_to_symbolized_phxz(merged_circuits[0], merge_tags, deep) + # Step 2, get the new symbolized circuit by mapping merged operations. + new_circuit = align.align_right( + remove_tags( + symbolize.symbolize_single_qubit_gates_by_indexed_tags( + merged_circuits[0], tag_prefix=symbolized_single_tag + ), + remove_if=lambda tag: tag.startswith(symbolized_single_tag), + ) + ) # Step 3, get N sets of parameterizations as new_sweep. - new_sweep = _parameterize_merged_circuits( - merged_circuits, merge_tags, new_symbols, remaining_symbols, sweep + phxz_symbols: frozenset[sympy.Symbol] = frozenset( + set().union( + *[ + set( + [ + sympy.Symbol(tag.replace(f"{symbolized_single_tag}_", s)) + for s in ["x", "z", "a"] + ] + ) + for tag in _all_tags_startswith( + merged_circuits[0], startswith=symbolized_single_tag + ) + ] + ) + ) + # Remaining symbols, e.g., 2 qubit gates' symbols. Sweep of those symbols keeps unchanged. + remaining_symbols: frozenset[sympy.Symbol] = frozenset( + protocols.parameter_symbols(circuit) - single_qubit_gate_symbols + ) + new_sweep = _parameterize_phxz_in_circuits( + merged_circuits, symbolized_single_tag, phxz_symbols, remaining_symbols, sweep ) return new_circuit.unfreeze(copy=False), new_sweep - - -# ---------------------------------------------------------------------- -# Impl merge_single_qubit_gates_to_phxz_symbolized: End -# ---------------------------------------------------------------------- diff --git a/cirq-core/cirq/transformers/merge_single_qubit_gates_test.py b/cirq-core/cirq/transformers/merge_single_qubit_gates_test.py index 779d53e7df4..6962d96cb23 100644 --- a/cirq-core/cirq/transformers/merge_single_qubit_gates_test.py +++ b/cirq-core/cirq/transformers/merge_single_qubit_gates_test.py @@ -308,22 +308,6 @@ def test_merge_single_qubit_gates_to_phxz_symbolized_non_parameterized_singles() assert_optimizes(output_circuit, expected_circuit) -def test_merge_single_qubit_gates_to_phxz_symbolized_with_global_phases(): - a = cirq.NamedQubit("a") - input_circuit = cirq.Circuit( - cirq.GlobalPhaseGate(1j).on(), cirq.X(a), cirq.Y(a) ** sympy.Symbol("y_exp") - ) - new_circuit, _ = cirq.merge_single_qubit_gates_to_phxz_symbolized( - input_circuit, sweep=cirq.Points(key="y_exp", points=[0, 1]) - ) - expected_circuit = cirq.Circuit( - cirq.GlobalPhaseGate(1j).on(), - _phxz(sympy.Symbol("a0"), sympy.Symbol("x0"), sympy.Symbol("z0")).on(a), - ) - - assert_optimizes(new_circuit, expected_circuit) - - def test_merge_single_qubit_gates_to_phxz_symbolized_different_structures_error(): """Tests that the function raises a RuntimeError if merged structures of the circuit differ for different parameterizations.""" @@ -334,7 +318,7 @@ def test_merge_single_qubit_gates_to_phxz_symbolized_different_structures_error( with patch( "cirq.protocols.resolve_parameters", side_effect=[ - cirq.Circuit(cirq.H(a).with_tags("_symbolized_single")), + cirq.Circuit(cirq.H(a).with_tags("TMP-TAG-symbolized-single")), cirq.Circuit(cirq.H(a)), ], ): @@ -345,30 +329,6 @@ def test_merge_single_qubit_gates_to_phxz_symbolized_different_structures_error( cirq.merge_single_qubit_gates_to_phxz_symbolized(circuit, sweep=sweep) -def test_merge_single_qubit_gates_to_phxz_symbolized_multiple_phxz_tags_error(): - """Tests that the function raises a RuntimeError of incorrect merges.""" - a, b = cirq.LineQubit.range(2) - circuit = cirq.Circuit( - cirq.H(a) ** sympy.Symbol("exp1"), - cirq.X(a), - cirq.CZ(a, b), - cirq.Y(a), - cirq.H(a) ** sympy.Symbol("exp2"), - ) - sweep = cirq.Points(key="exp1", points=[0.1, 0.2]) * cirq.Points(key="exp2", points=[0.1, 0.2]) - - mock_iter = Mock() - mock_iter.__next__ = Mock(return_value=2) - - with patch( - "cirq.transformers.merge_k_qubit_gates.merge_k_qubit_unitaries", - return_value=cirq.Circuit(cirq.H(a).with_tags("_phxz_0", "_phxz_1")), - ): - with patch("itertools.count", return_value=mock_iter): - with pytest.raises(RuntimeError, match="Multiple merge tags found."): - cirq.merge_single_qubit_gates_to_phxz_symbolized(circuit, sweep=sweep) - - def test_merge_single_qubit_gates_to_phxz_symbolized_unexpected_gate_error(): """Tests that the function raises a RuntimeError of unexpected gate.""" a, b = cirq.LineQubit.range(2) diff --git a/cirq-core/cirq/transformers/symbolize.py b/cirq-core/cirq/transformers/symbolize.py new file mode 100644 index 00000000000..7f948dc8a13 --- /dev/null +++ b/cirq-core/cirq/transformers/symbolize.py @@ -0,0 +1,88 @@ +# Copyright 2025 The Cirq Developers +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Hashable, Optional, TYPE_CHECKING +import re +import sympy + +from cirq import ops +from cirq.transformers import transformer_api, transformer_primitives + +if TYPE_CHECKING: + import cirq + + +@transformer_api.transformer +def symbolize_single_qubit_gates_by_indexed_tags( + circuit: 'cirq.AbstractCircuit', + *, + context: Optional['cirq.TransformerContext'] = None, + tag_prefix: Optional[str] = "TO-PHXZ", +) -> 'cirq.Circuit': + """Symbolize single qubit operations by indexed tags prefixed by tag_prefix. + + Example: + >>> q0, q1 = cirq.LineQubit.range(2) + >>> c = cirq.Circuit(\ + cirq.X(q0).with_tags("phxz_0"),\ + cirq.CZ(q0,q1),\ + cirq.Y(q0).with_tags("phxz_1"),\ + cirq.X(q0)) + >>> print(c) + 0: ───X["phxz_0"]───@───Y["phxz_1"]───X─── + │ + 1: ─────────────────@───────────────────── + >>> new_circuit = cirq.symbolize_single_qubit_gates_by_indexed_tags(\ + c, tag_prefix="phxz") + >>> print(new_circuit) + 0: ───PhXZ(a=a0,x=x0,z=z0)───@───PhXZ(a=a1,x=x1,z=z1)─X─ + │ + 1: ────────────────────────@────────────────────────── + + Args: + circuit: Input circuit to apply the transformations on. The input circuit is not mutated. + context: `cirq.TransformerContext` storing common configurable options for transformers. + tag_prefix: The prefix of the tag. + + Returns: + Copy of the transformed input circuit. + """ + + def _map_func(op: 'cirq.Operation', _): + """Maps an op with tag `{tag_prefix}_i` to a symbolzied `PhasedXZGate(xi,zi,ai)`.""" + tags: set[Hashable] = set(op.tags) + tag_id: None | int = None + for tag in tags: + if re.fullmatch(f"{tag_prefix}_\\d+", tag): + if tag_id is None: + tag_id = tag.split("_")[-1] + else: + raise ValueError(f"Multiple tags are prefixed with {tag_prefix}.") + if not tag_id: + return op + tags.remove(f"{tag_prefix}_{tag_id}") + phxz_params = { + "x_exponent": sympy.Symbol(f"x{tag_id}"), + "z_exponent": sympy.Symbol(f"z{tag_id}"), + "axis_phase_exponent": sympy.Symbol(f"a{tag_id}"), + } + + return ops.PhasedXZGate(**phxz_params).on(*op.qubits).with_tags(*tags) + + return transformer_primitives.map_operations( + circuit.freeze(), + _map_func, + deep=context.deep if context else False, + tags_to_ignore=context.tags_to_ignore if context else set(), + ).unfreeze(copy=False) diff --git a/cirq-core/cirq/transformers/symbolize_test.py b/cirq-core/cirq/transformers/symbolize_test.py new file mode 100644 index 00000000000..6cf2e9f3836 --- /dev/null +++ b/cirq-core/cirq/transformers/symbolize_test.py @@ -0,0 +1,52 @@ +# Copyright 2025 The Cirq Developers +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest +import sympy + +import cirq + + +def test_symbolize_single_qubit_gates_by_indexed_tags_success(): + q = cirq.NamedQubit("a") + input_circuit = cirq.Circuit( + cirq.X(q).with_tags("TO-PHXZ_1"), + cirq.Y(q).with_tags("tag1"), + cirq.Z(q).with_tags("TO-PHXZ_0"), + ) + output_circuit = cirq.symbolize_single_qubit_gates_by_indexed_tags(input_circuit) + cirq.testing.assert_same_circuits( + output_circuit, + cirq.Circuit( + cirq.PhasedXZGate( + x_exponent=sympy.Symbol("x1"), + z_exponent=sympy.Symbol("z1"), + axis_phase_exponent=sympy.Symbol("a1"), + ).on(q), + cirq.Y(q).with_tags("tag1"), + cirq.PhasedXZGate( + x_exponent=sympy.Symbol("x0"), + z_exponent=sympy.Symbol("z0"), + axis_phase_exponent=sympy.Symbol("a0"), + ).on(q), + ), + ) + + +def test_symbolize_single_qubit_gates_by_indexed_tags_multiple_tags(): + q = cirq.NamedQubit("a") + input_circuit = cirq.Circuit(cirq.X(q).with_tags("TO-PHXZ_0", "TO-PHXZ_2")) + + with pytest.raises(ValueError, match="Multiple tags are prefixed with TO-PHXZ."): + cirq.symbolize_single_qubit_gates_by_indexed_tags(input_circuit) diff --git a/cirq-core/cirq/transformers/tag_transformers.py b/cirq-core/cirq/transformers/tag_transformers.py new file mode 100644 index 00000000000..9664b5e2b36 --- /dev/null +++ b/cirq-core/cirq/transformers/tag_transformers.py @@ -0,0 +1,95 @@ +# Copyright 2025 The Cirq Developers +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import itertools +from typing import Callable, Hashable, Optional, TYPE_CHECKING + + +from cirq.transformers import transformer_api, transformer_primitives + +if TYPE_CHECKING: + import cirq + + +@transformer_api.transformer +def index_tags( + circuit: 'cirq.AbstractCircuit', + *, + context: Optional['cirq.TransformerContext'] = None, + target_tags: frozenset[Hashable] = frozenset(), + index_if: Callable[[Hashable], bool] = lambda _: True, +) -> 'cirq.Circuit': + """Indexes all the tags in target_tags tag_0, tag_1, .... + + Args: + circuit: Input circuit to apply the transformations on. The input circuit is not mutated. + context: `cirq.TransformerContext` storing common configurable options for transformers. + target_tags: Tags to be indexed. + index_if: A callable that returns True if its tags should be indexed. Defaults to True. + + Returns: + Copy of the transformed input circuit. + """ + tag_iter_by_tags = {tag: itertools.count(start=0, step=1) for tag in target_tags} + tags_to_ignore = context.tags_to_ignore if context else set() + + def _map_func(op: 'cirq.Operation', _) -> 'cirq.OP_TREE': + tag_set = set(op.tags) + if not index_if(op) or tag_set.intersection(tags_to_ignore): # Skip indexing + return op + nonlocal tag_iter_by_tags + for tag in target_tags.intersection(op.tags): + tag_set.remove(tag) + tag_set.add(f"{tag}_{next(tag_iter_by_tags[tag])}") + + return op.untagged.with_tags(*tag_set) + + return transformer_primitives.map_operations( + circuit, _map_func, deep=context.deep if context else False + ).freeze(copy=False) + + +@transformer_api.transformer +def remove_tags( + circuit: 'cirq.AbstractCircuit', + *, + context: Optional['cirq.TransformerContext'] = None, + target_tags: frozenset[Hashable] = frozenset(), + remove_if: Callable[[Hashable], bool] = lambda _: False, +) -> 'cirq.Circuit': + """Remove tags from the operations based on the input args. + + Args: + circuit: Input circuit to apply the transformations on. The input circuit is not mutated. + context: `cirq.TransformerContext` storing common configurable options for transformers. + target_tags: Tags to be removed. + remove_if: A callable(tag) that returns True if the tag should be removed. Defaults to False. + + Returns: + Copy of the transformed input circuit. + """ + if context and target_tags.intersection(context.tags_to_ignore or set()): + raise ValueError("Can't remove tags in context.tags_to_ignore.") + + def _map_func(op: 'cirq.Operation', _) -> 'cirq.OP_TREE': + remaing_tags = set() + for tag in op.tags: + if not remove_if(tag) and tag not in target_tags: + remaing_tags.add(tag) + + return op.untagged.with_tags(*remaing_tags) + + return transformer_primitives.map_operations( + circuit, _map_func, deep=context.deep if context else False + ).freeze(copy=False) diff --git a/cirq-core/cirq/transformers/tag_transformers_test.py b/cirq-core/cirq/transformers/tag_transformers_test.py new file mode 100644 index 00000000000..ec118a99ee8 --- /dev/null +++ b/cirq-core/cirq/transformers/tag_transformers_test.py @@ -0,0 +1,67 @@ +# Copyright 2025 The Cirq Developers +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import cirq + + +def check_same_circuit_with_same_tag_sets(circuit1, circuit2): + for op1, op2 in zip(circuit1.all_operations(), circuit2.all_operations()): + assert set(op1.tags) == set(op2.tags) + assert op1.untagged == op2.untagged + + +def test_index_tags(): + q0, q1 = cirq.LineQubit.range(2) + input_circuit = cirq.Circuit( + cirq.X(q0).with_tags("tag1", "tag2"), + cirq.Y(q1).with_tags("tag1"), + cirq.CZ(q0, q1).with_tags("tag2"), + ) + expected_circuit = cirq.Circuit( + cirq.X(q0).with_tags("tag1_0", "tag2_0"), + cirq.Y(q1).with_tags("tag1_1"), + cirq.CZ(q0, q1).with_tags("tag2_1"), + ) + check_same_circuit_with_same_tag_sets( + cirq.index_tags(input_circuit, target_tags={"tag1", "tag2"}), expected_circuit + ) + + +def test_remove_tags(): + q0, q1 = cirq.LineQubit.range(2) + input_circuit = cirq.Circuit( + cirq.X(q0).with_tags("tag1", "tag2"), + cirq.Y(q1).with_tags("tag1"), + cirq.CZ(q0, q1).with_tags("tag2"), + ) + expected_circuit = cirq.Circuit( + cirq.X(q0).with_tags("tag2"), cirq.Y(q1), cirq.CZ(q0, q1).with_tags("tag2") + ) + cirq.testing.assert_equivalent_op_tree( + cirq.remove_tags(input_circuit, target_tags={"tag1"}), expected_circuit + ) + + +def test_remove_tags_via_remove_if(): + q0, q1 = cirq.LineQubit.range(2) + input_circuit = cirq.Circuit( + cirq.X(q0).with_tags("tag1", "tag2"), + cirq.Y(q1).with_tags("not_tag1"), + cirq.CZ(q0, q1).with_tags("tag2"), + ) + expected_circuit = cirq.Circuit(cirq.X(q0), cirq.Y(q1).with_tags("not_tag1"), cirq.CZ(q0, q1)) + cirq.testing.assert_equivalent_op_tree( + cirq.remove_tags(input_circuit, remove_if=lambda tag: tag.startswith("tag")), + expected_circuit, + ) From a61b1e3bbed0b311569376373b4612fb5492d488 Mon Sep 17 00:00:00 2001 From: Renyi Chen Date: Tue, 22 Apr 2025 18:32:57 -0700 Subject: [PATCH 07/19] fix --- cirq-core/cirq/transformers/symbolize.py | 6 +++--- cirq-core/cirq/transformers/tag_transformers.py | 4 ++-- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/cirq-core/cirq/transformers/symbolize.py b/cirq-core/cirq/transformers/symbolize.py index 7f948dc8a13..839f653944e 100644 --- a/cirq-core/cirq/transformers/symbolize.py +++ b/cirq-core/cirq/transformers/symbolize.py @@ -46,9 +46,9 @@ def symbolize_single_qubit_gates_by_indexed_tags( >>> new_circuit = cirq.symbolize_single_qubit_gates_by_indexed_tags(\ c, tag_prefix="phxz") >>> print(new_circuit) - 0: ───PhXZ(a=a0,x=x0,z=z0)───@───PhXZ(a=a1,x=x1,z=z1)─X─ - │ - 1: ────────────────────────@────────────────────────── + 0: ───PhXZ(a=a0,x=x0,z=z0)───@───PhXZ(a=a1,x=x1,z=z1)───X─── + │ + 1: ──────────────────────────@────────────────────────────── Args: circuit: Input circuit to apply the transformations on. The input circuit is not mutated. diff --git a/cirq-core/cirq/transformers/tag_transformers.py b/cirq-core/cirq/transformers/tag_transformers.py index 9664b5e2b36..2d6ba09fea3 100644 --- a/cirq-core/cirq/transformers/tag_transformers.py +++ b/cirq-core/cirq/transformers/tag_transformers.py @@ -57,7 +57,7 @@ def _map_func(op: 'cirq.Operation', _) -> 'cirq.OP_TREE': return transformer_primitives.map_operations( circuit, _map_func, deep=context.deep if context else False - ).freeze(copy=False) + ).unfreeze(copy=False) @transformer_api.transformer @@ -92,4 +92,4 @@ def _map_func(op: 'cirq.Operation', _) -> 'cirq.OP_TREE': return transformer_primitives.map_operations( circuit, _map_func, deep=context.deep if context else False - ).freeze(copy=False) + ).unfreeze(copy=False) From 830d2ebdf0f3ed13412d04fad5abfb38097f5b07 Mon Sep 17 00:00:00 2001 From: Renyi Chen Date: Tue, 22 Apr 2025 18:52:58 -0700 Subject: [PATCH 08/19] fix format, type, lint --- .../transformers/merge_single_qubit_gates.py | 42 +++++++------------ cirq-core/cirq/transformers/symbolize.py | 8 ++-- .../cirq/transformers/tag_transformers.py | 11 +++-- 3 files changed, 27 insertions(+), 34 deletions(-) diff --git a/cirq-core/cirq/transformers/merge_single_qubit_gates.py b/cirq-core/cirq/transformers/merge_single_qubit_gates.py index cff04faa702..48b1e1ce901 100644 --- a/cirq-core/cirq/transformers/merge_single_qubit_gates.py +++ b/cirq-core/cirq/transformers/merge_single_qubit_gates.py @@ -179,8 +179,8 @@ def _values_of_sweep(sweep: Sweep, key: TMeasurementKey): def _parameterize_phxz_in_circuits( circuit_list: List['cirq.Circuit'], merge_tag_prefix: str, - phxz_symbols: frozenset[sympy.Symbol], - remaining_symbols: frozenset[sympy.Symbol], + phxz_symbols: set[sympy.Symbol], + remaining_symbols: set[sympy.Symbol], sweep: Sweep, ) -> Sweep: """Parameterizes the circuits and returns a new sweep.""" @@ -280,13 +280,11 @@ def merge_single_qubit_gates_to_phxz_symbolized( ) # Step 0, isolate single qubit symbolized symbols and resolve the circuit on them. - single_qubit_gate_symbols: frozenset[sympy.Symbol] = frozenset( - set().union( - *[ - protocols.parameter_symbols(op) if symbolized_single_tag in op.tags else set() - for op in circuit_tagged.all_operations() - ] - ) + single_qubit_gate_symbols: set[sympy.Symbol] = set().union( + *[ + protocols.parameter_symbols(op) if symbolized_single_tag in op.tags else set() + for op in circuit_tagged.all_operations() + ] ) # If all single qubit gates are not parameterized, call the nonparamerized version of # the transformer. @@ -302,7 +300,6 @@ def merge_single_qubit_gates_to_phxz_symbolized( # Step 1, merge single qubit gates per resolved circuit, preserving the "symbolized_single_tag". merged_circuits: List['cirq.Circuit'] = [] - phxz_symbols: set[sympy.Symbols] = set() for resolved_circuit in resolved_circuits: merged_circuit = index_tags( merge_single_qubit_gates_to_phxz( @@ -336,28 +333,21 @@ def merge_single_qubit_gates_to_phxz_symbolized( symbolize.symbolize_single_qubit_gates_by_indexed_tags( merged_circuits[0], tag_prefix=symbolized_single_tag ), - remove_if=lambda tag: tag.startswith(symbolized_single_tag), + remove_if=lambda tag: str(tag).startswith(symbolized_single_tag), ) ) # Step 3, get N sets of parameterizations as new_sweep. - phxz_symbols: frozenset[sympy.Symbol] = frozenset( - set().union( - *[ - set( - [ - sympy.Symbol(tag.replace(f"{symbolized_single_tag}_", s)) - for s in ["x", "z", "a"] - ] - ) - for tag in _all_tags_startswith( - merged_circuits[0], startswith=symbolized_single_tag - ) - ] - ) + phxz_symbols: set[sympy.Symbol] = set().union( + *[ + set( + [sympy.Symbol(tag.replace(f"{symbolized_single_tag}_", s)) for s in ["x", "z", "a"]] + ) + for tag in _all_tags_startswith(merged_circuits[0], startswith=symbolized_single_tag) + ] ) # Remaining symbols, e.g., 2 qubit gates' symbols. Sweep of those symbols keeps unchanged. - remaining_symbols: frozenset[sympy.Symbol] = frozenset( + remaining_symbols: set[sympy.Symbol] = set( protocols.parameter_symbols(circuit) - single_qubit_gate_symbols ) new_sweep = _parameterize_phxz_in_circuits( diff --git a/cirq-core/cirq/transformers/symbolize.py b/cirq-core/cirq/transformers/symbolize.py index 839f653944e..009127ae44b 100644 --- a/cirq-core/cirq/transformers/symbolize.py +++ b/cirq-core/cirq/transformers/symbolize.py @@ -12,8 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Hashable, Optional, TYPE_CHECKING import re +from typing import Hashable, Optional, TYPE_CHECKING import sympy from cirq import ops @@ -64,9 +64,9 @@ def _map_func(op: 'cirq.Operation', _): tags: set[Hashable] = set(op.tags) tag_id: None | int = None for tag in tags: - if re.fullmatch(f"{tag_prefix}_\\d+", tag): + if re.fullmatch(f"{tag_prefix}_\\d+", str(tag)): if tag_id is None: - tag_id = tag.split("_")[-1] + tag_id = int(str(tag).rsplit("_", maxsplit=-1)[-1]) else: raise ValueError(f"Multiple tags are prefixed with {tag_prefix}.") if not tag_id: @@ -84,5 +84,5 @@ def _map_func(op: 'cirq.Operation', _): circuit.freeze(), _map_func, deep=context.deep if context else False, - tags_to_ignore=context.tags_to_ignore if context else set(), + tags_to_ignore=context.tags_to_ignore if context else [], ).unfreeze(copy=False) diff --git a/cirq-core/cirq/transformers/tag_transformers.py b/cirq-core/cirq/transformers/tag_transformers.py index 2d6ba09fea3..8adaabdcaa7 100644 --- a/cirq-core/cirq/transformers/tag_transformers.py +++ b/cirq-core/cirq/transformers/tag_transformers.py @@ -13,7 +13,7 @@ # limitations under the License. import itertools -from typing import Callable, Hashable, Optional, TYPE_CHECKING +from typing import Callable, Hashable, Optional, Sequence, TYPE_CHECKING from cirq.transformers import transformer_api, transformer_primitives @@ -27,7 +27,7 @@ def index_tags( circuit: 'cirq.AbstractCircuit', *, context: Optional['cirq.TransformerContext'] = None, - target_tags: frozenset[Hashable] = frozenset(), + target_tags: Optional[set[Hashable]] = None, index_if: Callable[[Hashable], bool] = lambda _: True, ) -> 'cirq.Circuit': """Indexes all the tags in target_tags tag_0, tag_1, .... @@ -41,6 +41,7 @@ def index_tags( Returns: Copy of the transformed input circuit. """ + target_tags = set(target_tags) if target_tags else set() tag_iter_by_tags = {tag: itertools.count(start=0, step=1) for tag in target_tags} tags_to_ignore = context.tags_to_ignore if context else set() @@ -65,7 +66,7 @@ def remove_tags( circuit: 'cirq.AbstractCircuit', *, context: Optional['cirq.TransformerContext'] = None, - target_tags: frozenset[Hashable] = frozenset(), + target_tags: Optional[set[Hashable]] = None, remove_if: Callable[[Hashable], bool] = lambda _: False, ) -> 'cirq.Circuit': """Remove tags from the operations based on the input args. @@ -74,11 +75,13 @@ def remove_tags( circuit: Input circuit to apply the transformations on. The input circuit is not mutated. context: `cirq.TransformerContext` storing common configurable options for transformers. target_tags: Tags to be removed. - remove_if: A callable(tag) that returns True if the tag should be removed. Defaults to False. + remove_if: A callable(tag) that returns True if the tag should be removed. + Defaults to False. Returns: Copy of the transformed input circuit. """ + target_tags = set(target_tags) if target_tags else set() if context and target_tags.intersection(context.tags_to_ignore or set()): raise ValueError("Can't remove tags in context.tags_to_ignore.") From b09d33c30c2bc6682087162e051cbd3f1d9b691d Mon Sep 17 00:00:00 2001 From: Renyi Chen Date: Tue, 22 Apr 2025 19:06:34 -0700 Subject: [PATCH 09/19] fix --- cirq-core/cirq/transformers/merge_single_qubit_gates.py | 6 +++++- cirq-core/cirq/transformers/symbolize.py | 3 ++- cirq-core/cirq/transformers/tag_transformers.py | 7 +++---- 3 files changed, 10 insertions(+), 6 deletions(-) diff --git a/cirq-core/cirq/transformers/merge_single_qubit_gates.py b/cirq-core/cirq/transformers/merge_single_qubit_gates.py index 48b1e1ce901..4e89086e065 100644 --- a/cirq-core/cirq/transformers/merge_single_qubit_gates.py +++ b/cirq-core/cirq/transformers/merge_single_qubit_gates.py @@ -327,6 +327,10 @@ def merge_single_qubit_gates_to_phxz_symbolized( ): raise RuntimeError("Different resolvers in sweep resulted in different merged structures.") + import logging + + logging.info(f"\n{merged_circuits[0]}") + # Step 2, get the new symbolized circuit by mapping merged operations. new_circuit = align.align_right( remove_tags( @@ -354,4 +358,4 @@ def merge_single_qubit_gates_to_phxz_symbolized( merged_circuits, symbolized_single_tag, phxz_symbols, remaining_symbols, sweep ) - return new_circuit.unfreeze(copy=False), new_sweep + return new_circuit, new_sweep diff --git a/cirq-core/cirq/transformers/symbolize.py b/cirq-core/cirq/transformers/symbolize.py index 009127ae44b..8ef89b116df 100644 --- a/cirq-core/cirq/transformers/symbolize.py +++ b/cirq-core/cirq/transformers/symbolize.py @@ -14,6 +14,7 @@ import re from typing import Hashable, Optional, TYPE_CHECKING + import sympy from cirq import ops @@ -69,7 +70,7 @@ def _map_func(op: 'cirq.Operation', _): tag_id = int(str(tag).rsplit("_", maxsplit=-1)[-1]) else: raise ValueError(f"Multiple tags are prefixed with {tag_prefix}.") - if not tag_id: + if tag_id is None: return op tags.remove(f"{tag_prefix}_{tag_id}") phxz_params = { diff --git a/cirq-core/cirq/transformers/tag_transformers.py b/cirq-core/cirq/transformers/tag_transformers.py index 8adaabdcaa7..b545433b01e 100644 --- a/cirq-core/cirq/transformers/tag_transformers.py +++ b/cirq-core/cirq/transformers/tag_transformers.py @@ -13,8 +13,7 @@ # limitations under the License. import itertools -from typing import Callable, Hashable, Optional, Sequence, TYPE_CHECKING - +from typing import Callable, Hashable, Optional, TYPE_CHECKING from cirq.transformers import transformer_api, transformer_primitives @@ -41,7 +40,7 @@ def index_tags( Returns: Copy of the transformed input circuit. """ - target_tags = set(target_tags) if target_tags else set() + target_tags = target_tags or set() tag_iter_by_tags = {tag: itertools.count(start=0, step=1) for tag in target_tags} tags_to_ignore = context.tags_to_ignore if context else set() @@ -81,7 +80,7 @@ def remove_tags( Returns: Copy of the transformed input circuit. """ - target_tags = set(target_tags) if target_tags else set() + target_tags = target_tags or set() if context and target_tags.intersection(context.tags_to_ignore or set()): raise ValueError("Can't remove tags in context.tags_to_ignore.") From dac4e32b90ebe256394337da0f9d4207734f4c41 Mon Sep 17 00:00:00 2001 From: Renyi Chen Date: Tue, 22 Apr 2025 19:32:51 -0700 Subject: [PATCH 10/19] fix tags_to_ignore and coverage --- .../cirq/transformers/tag_transformers.py | 11 +++++---- .../transformers/tag_transformers_test.py | 24 +++++++++++++++++-- 2 files changed, 29 insertions(+), 6 deletions(-) diff --git a/cirq-core/cirq/transformers/tag_transformers.py b/cirq-core/cirq/transformers/tag_transformers.py index b545433b01e..8bc281a2b3b 100644 --- a/cirq-core/cirq/transformers/tag_transformers.py +++ b/cirq-core/cirq/transformers/tag_transformers.py @@ -56,7 +56,7 @@ def _map_func(op: 'cirq.Operation', _) -> 'cirq.OP_TREE': return op.untagged.with_tags(*tag_set) return transformer_primitives.map_operations( - circuit, _map_func, deep=context.deep if context else False + circuit, _map_func, deep=context.deep if context else False, tags_to_ignore=tags_to_ignore ).unfreeze(copy=False) @@ -70,6 +70,8 @@ def remove_tags( ) -> 'cirq.Circuit': """Remove tags from the operations based on the input args. + Note: context.tags_to_ignore has higher priority than target_tags and remove_if. + Args: circuit: Input circuit to apply the transformations on. The input circuit is not mutated. context: `cirq.TransformerContext` storing common configurable options for transformers. @@ -81,8 +83,6 @@ def remove_tags( Copy of the transformed input circuit. """ target_tags = target_tags or set() - if context and target_tags.intersection(context.tags_to_ignore or set()): - raise ValueError("Can't remove tags in context.tags_to_ignore.") def _map_func(op: 'cirq.Operation', _) -> 'cirq.OP_TREE': remaing_tags = set() @@ -93,5 +93,8 @@ def _map_func(op: 'cirq.Operation', _) -> 'cirq.OP_TREE': return op.untagged.with_tags(*remaing_tags) return transformer_primitives.map_operations( - circuit, _map_func, deep=context.deep if context else False + circuit, + _map_func, + deep=context.deep if context else False, + tags_to_ignore=context.tags_to_ignore if context else [], ).unfreeze(copy=False) diff --git a/cirq-core/cirq/transformers/tag_transformers_test.py b/cirq-core/cirq/transformers/tag_transformers_test.py index ec118a99ee8..bfb90c7dde5 100644 --- a/cirq-core/cirq/transformers/tag_transformers_test.py +++ b/cirq-core/cirq/transformers/tag_transformers_test.py @@ -48,7 +48,7 @@ def test_remove_tags(): expected_circuit = cirq.Circuit( cirq.X(q0).with_tags("tag2"), cirq.Y(q1), cirq.CZ(q0, q1).with_tags("tag2") ) - cirq.testing.assert_equivalent_op_tree( + check_same_circuit_with_same_tag_sets( cirq.remove_tags(input_circuit, target_tags={"tag1"}), expected_circuit ) @@ -61,7 +61,27 @@ def test_remove_tags_via_remove_if(): cirq.CZ(q0, q1).with_tags("tag2"), ) expected_circuit = cirq.Circuit(cirq.X(q0), cirq.Y(q1).with_tags("not_tag1"), cirq.CZ(q0, q1)) - cirq.testing.assert_equivalent_op_tree( + check_same_circuit_with_same_tag_sets( cirq.remove_tags(input_circuit, remove_if=lambda tag: tag.startswith("tag")), expected_circuit, ) + + +def test_remove_tags_with_tags_to_ignore(): + q0, q1 = cirq.LineQubit.range(2) + input_circuit = cirq.Circuit( + cirq.X(q0).with_tags("tag1", "tag0"), + cirq.Y(q1).with_tags("not_tag1"), + cirq.CZ(q0, q1).with_tags("tag2"), + ) + expected_circuit = cirq.Circuit( + cirq.X(q0).with_tags("tag1", "tag0"), cirq.Y(q1).with_tags("not_tag1"), cirq.CZ(q0, q1) + ) + check_same_circuit_with_same_tag_sets( + cirq.remove_tags( + input_circuit, + remove_if=lambda tag: tag.startswith("tag"), + context=cirq.TransformerContext(tags_to_ignore=["tag0"]), + ), + expected_circuit, + ) From 2982730abdea5ea3bf5447baf3935ed280fa3090 Mon Sep 17 00:00:00 2001 From: Renyi Chen Date: Tue, 22 Apr 2025 19:34:33 -0700 Subject: [PATCH 11/19] rm logging code.. --- cirq-core/cirq/transformers/merge_single_qubit_gates.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/cirq-core/cirq/transformers/merge_single_qubit_gates.py b/cirq-core/cirq/transformers/merge_single_qubit_gates.py index 4e89086e065..efb2da5dc4c 100644 --- a/cirq-core/cirq/transformers/merge_single_qubit_gates.py +++ b/cirq-core/cirq/transformers/merge_single_qubit_gates.py @@ -327,10 +327,6 @@ def merge_single_qubit_gates_to_phxz_symbolized( ): raise RuntimeError("Different resolvers in sweep resulted in different merged structures.") - import logging - - logging.info(f"\n{merged_circuits[0]}") - # Step 2, get the new symbolized circuit by mapping merged operations. new_circuit = align.align_right( remove_tags( From 37163c08e0e5a179c481c2c3b7ba8eec23c3a4a1 Mon Sep 17 00:00:00 2001 From: Renyi Chen Date: Tue, 22 Apr 2025 19:44:09 -0700 Subject: [PATCH 12/19] rm unnecessary and wrong utils --- .../cirq/transformers/merge_single_qubit_gates.py | 12 +++++------- 1 file changed, 5 insertions(+), 7 deletions(-) diff --git a/cirq-core/cirq/transformers/merge_single_qubit_gates.py b/cirq-core/cirq/transformers/merge_single_qubit_gates.py index efb2da5dc4c..edcc846ae33 100644 --- a/cirq-core/cirq/transformers/merge_single_qubit_gates.py +++ b/cirq-core/cirq/transformers/merge_single_qubit_gates.py @@ -171,11 +171,6 @@ def merge_func(m1: 'cirq.Moment', m2: 'cirq.Moment') -> Optional['cirq.Moment']: ).unfreeze(copy=False) -def _values_of_sweep(sweep: Sweep, key: TMeasurementKey): - p = sympy.Symbol(key) if isinstance(key, str) else key - return [resolver.value_of(p) for resolver in sweep] - - def _parameterize_phxz_in_circuits( circuit_list: List['cirq.Circuit'], merge_tag_prefix: str, @@ -186,7 +181,7 @@ def _parameterize_phxz_in_circuits( """Parameterizes the circuits and returns a new sweep.""" values_by_params: Dict[str, List[float]] = { **{str(s): [] for s in phxz_symbols}, - **{str(s): _values_of_sweep(sweep, s) for s in remaining_symbols}, + **{str(s): [resolver.value_of(s) for resolver in sweep] for s in remaining_symbols}, } for circuit in circuit_list: @@ -291,7 +286,10 @@ def merge_single_qubit_gates_to_phxz_symbolized( if not single_qubit_gate_symbols: return (merge_single_qubit_gates_to_phxz(circuit, context=context, atol=atol), sweep) sweep_of_single: Sweep = Zip( - *[Points(key=k, points=_values_of_sweep(sweep, k)) for k in single_qubit_gate_symbols] + *[ + Points(key=k, points=[resolver.value_of(k) for resolver in sweep]) + for k in single_qubit_gate_symbols + ] ) # Get all resolved circuits from all sets of resolvers in the sweep. resolved_circuits = [ From b680e5008e61d9f99f51b682b4ef47c967d14001 Mon Sep 17 00:00:00 2001 From: Renyi Chen Date: Tue, 22 Apr 2025 19:46:29 -0700 Subject: [PATCH 13/19] fix tag_transformer --- cirq-core/cirq/transformers/tag_transformers.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/cirq-core/cirq/transformers/tag_transformers.py b/cirq-core/cirq/transformers/tag_transformers.py index 8bc281a2b3b..c4429cb1b46 100644 --- a/cirq-core/cirq/transformers/tag_transformers.py +++ b/cirq-core/cirq/transformers/tag_transformers.py @@ -42,11 +42,10 @@ def index_tags( """ target_tags = target_tags or set() tag_iter_by_tags = {tag: itertools.count(start=0, step=1) for tag in target_tags} - tags_to_ignore = context.tags_to_ignore if context else set() def _map_func(op: 'cirq.Operation', _) -> 'cirq.OP_TREE': tag_set = set(op.tags) - if not index_if(op) or tag_set.intersection(tags_to_ignore): # Skip indexing + if not index_if(op): return op nonlocal tag_iter_by_tags for tag in target_tags.intersection(op.tags): @@ -56,7 +55,10 @@ def _map_func(op: 'cirq.Operation', _) -> 'cirq.OP_TREE': return op.untagged.with_tags(*tag_set) return transformer_primitives.map_operations( - circuit, _map_func, deep=context.deep if context else False, tags_to_ignore=tags_to_ignore + circuit, + _map_func, + deep=context.deep if context else False, + tags_to_ignore=context.tags_to_ignore if context else [], ).unfreeze(copy=False) From 875179b7de213fd91a99996d5165ef2946f07a7d Mon Sep 17 00:00:00 2001 From: Renyi Chen Date: Tue, 22 Apr 2025 19:52:26 -0700 Subject: [PATCH 14/19] fix TParamValComplex --- cirq-core/cirq/transformers/merge_single_qubit_gates.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/cirq-core/cirq/transformers/merge_single_qubit_gates.py b/cirq-core/cirq/transformers/merge_single_qubit_gates.py index edcc846ae33..2e7c8ad3c87 100644 --- a/cirq-core/cirq/transformers/merge_single_qubit_gates.py +++ b/cirq-core/cirq/transformers/merge_single_qubit_gates.py @@ -19,7 +19,6 @@ import sympy from cirq import circuits, ops, protocols -from cirq.study.result import TMeasurementKey from cirq.study.sweeps import Points, Sweep, Zip from cirq.transformers import ( align, @@ -179,7 +178,7 @@ def _parameterize_phxz_in_circuits( sweep: Sweep, ) -> Sweep: """Parameterizes the circuits and returns a new sweep.""" - values_by_params: Dict[str, List[float]] = { + values_by_params: Dict[str, 'cirq.TParamValComplex'] = { **{str(s): [] for s in phxz_symbols}, **{str(s): [resolver.value_of(s) for resolver in sweep] for s in remaining_symbols}, } @@ -287,7 +286,7 @@ def merge_single_qubit_gates_to_phxz_symbolized( return (merge_single_qubit_gates_to_phxz(circuit, context=context, atol=atol), sweep) sweep_of_single: Sweep = Zip( *[ - Points(key=k, points=[resolver.value_of(k) for resolver in sweep]) + Points(key=k, points=[float(resolver.value_of(k)) for resolver in sweep]) for k in single_qubit_gate_symbols ] ) From 414adda971c15975cdef6a2624176542620144b5 Mon Sep 17 00:00:00 2001 From: Renyi Chen Date: Tue, 22 Apr 2025 21:17:52 -0700 Subject: [PATCH 15/19] fix param resolve, type casting etc. --- .../transformers/merge_single_qubit_gates.py | 35 +++++++++++-------- 1 file changed, 20 insertions(+), 15 deletions(-) diff --git a/cirq-core/cirq/transformers/merge_single_qubit_gates.py b/cirq-core/cirq/transformers/merge_single_qubit_gates.py index 2e7c8ad3c87..ec707aea914 100644 --- a/cirq-core/cirq/transformers/merge_single_qubit_gates.py +++ b/cirq-core/cirq/transformers/merge_single_qubit_gates.py @@ -14,12 +14,13 @@ """Transformer passes to combine adjacent single-qubit rotations.""" -from typing import Callable, Dict, Hashable, List, Optional, Tuple, TYPE_CHECKING +from typing import Callable, cast, Dict, Hashable, List, Optional, Tuple, TYPE_CHECKING import sympy from cirq import circuits, ops, protocols -from cirq.study.sweeps import Points, Sweep, Zip +from cirq.study.sweeps import dict_to_zip_sweep, ListSweep, ProductOrZipSweepLike, Sweep, Zip +from cirq.study.resolver import ParamResolver from cirq.transformers import ( align, merge_k_qubit_gates, @@ -170,6 +171,14 @@ def merge_func(m1: 'cirq.Moment', m2: 'cirq.Moment') -> Optional['cirq.Moment']: ).unfreeze(copy=False) +def _sweep_on_symbols(sweep: Sweep, symbols: set[sympy.Symbol]) -> Sweep: + new_resolvers: List['cirq.ParamResolver'] = [] + for resolver in sweep: + param_dict: 'cirq.ParamMappingType' = {s: resolver.value_of(s) for s in symbols} + new_resolvers.append(ParamResolver(param_dict)) + return ListSweep(new_resolvers) + + def _parameterize_phxz_in_circuits( circuit_list: List['cirq.Circuit'], merge_tag_prefix: str, @@ -178,10 +187,7 @@ def _parameterize_phxz_in_circuits( sweep: Sweep, ) -> Sweep: """Parameterizes the circuits and returns a new sweep.""" - values_by_params: Dict[str, 'cirq.TParamValComplex'] = { - **{str(s): [] for s in phxz_symbols}, - **{str(s): [resolver.value_of(s) for resolver in sweep] for s in remaining_symbols}, - } + values_by_params: Dict[str, List[float]] = {**{str(s): [] for s in phxz_symbols}} for circuit in circuit_list: for op in circuit.all_operations(): @@ -204,7 +210,10 @@ def _parameterize_phxz_in_circuits( values_by_params[f"z{sid}"].append(z) values_by_params[f"a{sid}"].append(a) - return Zip(*[Points(key=key, points=values) for key, values in values_by_params.items()]) + return Zip( + dict_to_zip_sweep(cast(ProductOrZipSweepLike, values_by_params)), + _sweep_on_symbols(sweep, remaining_symbols), + ) def _all_tags_startswith(circuit: 'cirq.AbstractCircuit', startswith: str): @@ -284,18 +293,14 @@ def merge_single_qubit_gates_to_phxz_symbolized( # the transformer. if not single_qubit_gate_symbols: return (merge_single_qubit_gates_to_phxz(circuit, context=context, atol=atol), sweep) - sweep_of_single: Sweep = Zip( - *[ - Points(key=k, points=[float(resolver.value_of(k)) for resolver in sweep]) - for k in single_qubit_gate_symbols - ] - ) + sweep_of_single: Sweep = _sweep_on_symbols(sweep, single_qubit_gate_symbols) # Get all resolved circuits from all sets of resolvers in the sweep. resolved_circuits = [ protocols.resolve_parameters(circuit_tagged, resolver) for resolver in sweep_of_single ] - # Step 1, merge single qubit gates per resolved circuit, preserving the "symbolized_single_tag". + # Step 1, merge single qubit gates per resolved circuit, preserving + # the symbolized_single_tag with indexes. merged_circuits: List['cirq.Circuit'] = [] for resolved_circuit in resolved_circuits: merged_circuit = index_tags( @@ -324,7 +329,7 @@ def merge_single_qubit_gates_to_phxz_symbolized( ): raise RuntimeError("Different resolvers in sweep resulted in different merged structures.") - # Step 2, get the new symbolized circuit by mapping merged operations. + # Step 2, get the new symbolized circuit by symbolization on indexed symbolized_single_tag. new_circuit = align.align_right( remove_tags( symbolize.symbolize_single_qubit_gates_by_indexed_tags( From f197a89771b480112c8dbea3e7af1286596cdbb4 Mon Sep 17 00:00:00 2001 From: Renyi Chen Date: Tue, 22 Apr 2025 21:20:52 -0700 Subject: [PATCH 16/19] fix format --- cirq-core/cirq/transformers/merge_single_qubit_gates.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cirq-core/cirq/transformers/merge_single_qubit_gates.py b/cirq-core/cirq/transformers/merge_single_qubit_gates.py index ec707aea914..51279d3bbbc 100644 --- a/cirq-core/cirq/transformers/merge_single_qubit_gates.py +++ b/cirq-core/cirq/transformers/merge_single_qubit_gates.py @@ -19,8 +19,8 @@ import sympy from cirq import circuits, ops, protocols -from cirq.study.sweeps import dict_to_zip_sweep, ListSweep, ProductOrZipSweepLike, Sweep, Zip from cirq.study.resolver import ParamResolver +from cirq.study.sweeps import dict_to_zip_sweep, ListSweep, ProductOrZipSweepLike, Sweep, Zip from cirq.transformers import ( align, merge_k_qubit_gates, From 807f35f53ff1928266d1224691008e2efd29a2a8 Mon Sep 17 00:00:00 2001 From: Renyi Chen Date: Tue, 22 Apr 2025 21:43:57 -0700 Subject: [PATCH 17/19] fix checks --- cirq-core/cirq/transformers/symbolize.py | 6 +++--- cirq-core/cirq/transformers/tag_transformers.py | 4 ---- 2 files changed, 3 insertions(+), 7 deletions(-) diff --git a/cirq-core/cirq/transformers/symbolize.py b/cirq-core/cirq/transformers/symbolize.py index 8ef89b116df..72459692f5b 100644 --- a/cirq-core/cirq/transformers/symbolize.py +++ b/cirq-core/cirq/transformers/symbolize.py @@ -41,9 +41,9 @@ def symbolize_single_qubit_gates_by_indexed_tags( cirq.Y(q0).with_tags("phxz_1"),\ cirq.X(q0)) >>> print(c) - 0: ───X["phxz_0"]───@───Y["phxz_1"]───X─── - │ - 1: ─────────────────@───────────────────── + 0: ───X[phxz_0]───@───Y[phxz_1]───X─── + │ + 1: ───────────────@─────────────────── >>> new_circuit = cirq.symbolize_single_qubit_gates_by_indexed_tags(\ c, tag_prefix="phxz") >>> print(new_circuit) diff --git a/cirq-core/cirq/transformers/tag_transformers.py b/cirq-core/cirq/transformers/tag_transformers.py index c4429cb1b46..bf47d269a35 100644 --- a/cirq-core/cirq/transformers/tag_transformers.py +++ b/cirq-core/cirq/transformers/tag_transformers.py @@ -27,7 +27,6 @@ def index_tags( *, context: Optional['cirq.TransformerContext'] = None, target_tags: Optional[set[Hashable]] = None, - index_if: Callable[[Hashable], bool] = lambda _: True, ) -> 'cirq.Circuit': """Indexes all the tags in target_tags tag_0, tag_1, .... @@ -35,7 +34,6 @@ def index_tags( circuit: Input circuit to apply the transformations on. The input circuit is not mutated. context: `cirq.TransformerContext` storing common configurable options for transformers. target_tags: Tags to be indexed. - index_if: A callable that returns True if its tags should be indexed. Defaults to True. Returns: Copy of the transformed input circuit. @@ -45,8 +43,6 @@ def index_tags( def _map_func(op: 'cirq.Operation', _) -> 'cirq.OP_TREE': tag_set = set(op.tags) - if not index_if(op): - return op nonlocal tag_iter_by_tags for tag in target_tags.intersection(op.tags): tag_set.remove(tag) From 69ad9eb4e955e8cac3995a886723fec37af871eb Mon Sep 17 00:00:00 2001 From: Renyi Chen Date: Tue, 22 Apr 2025 22:34:01 -0700 Subject: [PATCH 18/19] Small fixes of comments and docstrings. --- .../cirq/transformers/merge_single_qubit_gates.py | 10 +++++----- cirq-core/cirq/transformers/symbolize.py | 2 +- cirq-core/cirq/transformers/tag_transformers.py | 4 ++-- 3 files changed, 8 insertions(+), 8 deletions(-) diff --git a/cirq-core/cirq/transformers/merge_single_qubit_gates.py b/cirq-core/cirq/transformers/merge_single_qubit_gates.py index 51279d3bbbc..11eb60fa3e3 100644 --- a/cirq-core/cirq/transformers/merge_single_qubit_gates.py +++ b/cirq-core/cirq/transformers/merge_single_qubit_gates.py @@ -232,8 +232,8 @@ def merge_single_qubit_gates_to_phxz_symbolized( sweep: Sweep, atol: float = 1e-8, ) -> Tuple['cirq.Circuit', Sweep]: - """Merge consecutive single qubit gates as PhasedXZ Gates. Symbolize if any of the consecutive - gates is symbolized. + """Merges consecutive single qubit gates as PhasedXZ Gates. Symbolizes if any of + the consecutive gates is symbolized. Example: >>> q0, q1 = cirq.LineQubit.range(2) @@ -258,9 +258,9 @@ def merge_single_qubit_gates_to_phxz_symbolized( Args: circuit: Input circuit to transform. It will not be modified. + context: `cirq.TransformerContext` storing common configurable options for transformers. sweep: Sweep of the symbols in the input circuit, updated Sweep will be returned based on the transformation. - context: `cirq.TransformerContext` storing common configurable options for transformers. atol: Absolute tolerance to angle error. Larger values allow more negligible gates to be dropped, smaller values increase accuracy. @@ -282,7 +282,7 @@ def merge_single_qubit_gates_to_phxz_symbolized( deep=deep, ) - # Step 0, isolate single qubit symbolized symbols and resolve the circuit on them. + # Step 0, isolate single qubit symbols and resolve the circuit on them. single_qubit_gate_symbols: set[sympy.Symbol] = set().union( *[ protocols.parameter_symbols(op) if symbolized_single_tag in op.tags else set() @@ -294,7 +294,7 @@ def merge_single_qubit_gates_to_phxz_symbolized( if not single_qubit_gate_symbols: return (merge_single_qubit_gates_to_phxz(circuit, context=context, atol=atol), sweep) sweep_of_single: Sweep = _sweep_on_symbols(sweep, single_qubit_gate_symbols) - # Get all resolved circuits from all sets of resolvers in the sweep. + # Get all resolved circuits from all sets of resolvers in sweep_of_single. resolved_circuits = [ protocols.resolve_parameters(circuit_tagged, resolver) for resolver in sweep_of_single ] diff --git a/cirq-core/cirq/transformers/symbolize.py b/cirq-core/cirq/transformers/symbolize.py index 72459692f5b..c5f1a831a10 100644 --- a/cirq-core/cirq/transformers/symbolize.py +++ b/cirq-core/cirq/transformers/symbolize.py @@ -31,7 +31,7 @@ def symbolize_single_qubit_gates_by_indexed_tags( context: Optional['cirq.TransformerContext'] = None, tag_prefix: Optional[str] = "TO-PHXZ", ) -> 'cirq.Circuit': - """Symbolize single qubit operations by indexed tags prefixed by tag_prefix. + """Symbolizes single qubit operations by indexed tags prefixed by tag_prefix. Example: >>> q0, q1 = cirq.LineQubit.range(2) diff --git a/cirq-core/cirq/transformers/tag_transformers.py b/cirq-core/cirq/transformers/tag_transformers.py index bf47d269a35..28cb52c2916 100644 --- a/cirq-core/cirq/transformers/tag_transformers.py +++ b/cirq-core/cirq/transformers/tag_transformers.py @@ -28,7 +28,7 @@ def index_tags( context: Optional['cirq.TransformerContext'] = None, target_tags: Optional[set[Hashable]] = None, ) -> 'cirq.Circuit': - """Indexes all the tags in target_tags tag_0, tag_1, .... + """Indexes tags in target_tags as tag_0, tag_1, ... per tag. Args: circuit: Input circuit to apply the transformations on. The input circuit is not mutated. @@ -66,7 +66,7 @@ def remove_tags( target_tags: Optional[set[Hashable]] = None, remove_if: Callable[[Hashable], bool] = lambda _: False, ) -> 'cirq.Circuit': - """Remove tags from the operations based on the input args. + """Removes tags from the operations based on the input args. Note: context.tags_to_ignore has higher priority than target_tags and remove_if. From e67e6611d3aeb0beaac554bab4fb645ad581cc0c Mon Sep 17 00:00:00 2001 From: Renyi Chen Date: Wed, 23 Apr 2025 12:39:02 -0700 Subject: [PATCH 19/19] Fix docstring, rm empty lines. --- cirq-core/cirq/transformers/merge_single_qubit_gates.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/cirq-core/cirq/transformers/merge_single_qubit_gates.py b/cirq-core/cirq/transformers/merge_single_qubit_gates.py index 11eb60fa3e3..ec57c794ee9 100644 --- a/cirq-core/cirq/transformers/merge_single_qubit_gates.py +++ b/cirq-core/cirq/transformers/merge_single_qubit_gates.py @@ -87,7 +87,6 @@ def merge_single_qubit_gates_to_phxz( Args: circuit: Input circuit to transform. It will not be modified. context: `cirq.TransformerContext` storing common configurable options for transformers. - merge_tag: If provided, tag merged PhXZ gate with it. merge_tags_fn: A callable returns the tags to be added to the merged operation. atol: Absolute tolerance to angle error. Larger values allow more negligible gates to be dropped, smaller values increase accuracy. @@ -97,11 +96,9 @@ def merge_single_qubit_gates_to_phxz( """ def rewriter(circuit_op: 'cirq.CircuitOperation') -> 'cirq.OP_TREE': - u = protocols.unitary(circuit_op) if protocols.num_qubits(circuit_op) == 0: return ops.GlobalPhaseGate(u[0, 0]).on() - gate = single_qubit_decompositions.single_qubit_matrix_to_phxz(u, atol) or ops.I phxz_op = gate.on(circuit_op.qubits[0]) return phxz_op.with_tags(*merge_tags_fn(circuit_op)) if merge_tags_fn else phxz_op