From d6018f8b3531119ff81cbe002e9256999e452c58 Mon Sep 17 00:00:00 2001 From: Stefan Krastanov Date: Sat, 20 Apr 2024 15:17:33 -0400 Subject: [PATCH] Fix type instabilities in compactified methods (due to type mismatch in the uncompactified method) (#265) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ``` function x_diag_circuit_noisy_measurement(csize) circuit = [] for i in 1:csize push!(circuit, PauliError(i, 0.1)) push!(circuit, sHadamard(i)) push!(circuit, sCNOT(i, csize+1)) push!(circuit, sMZ(csize+1,i)) push!(circuit, ClassicalXOR((1,(i%6+6)),i)) end return circuit end @benchmark pftrajectories(state,circuit) setup=(state=PauliFrame(1000, 1001, 1001); circuit=compactify_circuit(x_diag_circuit_noisy_measurement(1000))) evals=1 Before: BenchmarkTools.Trial: 10 samples with 1 evaluation. Range (min … max): 2.885 ms … 2.962 ms ┊ GC (min … max): 0.00% … 0.00% Time (median): 2.900 ms ┊ GC (median): 0.00% Time (mean ± σ): 2.912 ms ± 30.387 μs ┊ GC (mean ± σ): 0.00% ± 0.00% █▁ ▁ ▁ ▁ ▁ ▁ ▁ ▁ ██▁█▁▁▁█▁▁▁▁▁▁█▁▁▁▁█▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁█▁▁▁▁▁▁▁▁▁▁█▁▁█ ▁ 2.89 ms Histogram: frequency by time 2.96 ms < Memory estimate: 187.50 KiB, allocs estimate: 4000. After: BenchmarkTools.Trial: 749 samples with 1 evaluation. Range (min … max): 2.929 ms … 3.097 ms ┊ GC (min … max): 0.00% … 0.00% Time (median): 2.948 ms ┊ GC (median): 0.00% Time (mean ± σ): 2.951 ms ± 16.854 μs ┊ GC (mean ± σ): 0.00% ± 0.00% ▃█▆▂ ▂▂▃▄▅▆█████▅▅▄▄▃▃▄▃▃▃▂▂▁▁▁▂▁▂▂▁▁▁▁▁▁▁▁▁▁▁▂▁▁▁▁▁▁▁▁▁▁▁▁▂▁▂▂ ▃ 2.93 ms Histogram: frequency by time 3.06 ms < Memory estimate: 0 bytes, allocs estimate: 0. ``` --- src/misc_ops.jl | 7 ++----- src/pauli_frames.jl | 1 + src/sumtypes.jl | 19 ++++++++++++++----- 3 files changed, 17 insertions(+), 10 deletions(-) diff --git a/src/misc_ops.jl b/src/misc_ops.jl index a79d5fc05..bf26489e0 100644 --- a/src/misc_ops.jl +++ b/src/misc_ops.jl @@ -133,9 +133,6 @@ struct ClassicalXOR{N} <: AbstractOperation bits::NTuple{N,Int} "The index of the classical bit that will store the results" store::Int - function ClassicalXOR(bits, store) - tbits = tuple(bits...) - n = length(tbits) - return new{n}(tbits, store) - end end + +ClassicalXOR(bits,store) = ClassicalXOR{length(bits)}(tuple(bits...),store) diff --git a/src/pauli_frames.jl b/src/pauli_frames.jl index dc4a54de5..0fde80f1c 100644 --- a/src/pauli_frames.jl +++ b/src/pauli_frames.jl @@ -65,6 +65,7 @@ function apply!(frame::PauliFrame, xor::ClassicalXOR) end frame.measurements[f, xor.store] = value end + return frame end function apply!(frame::PauliFrame, op::sMX) # TODO implement a faster direct version diff --git a/src/sumtypes.jl b/src/sumtypes.jl index 68ee0dc20..12d13bdaa 100644 --- a/src/sumtypes.jl +++ b/src/sumtypes.jl @@ -9,6 +9,7 @@ struct SymbolicDataType types#::Core.SimpleVector fieldnames originaltype + originaltype_parameterized end _header(s) = s _header(s::SymbolicDataType) = s.name @@ -19,6 +20,8 @@ _fieldnames(s) = fieldnames(s) _fieldnames(s::SymbolicDataType) = s.fieldnames _originaltype(s) = s _originaltype(s::SymbolicDataType) = s.originaltype +_originaltype_parameterized(s) = s +_originaltype_parameterized(s::SymbolicDataType) = s.originaltype_parameterized """ ``` @@ -38,7 +41,8 @@ julia> make_variant_deconstruct(sCNOT, :apply!, (:s,)) """ function make_variant_deconstruct(type::Union{DataType,SymbolicDataType}, call, preargs=(), postargs=()) variant = Expr(:call, _symbol(type), _fieldnames(type)...) - original = :(($(_originaltype(type)))($(_fieldnames(type)...))) + original = :(($(_originaltype_parameterized(type)))($(_fieldnames(type)...))) + #:($variant => begin $(Expr(:call, call, preargs..., original, postargs...)); nothing end) # useful when you are searching for type instabilities due to inconsistent output types for a method (usually also pointing to a method not following the conventions of the API) :($variant => $(Expr(:call, call, preargs..., original, postargs...))) end @@ -94,8 +98,7 @@ function make_sumtype_variant_constructor(type) if isa(type, DataType) || isa(type, SymbolicDataType) return :( CompactifiedGate(g::$(_header(type))) = CompactifiedGate'.$(_symbol(type))($([:(g.$n) for n in _fieldnames(type)]...)) ) else - #return :( CompactifiedGate(g::$(_header(type))) = (@warn "The operation is of a type that can not be unified, defaulting to slower runtime dispatch" typeof(g); return g) ) - return :() + return :() # this is taken care of by a default constructor that also warns about the failure to compactify end end @@ -135,15 +138,21 @@ function make_all_sumtype_infrastructure_expr(t::DataType, callsigs) push!(concrete_types, ut) # fallback end sumtype = make_sumtype(concrete_types) + @debug "compiling a total of $(length(concrete_types)) concrete types" constructors = make_sumtype_variant_constructor.(concrete_types) methods = [make_sumtype_method(concrete_types, call, preargs, postargs) for (call, preargs, postargs) in callsigs] + modulename = gensym(:CompactifiedGate) return quote + #module $(modulename) + #using QuantumClifford + #import QuantumClifford: CompactifiedGate, # todo $(concretifier_workarounds_types...) $(sumtype.args...) # defining the sum type $(constructors...) # creating constructors for the sumtype which turn our concrete types into instance of the sum type $(concretifier_additional_constructors...) # creating constructors for the newly generated "workaround" concrete types - :( CompactifiedGate(g::AbstractOperation) = (@warn "The operation is of a type that can not be unified, defaulting to slower runtime dispatch" typeof(g); return g) ) + :(CompactifiedGate(g::AbstractOperation) = (@warn "The operation is of a type that can not be unified, defaulting to slower runtime dispatch" typeof(g); return g) ) $(methods...) + #end end end @@ -161,7 +170,7 @@ function concretifier(t) parameterized_type = t{typeparams...} ftypes = parameterized_type.types fnames = fieldnames(t) - push!(names, SymbolicDataType(name, ftypes, fnames, t)) + push!(names, SymbolicDataType(name, ftypes, fnames, t, parameterized_type)) push!(generated_concretetypes, :( struct $(name) $([:($n::$t) for (n,t) in zip(fnames,ftypes)]...)