Skip to content

Commit

Permalink
Fix declaration bug introduced by PR #930
Browse files Browse the repository at this point in the history
  • Loading branch information
rmshaffer committed Apr 4, 2024
1 parent 790ede5 commit 1ad3e66
Show file tree
Hide file tree
Showing 7 changed files with 37 additions and 27 deletions.
2 changes: 1 addition & 1 deletion src/braket/experimental/autoqasm/operators/assignments.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,7 @@ def assign_stmt(target_name: str, value: Any) -> Any:
# int[32] a = 10;
# where `a` is at the root scope of the function (not inside any if/for/while block).
target.init_expression = value_init_expression
oqpy_program._add_var(target)
oqpy_program.declare(target)
else:
# Set to `value_init_expression` to avoid declaring an unnecessary variable.
# The variable will be set in the current scope and auto-declared at the root scope.
Expand Down
11 changes: 11 additions & 0 deletions src/braket/experimental/autoqasm/program/program.py
Original file line number Diff line number Diff line change
Expand Up @@ -519,6 +519,17 @@ def add_io_declarations(self) -> None:
# Verify that we didn't find it in both lists
assert popped_undeclared is None or popped_declared is None

# Remove the existing declaration statement, if any
if popped_declared is not None:
declarations = [
stmt
for stmt in root_oqpy_program._state.body
if isinstance(stmt, ast.ClassicalDeclaration)
and stmt.identifier.name == parameter_name
]
assert len(declarations) == 1
root_oqpy_program._state.body.remove(declarations[0])

popped = popped_undeclared if popped_undeclared is not None else popped_declared
if popped is not None and popped.init_expression is not None:
# Add an assignment statement to the beginning of the program to initialize
Expand Down
16 changes: 10 additions & 6 deletions test/unit_tests/braket/experimental/autoqasm/test_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,8 +251,8 @@ def bell_measurement_declared() -> None:

def test_bell_measurement_declared(bell_measurement_declared) -> None:
expected = """OPENQASM 3.0;
bit[2] c = "00";
qubit[2] __qubits__;
bit[2] c = "00";
h __qubits__[0];
cnot __qubits__[0], __qubits__[1];
bit[2] __bit_1__ = "00";
Expand Down Expand Up @@ -863,13 +863,16 @@ def classical_variables_types() -> None:

def test_classical_variables_types(classical_variables_types):
expected = """OPENQASM 3.0;
bit a = 1;
bit a = 0;
a = 1;
int[32] i = 1;
bit[2] a_array = "00";
int[32] b = 15;
float[64] c = 3.4;
a_array[0] = 0;
a_array[i] = 1;"""
a_array[i] = 1;
int[32] b = 10;
b = 15;
float[64] c = 1.2;
c = 3.4;"""
assert classical_variables_types.build().to_ir() == expected


Expand All @@ -886,8 +889,9 @@ def prog() -> None:
a = b # declared target, declared value # noqa: F841

expected = """OPENQASM 3.0;
int[32] a = 2;
int[32] b;
int[32] a = 1;
a = 2;
b = a;
a = b;"""
assert prog.build().to_ir() == expected
Expand Down
10 changes: 6 additions & 4 deletions test/unit_tests/braket/experimental/autoqasm/test_converters.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def test_assignment(program_ctx: ag_ctx.ControlStatusCtx) -> None:
def fn() -> None:
"""user program to test"""
a = aq.IntVar(5) # noqa: F841
b = aq.FloatVar(1.2) # noqa: F841
b = a # noqa: F841
c = 123 # noqa: F841
d = (0.123, "foo") # noqa: F841
a = aq.IntVar(1) # noqa: F841
Expand All @@ -54,12 +54,14 @@ def fn() -> None:

qasm = program_conversion_context.make_program().to_ir()
expected_qasm = """OPENQASM 3.0;
int[32] a = 1;
float[64] b = 1.2;
int[32] b;
int[32] e;
int[32] a = 5;
b = a;
a = 1;
e = a;
bool f = false;
bool g = true;
e = a;
g = f;"""
assert qasm == expected_qasm

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -162,8 +162,8 @@ def branch_assignment_declared():
a = aq.IntVar(7) # noqa: F841

expected = """OPENQASM 3.0;
int[32] a = 5;
bool __bool_1__ = true;
int[32] a = 5;
if (__bool_1__) {
a = 6;
} else {
Expand All @@ -184,8 +184,8 @@ def iterative_assignment():
rx(0, val)

expected = """OPENQASM 3.0;
float[64] val = 0.5;
qubit[3] __qubits__;
float[64] val = 0.5;
for int q in [0:3 - 1] {
bit __bit_1__;
__bit_1__ = measure __qubits__[q];
Expand Down Expand Up @@ -701,9 +701,9 @@ def measure_to_slice():
b0[3] = c

expected = """OPENQASM 3.0;
bit[10] b0 = "0000000000";
bit c;
qubit[1] __qubits__;
bit[10] b0 = "0000000000";
bit __bit_1__;
__bit_1__ = measure __qubits__[0];
c = __bit_1__;
Expand Down
11 changes: 2 additions & 9 deletions test/unit_tests/braket/experimental/autoqasm/test_parameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -529,20 +529,13 @@ def parametric_explicit():
with pytest.raises(RuntimeError, match="conflicting variables with name alpha"):
parametric_explicit.build()


def test_assignment_to_input_variable_name():
"""Test assigning to overwrite an input variable within the program."""

@aq.main
def parametric(alpha):
alpha = aq.FloatVar(1.2)
rx(0, alpha)

expected = """OPENQASM 3.0;
float[64] alpha = 1.2;
qubit[1] __qubits__;
rx(alpha) __qubits__[0];"""
assert parametric.build().to_ir() == expected
with pytest.raises(RuntimeError, match="conflicting variables with name alpha"):
parametric.build()


def test_binding_variable_fails():
Expand Down
8 changes: 4 additions & 4 deletions test/unit_tests/braket/experimental/autoqasm/test_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,9 +159,9 @@ def ret_test() -> int:
def add(int[32] a, int[32] b) -> int[32] {
return a + b;
}
output int[32] return_value;
int[32] a = 5;
int[32] b = 6;
output int[32] return_value;
int[32] __int_2__;
__int_2__ = add(a, b);
return_value = __int_2__;"""
Expand Down Expand Up @@ -194,8 +194,8 @@ def declare_array():

expected = """OPENQASM 3.0;
array[int[32], 3] a = {1, 2, 3};
array[int[32], 3] b = {4, 5, 6};
a[0] = 11;
array[int[32], 3] b = {4, 5, 6};
b[2] = 14;
b = a;"""

Expand Down Expand Up @@ -517,9 +517,9 @@ def main():

expected_qasm = """OPENQASM 3.0;
def retval_recursive() -> int[32] {
int[32] retval_ = 1;
int[32] __int_1__;
__int_1__ = retval_recursive();
int[32] retval_ = 1;
return retval_;
}
int[32] __int_3__;
Expand All @@ -543,10 +543,10 @@ def main():
expected_qasm = """OPENQASM 3.0;
def retval_recursive() -> int[32] {
int[32] a;
int[32] retval_ = 1;
int[32] __int_1__;
__int_1__ = retval_recursive();
a = __int_1__;
int[32] retval_ = 1;
return retval_;
}
int[32] __int_3__;
Expand Down

0 comments on commit 1ad3e66

Please sign in to comment.