diff --git a/tests/compiler/test_opcodes.py b/tests/compiler/test_opcodes.py index f36fcfac6f..3c595dee44 100644 --- a/tests/compiler/test_opcodes.py +++ b/tests/compiler/test_opcodes.py @@ -45,11 +45,14 @@ def test_version_check(evm_version): def test_get_opcodes(evm_version): ops = opcodes.get_opcodes() - if evm_version in ("paris", "berlin", "shanghai"): + if evm_version in ("paris", "berlin", "shanghai", "cancun"): assert "CHAINID" in ops assert ops["SLOAD"][-1] == 2100 - if evm_version in ("shanghai",): + if evm_version in ("shanghai", "cancun"): assert "PUSH0" in ops + if evm_version in ("cancun",): + assert "TLOAD" in ops + assert "TSTORE" in ops elif evm_version == "istanbul": assert "CHAINID" in ops assert ops["SLOAD"][-1] == 800 diff --git a/tests/parser/ast_utils/test_ast_dict.py b/tests/parser/ast_utils/test_ast_dict.py index 214af50f9f..f483d0cbe8 100644 --- a/tests/parser/ast_utils/test_ast_dict.py +++ b/tests/parser/ast_utils/test_ast_dict.py @@ -73,6 +73,7 @@ def test_basic_ast(): "is_constant": False, "is_immutable": False, "is_public": False, + "is_transient": False, } diff --git a/tests/parser/features/decorators/test_nonreentrant.py b/tests/parser/features/decorators/test_nonreentrant.py index 0577313b88..ac73b35bec 100644 --- a/tests/parser/features/decorators/test_nonreentrant.py +++ b/tests/parser/features/decorators/test_nonreentrant.py @@ -3,6 +3,8 @@ from vyper.exceptions import FunctionDeclarationException +# TODO test functions in this module across all evm versions +# once we have cancun support. def test_nonreentrant_decorator(get_contract, assert_tx_failed): calling_contract_code = """ interface SpecialContract: diff --git a/tests/parser/features/test_transient.py b/tests/parser/features/test_transient.py new file mode 100644 index 0000000000..53354beca8 --- /dev/null +++ b/tests/parser/features/test_transient.py @@ -0,0 +1,61 @@ +import pytest + +from vyper.compiler import compile_code +from vyper.evm.opcodes import EVM_VERSIONS +from vyper.exceptions import StructureException + +post_cancun = {k: v for k, v in EVM_VERSIONS.items() if v >= EVM_VERSIONS["cancun"]} + + +@pytest.mark.parametrize("evm_version", list(EVM_VERSIONS.keys())) +def test_transient_blocked(evm_version): + # test transient is blocked on pre-cancun and compiles post-cancun + code = """ +my_map: transient(HashMap[address, uint256]) + """ + if EVM_VERSIONS[evm_version] >= EVM_VERSIONS["cancun"]: + assert compile_code(code, evm_version=evm_version) is not None + else: + with pytest.raises(StructureException): + compile_code(code, evm_version=evm_version) + + +@pytest.mark.parametrize("evm_version", list(post_cancun.keys())) +def test_transient_compiles(evm_version): + # test transient keyword at least generates TLOAD/TSTORE opcodes + getter_code = """ +my_map: public(transient(HashMap[address, uint256])) + """ + t = compile_code(getter_code, evm_version=evm_version, output_formats=["opcodes_runtime"]) + t = t["opcodes_runtime"].split(" ") + + assert "TLOAD" in t + assert "TSTORE" not in t + + setter_code = """ +my_map: transient(HashMap[address, uint256]) + +@external +def setter(k: address, v: uint256): + self.my_map[k] = v + """ + t = compile_code(setter_code, evm_version=evm_version, output_formats=["opcodes_runtime"]) + t = t["opcodes_runtime"].split(" ") + + assert "TLOAD" not in t + assert "TSTORE" in t + + getter_setter_code = """ +my_map: public(transient(HashMap[address, uint256])) + +@external +def setter(k: address, v: uint256): + self.my_map[k] = v + """ + t = compile_code( + getter_setter_code, evm_version=evm_version, output_formats=["opcodes_runtime"] + ) + t = t["opcodes_runtime"].split(" ") + + assert "TLOAD" in t + assert "TSTORE" in t diff --git a/vyper/ast/nodes.py b/vyper/ast/nodes.py index 03f2d713c1..7c907b4d08 100644 --- a/vyper/ast/nodes.py +++ b/vyper/ast/nodes.py @@ -1344,7 +1344,15 @@ class VariableDecl(VyperNode): If true, indicates that the variable is an immutable variable. """ - __slots__ = ("target", "annotation", "value", "is_constant", "is_public", "is_immutable") + __slots__ = ( + "target", + "annotation", + "value", + "is_constant", + "is_public", + "is_immutable", + "is_transient", + ) def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) @@ -1352,6 +1360,7 @@ def __init__(self, *args, **kwargs): self.is_constant = False self.is_public = False self.is_immutable = False + self.is_transient = False def _check_args(annotation, call_name): # do the same thing as `validate_call_args` @@ -1369,9 +1378,10 @@ def _check_args(annotation, call_name): # unwrap one layer self.annotation = self.annotation.args[0] - if self.annotation.get("func.id") in ("immutable", "constant"): - _check_args(self.annotation, self.annotation.func.id) - setattr(self, f"is_{self.annotation.func.id}", True) + func_id = self.annotation.get("func.id") + if func_id in ("immutable", "constant", "transient"): + _check_args(self.annotation, func_id) + setattr(self, f"is_{func_id}", True) # unwrap one layer self.annotation = self.annotation.args[0] diff --git a/vyper/cli/vyper_compile.py b/vyper/cli/vyper_compile.py index 9ab884a6d0..4dfc87639a 100755 --- a/vyper/cli/vyper_compile.py +++ b/vyper/cli/vyper_compile.py @@ -101,7 +101,8 @@ def _parse_args(argv): ) parser.add_argument( "--evm-version", - help=f"Select desired EVM version (default {DEFAULT_EVM_VERSION})", + help=f"Select desired EVM version (default {DEFAULT_EVM_VERSION}). " + " note: cancun support is EXPERIMENTAL", choices=list(EVM_VERSIONS), default=DEFAULT_EVM_VERSION, dest="evm_version", diff --git a/vyper/codegen/context.py b/vyper/codegen/context.py index 6e8d02c9b3..e4b41adbc0 100644 --- a/vyper/codegen/context.py +++ b/vyper/codegen/context.py @@ -28,6 +28,7 @@ class VariableRecord: defined_at: Any = None is_internal: bool = False is_immutable: bool = False + is_transient: bool = False data_offset: Optional[int] = None def __hash__(self): diff --git a/vyper/codegen/core.py b/vyper/codegen/core.py index a9a91ec9d8..06140f3f52 100644 --- a/vyper/codegen/core.py +++ b/vyper/codegen/core.py @@ -1,6 +1,6 @@ from vyper import ast as vy_ast from vyper.codegen.ir_node import Encoding, IRnode -from vyper.evm.address_space import CALLDATA, DATA, IMMUTABLES, MEMORY, STORAGE +from vyper.evm.address_space import CALLDATA, DATA, IMMUTABLES, MEMORY, STORAGE, TRANSIENT from vyper.evm.opcodes import version_check from vyper.exceptions import CompilerPanic, StructureException, TypeCheckFailure, TypeMismatch from vyper.semantics.types import ( @@ -562,10 +562,10 @@ def _get_element_ptr_mapping(parent, key): key = unwrap_location(key) # TODO when is key None? - if key is None or parent.location != STORAGE: - raise TypeCheckFailure(f"bad dereference on mapping {parent}[{key}]") + if key is None or parent.location not in (STORAGE, TRANSIENT): + raise TypeCheckFailure("bad dereference on mapping {parent}[{key}]") - return IRnode.from_list(["sha3_64", parent, key], typ=subtype, location=STORAGE) + return IRnode.from_list(["sha3_64", parent, key], typ=subtype, location=parent.location) # Take a value representing a memory or storage location, and descend down to diff --git a/vyper/codegen/expr.py b/vyper/codegen/expr.py index 4a18a16e1b..ac7290836b 100644 --- a/vyper/codegen/expr.py +++ b/vyper/codegen/expr.py @@ -23,7 +23,7 @@ ) from vyper.codegen.ir_node import IRnode from vyper.codegen.keccak256_helper import keccak256_helper -from vyper.evm.address_space import DATA, IMMUTABLES, MEMORY, STORAGE +from vyper.evm.address_space import DATA, IMMUTABLES, MEMORY, STORAGE, TRANSIENT from vyper.evm.opcodes import version_check from vyper.exceptions import ( CompilerPanic, @@ -259,10 +259,12 @@ def parse_Attribute(self): # self.x: global attribute elif isinstance(self.expr.value, vy_ast.Name) and self.expr.value.id == "self": varinfo = self.context.globals[self.expr.attr] + location = TRANSIENT if varinfo.is_transient else STORAGE + ret = IRnode.from_list( varinfo.position.position, typ=varinfo.typ, - location=STORAGE, + location=location, annotation="self." + self.expr.attr, ) ret._referenced_variables = {varinfo} diff --git a/vyper/codegen/function_definitions/utils.py b/vyper/codegen/function_definitions/utils.py index 7129388c58..f524ec6e88 100644 --- a/vyper/codegen/function_definitions/utils.py +++ b/vyper/codegen/function_definitions/utils.py @@ -8,6 +8,10 @@ def get_nonreentrant_lock(func_type): nkey = func_type.reentrancy_key_position.position + LOAD, STORE = "sload", "sstore" + if version_check(begin="cancun"): + LOAD, STORE = "tload", "tstore" + if version_check(begin="berlin"): # any nonzero values would work here (see pricing as of net gas # metering); these values are chosen so that downgrading to the @@ -16,12 +20,12 @@ def get_nonreentrant_lock(func_type): else: final_value, temp_value = 0, 1 - check_notset = ["assert", ["ne", temp_value, ["sload", nkey]]] + check_notset = ["assert", ["ne", temp_value, [LOAD, nkey]]] if func_type.mutability == StateMutability.VIEW: return [check_notset], [["seq"]] else: - pre = ["seq", check_notset, ["sstore", nkey, temp_value]] - post = ["sstore", nkey, final_value] + pre = ["seq", check_notset, [STORE, nkey, temp_value]] + post = [STORE, nkey, final_value] return [pre], [post] diff --git a/vyper/evm/address_space.py b/vyper/evm/address_space.py index 855e98b5c8..85a75c3c23 100644 --- a/vyper/evm/address_space.py +++ b/vyper/evm/address_space.py @@ -48,6 +48,7 @@ def byte_addressable(self) -> bool: MEMORY = AddrSpace("memory", 32, "mload", "mstore") STORAGE = AddrSpace("storage", 1, "sload", "sstore") +TRANSIENT = AddrSpace("transient", 1, "tload", "tstore") CALLDATA = AddrSpace("calldata", 32, "calldataload") # immutables address space: "immutables" section of memory # which is read-write in deploy code but then gets turned into diff --git a/vyper/evm/opcodes.py b/vyper/evm/opcodes.py index 7ff56df772..c447fd863c 100644 --- a/vyper/evm/opcodes.py +++ b/vyper/evm/opcodes.py @@ -24,6 +24,7 @@ "berlin": 3, "paris": 4, "shanghai": 5, + "cancun": 6, # ETC Forks "atlantis": 0, "agharta": 1, @@ -184,6 +185,8 @@ "INVALID": (0xFE, 0, 0, 0), "DEBUG": (0xA5, 1, 0, 0), "BREAKPOINT": (0xA6, 0, 0, 0), + "TLOAD": (0xB3, 1, 1, 100), + "TSTORE": (0xB4, 2, 0, 100), } PSEUDO_OPCODES: OpcodeMap = { diff --git a/vyper/semantics/analysis/base.py b/vyper/semantics/analysis/base.py index 5065131f29..449e6ca338 100644 --- a/vyper/semantics/analysis/base.py +++ b/vyper/semantics/analysis/base.py @@ -162,6 +162,7 @@ class VarInfo: is_constant: bool = False is_public: bool = False is_immutable: bool = False + is_transient: bool = False is_local_var: bool = False decl_node: Optional[vy_ast.VyperNode] = None diff --git a/vyper/semantics/analysis/module.py b/vyper/semantics/analysis/module.py index 188005e365..cb8e93ff28 100644 --- a/vyper/semantics/analysis/module.py +++ b/vyper/semantics/analysis/module.py @@ -4,6 +4,7 @@ import vyper.builtins.interfaces from vyper import ast as vy_ast +from vyper.evm.opcodes import version_check from vyper.exceptions import ( CallViolation, CompilerPanic, @@ -189,10 +190,17 @@ def visit_VariableDecl(self, node): if node.is_immutable else DataLocation.UNSET if node.is_constant + # XXX: needed if we want separate transient allocator + # else DataLocation.TRANSIENT + # if node.is_transient else DataLocation.STORAGE ) type_ = type_from_annotation(node.annotation, data_loc) + + if node.is_transient and not version_check(begin="cancun"): + raise StructureException("`transient` is not available pre-cancun", node.annotation) + var_info = VarInfo( type_, decl_node=node, @@ -200,6 +208,7 @@ def visit_VariableDecl(self, node): is_constant=node.is_constant, is_public=node.is_public, is_immutable=node.is_immutable, + is_transient=node.is_transient, ) node.target._metadata["varinfo"] = var_info # TODO maybe put this in the global namespace node._metadata["type"] = type_ diff --git a/vyper/semantics/data_locations.py b/vyper/semantics/data_locations.py index 0ec374e42f..2f259b1766 100644 --- a/vyper/semantics/data_locations.py +++ b/vyper/semantics/data_locations.py @@ -7,3 +7,5 @@ class DataLocation(enum.Enum): STORAGE = 2 CALLDATA = 3 CODE = 4 + # XXX: needed for separate transient storage allocator + # TRANSIENT = 5 diff --git a/vyper/semantics/namespace.py b/vyper/semantics/namespace.py index d760f66972..82a5d5cf3e 100644 --- a/vyper/semantics/namespace.py +++ b/vyper/semantics/namespace.py @@ -176,6 +176,7 @@ def validate_identifier(attr): "nonpayable", "constant", "immutable", + "transient", "internal", "payable", "nonreentrant",