diff --git a/test/test_weyl_chamber.jl b/test/test_weyl_chamber.jl index 991b340..ffecda0 100644 --- a/test/test_weyl_chamber.jl +++ b/test/test_weyl_chamber.jl @@ -248,19 +248,16 @@ end ] J_T = gate_functional(D_PE; unitarity_weight=0.0) - chi_pe! = make_chi(J_T, trajectories) - ϕ = [Ψ1, Ψ2, Ψ3, Ψ4] - χ_zygote = [similar(ϕₖ) for ϕₖ in ϕ] - chi_pe!(χ_zygote, ϕ, trajectories) + chi_pe = make_chi(J_T, trajectories) + Ψ = [Ψ1, Ψ2, Ψ3, Ψ4] + χ_zygote = chi_pe(Ψ, trajectories) - gate_chi_pe! = make_gate_chi(D_PE, trajectories; unitarity_weight=0.0) - χ_gate_zygote = [similar(ϕₖ) for ϕₖ in ϕ] - gate_chi_pe!(χ_gate_zygote, ϕ, trajectories) + gate_chi_pe = make_gate_chi(D_PE, trajectories; unitarity_weight=0.0) + χ_gate_zygote = gate_chi_pe(Ψ, trajectories) - gate_chi_pe_fd! = + gate_chi_pe_fd = make_gate_chi(D_PE, trajectories; automatic=FiniteDifferences, unitarity_weight=0.0) - χ_gate_fd = [similar(ϕₖ) for ϕₖ in ϕ] - gate_chi_pe_fd!(χ_gate_fd, ϕ, trajectories) + χ_gate_fd = gate_chi_pe_fd(Ψ, trajectories) vec_angle(v⃗, w⃗) = acos((v⃗ ⋅ w⃗) / (norm(v⃗) * norm(w⃗))) @@ -293,7 +290,7 @@ end J_T_val = 0.0 local J_T_U local J_T - local ϕ + local Ψ basis = [ket(lbl; N=2) for lbl in ("00", "01", "10", "11")] H = random_matrix(4; hermitian=true, complex=false) @@ -306,18 +303,16 @@ end end U = canonical_gate(c1, c2, c3) - ϕ = U' * basis + Ψ = U' * basis J_T_U = U -> 1 - gate_concurrence(U) J_T = gate_functional(J_T_U) - gate_chi_pe! = make_gate_chi(J_T_U, trajectories) - χ_gate_zygote = [similar(ϕₖ) for ϕₖ in ϕ] - gate_chi_pe!(χ_gate_zygote, ϕ, trajectories) + gate_chi_pe = make_gate_chi(J_T_U, trajectories) + χ_gate_zygote = gate_chi_pe(Ψ, trajectories) - gate_chi_pe_fd! = make_gate_chi(J_T_U, trajectories; automatic=FiniteDifferences) - χ_gate_fd = [similar(ϕₖ) for ϕₖ in ϕ] - gate_chi_pe_fd!(χ_gate_fd, ϕ, trajectories) + gate_chi_pe_fd = make_gate_chi(J_T_U, trajectories; automatic=FiniteDifferences) + χ_gate_fd = gate_chi_pe_fd(Ψ, trajectories) # Does Zygote gate gradient and FD gate gradient match? @test norm(χ_gate_zygote[1] - χ_gate_fd[1]) < 1e-8