diff --git a/lib/mps_simulator.h b/lib/mps_simulator.h index ae053690..98d00310 100644 --- a/lib/mps_simulator.h +++ b/lib/mps_simulator.h @@ -35,11 +35,12 @@ namespace mps { /** * Truncated Matrix Product State (MPS) circuit simulator w/ vectorization. */ -template +template class MPSSimulator final { public: - using MPSStateSpace_ = MPSStateSpace; + using MPSStateSpace_ = MPSStateSpace; using State = typename MPSStateSpace_::MPS; + using fp_type = typename MPSStateSpace_::fp_type; using Complex = std::complex; using Matrix = diff --git a/lib/mps_statespace.h b/lib/mps_statespace.h index 888d4d58..9e59599c 100644 --- a/lib/mps_statespace.h +++ b/lib/mps_statespace.h @@ -51,10 +51,11 @@ inline void free(void* ptr) { * Class containing context and routines for fixed bond dimension * truncated Matrix Product State (MPS) simulation. */ -template +template class MPSStateSpace { private: public: + using fp_type = FP; using Pointer = std::unique_ptr; using Complex = std::complex; diff --git a/tests/BUILD b/tests/BUILD index 53f72ab4..47f90535 100644 --- a/tests/BUILD +++ b/tests/BUILD @@ -613,6 +613,7 @@ cc_test( "@com_google_googletest//:gtest_main", "//lib:gate_appl", "//lib:gates_cirq", + "//lib:gates_qsim", "//lib:mps_simulator", "//lib:formux", ], diff --git a/tests/mps_simulator_test.cc b/tests/mps_simulator_test.cc index 1aa7def7..605a29c0 100644 --- a/tests/mps_simulator_test.cc +++ b/tests/mps_simulator_test.cc @@ -17,6 +17,7 @@ #include "../lib/formux.h" #include "../lib/gate_appl.h" #include "../lib/gates_cirq.h" +#include "../lib/gates_qsim.h" #include "gtest/gtest.h" namespace qsim { @@ -802,6 +803,110 @@ TEST(MPSSimulator, OneTwoQubitFuzz) { */ } +TEST(MPSSimulator, ApplyFusedGateLeft) { + // Apply a fused gate matrix to the first two qubits. + // Compute the state vector of: + // | | | + // +-+-----+-+ | + // |FusedGate| | + // +-+-----+-+ | + // | | | + // +-+-+ +-+-+ +-+-+ + // | 0 +-+ 1 +-+ 2 | + // +---+ +---+ +---+ + auto sim = MPSSimulator(1); + using MPSStateSpace = MPSSimulator::MPSStateSpace_; + auto ss = MPSStateSpace(1); + + auto gate1 = GateCZ::Create(2, 0, 1); + auto gate2 = GateHd::Create(0, 0); + auto gate3 = GateHd::Create(0, 1); + + GateFused> fgate1{kGateCZ, 2, {0, 1}, &gate1, + {&gate2, &gate3}}; + auto mps = ss.Create(3, 4); + ss.SetStateZero(mps); + ApplyFusedGate(sim, fgate1, mps); + + float wf[32]; + float ground_truth[] = {0.5, 0., 0., 0., 0.5, 0., 0., 0., + 0.5, 0., 0., 0., 0.5, 0., 0., 0.}; + ss.ToWaveFunction(mps, wf); + for (int i = 0; i < 16; i++) { + EXPECT_NEAR(wf[i], ground_truth[i], 1e-4); + } +} + +TEST(MPSSimulator, ApplyFusedGateRight) { + // Apply a fused gate matrix to the last two qubits. + // Compute the state vector of: + // | | | + // | +-+-----+-+ + // | |FusedGate| + // | +-+-----+-+ + // | | | + // +-+-+ +-+-+ +-+-+ + // | 0 +-+ 1 +-+ 2 | + // +---+ +---+ +---+ + auto sim = MPSSimulator(1); + using MPSStateSpace = MPSSimulator::MPSStateSpace_; + auto ss = MPSStateSpace(1); + + auto gate1 = GateCZ::Create(2, 1, 2); + auto gate2 = GateHd::Create(0, 1); + auto gate3 = GateHd::Create(0, 2); + + GateFused> fgate1{kGateCZ, 2, {1, 2}, &gate1, + {&gate2, &gate3}}; + auto mps = ss.Create(3, 4); + ss.SetStateZero(mps); + ApplyFusedGate(sim, fgate1, mps); + + float wf[32]; + float ground_truth[] = {0.5, 0., 0.5, 0., 0.5, 0., 0.5, 0., + 0., 0., 0., 0., 0., 0., 0., 0.}; + ss.ToWaveFunction(mps, wf); + for (int i = 0; i < 16; i++) { + EXPECT_NEAR(wf[i], ground_truth[i], 1e-4); + } +} + +TEST(MPSSimulator, ApplyFusedGateMiddle) { + // Apply a fused gate matrix to the middle two qubits. + // Compute the state vector of: + // | | | | + // | +-+-----+-+ | + // | |FusedGate| | + // | +-+-----+-+ | + // | | | | + // +-+-+ +-+-+ +-+-+ +-+-+ + // | 0 +-+ 1 +-+ 2 |-| 3 | + // +---+ +---+ +---+ +-+-+ + auto sim = MPSSimulator(1); + using MPSStateSpace = MPSSimulator::MPSStateSpace_; + auto ss = MPSStateSpace(1); + + auto gate1 = GateCZ::Create(2, 1, 2); + auto gate2 = GateHd::Create(0, 1); + auto gate3 = GateHd::Create(0, 2); + + GateFused> fgate1{kGateCZ, 2, {1, 2}, &gate1, + {&gate2, &gate3}}; + auto mps = ss.Create(4, 4); + ss.SetStateZero(mps); + ApplyFusedGate(sim, fgate1, mps); + + float wf[64]; + float ground_truth[] = {0.5, 0., 0., 0., 0.5, 0., 0., 0., + 0.5, 0., 0., 0., 0.5, 0., 0., 0., + 0., 0., 0., 0., 0., 0., 0., 0., + 0., 0., 0., 0., 0., 0., 0., 0.}; + ss.ToWaveFunction(mps, wf); + for (int i = 0; i < 32; i++) { + EXPECT_NEAR(wf[i], ground_truth[i], 1e-4); + } +} + } // namespace } // namespace mps } // namespace qsim