diff --git a/.github/workflows/test.yaml b/.github/workflows/test.yaml index 6454b89..76640e3 100644 --- a/.github/workflows/test.yaml +++ b/.github/workflows/test.yaml @@ -81,23 +81,22 @@ jobs: - name: Run Tests run: pytest -m "not fuzzing" -n 0 -s --cov -# NOTE: uncomment this block after you've marked tests with @pytest.mark.fuzzing -# fuzzing: -# runs-on: ubuntu-latest -# -# strategy: -# fail-fast: true -# -# steps: -# - uses: actions/checkout@v3 -# -# - name: Setup Python -# uses: actions/setup-python@v4 -# with: -# python-version: 3.8 -# -# - name: Install Dependencies -# run: pip install .[test] -# -# - name: Run Tests -# run: pytest -m "fuzzing" --no-cov -s + fuzzing: + runs-on: ubuntu-latest + + strategy: + fail-fast: true + + steps: + - uses: actions/checkout@v3 + + - name: Setup Python + uses: actions/setup-python@v4 + with: + python-version: 3.8 + + - name: Install Dependencies + run: pip install .[test] + + - name: Run Tests + run: pytest -m "fuzzing" --no-cov -s diff --git a/README.md b/README.md index 2f98bce..2447744 100644 --- a/README.md +++ b/README.md @@ -41,10 +41,10 @@ class Person(EIP712Type): class Mail(EIP712Message): - _chainId_: "uint256" = 1 - _name_: "string" = "Ether Mail" - _verifyingContract_: "address" = "0xCcCCccccCCCCcCCCCCCcCcCccCcCCCcCcccccccC" - _version_: "string" = "1" + _chainId_ = 1 + _name_ = "Ether Mail" + _verifyingContract_ = "0xCcCCccccCCCCcCCCCCCcCcCccCcCCCcCcccccccC" + _version_ = "1" sender: Person receiver: Person diff --git a/docs/conf.py b/docs/conf.py index b929091..a701881 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -24,7 +24,6 @@ author = "ApeWorX Team" extensions = [ "myst_parser", - "sphinx_click", "sphinx.ext.autodoc", "sphinx.ext.autosummary", "sphinx.ext.napoleon", diff --git a/docs/index.md b/docs/index.md index b40c4b6..fb65bfd 100644 --- a/docs/index.md +++ b/docs/index.md @@ -5,8 +5,8 @@ :caption: User Guides :maxdepth: 1 - userguides/quickstart - + userguides/quickstart + ``` ```{eval-rst} @@ -14,7 +14,7 @@ :caption: Python Reference :maxdepth: 2 + methoddocs/common.md methoddocs/messages.md - methoddocs/validation.md ``` diff --git a/docs/methoddocs/validation.md b/docs/methoddocs/common.md similarity index 54% rename from docs/methoddocs/validation.md rename to docs/methoddocs/common.md index 4b7a901..aa11d21 100644 --- a/docs/methoddocs/validation.md +++ b/docs/methoddocs/common.md @@ -1,7 +1,7 @@ -# Validation +# Common ```{eval-rst} -.. automodule:: eip712.validation +.. automodule:: eip712.common :members: :show-inheritance: ``` diff --git a/eip712/common.py b/eip712/common.py new file mode 100644 index 0000000..b2bab47 --- /dev/null +++ b/eip712/common.py @@ -0,0 +1,111 @@ +# flake8: noqa F821 +# Collection of commonly-used EIP712 message type definitions +from typing import Optional, Type, Union + +from .messages import EIP712Message + + +class EIP2612(EIP712Message): + # NOTE: Subclass this w/ at least one header field + + owner: "address" # type: ignore + spender: "address" # type: ignore + value: "uint256" # type: ignore + nonce: "uint256" # type: ignore + deadline: "uint256" # type: ignore + + +class EIP4494(EIP712Message): + # NOTE: Subclass this w/ at least one header field + + spender: "address" # type: ignore + tokenId: "uint256" # type: ignore + nonce: "uint256" # type: ignore + deadline: "uint256" # type: ignore + + +def create_permit_def(eip=2612, **header_fields): + if eip == 2612: + + class Permit(EIP2612): + _name_ = header_fields.get("name", None) + _version_ = header_fields.get("version", None) + _chainId_ = header_fields.get("chainId", None) + _verifyingContract_ = header_fields.get("verifyingContract", None) + _salt_ = header_fields.get("salt", None) + + elif eip == 4494: + + class Permit(EIP4494): + _name_ = header_fields.get("name", None) + _version_ = header_fields.get("version", None) + _chainId_ = header_fields.get("chainId", None) + _verifyingContract_ = header_fields.get("verifyingContract", None) + _salt_ = header_fields.get("salt", None) + + else: + raise ValueError(f"Invalid eip {eip}, must use one of: {EIP2612}, {EIP4494}") + + return Permit + + +class SafeTxV1(EIP712Message): + # NOTE: Subclass this as `SafeTx` w/ at least one header field + to: "address" # type: ignore + value: "uint256" = 0 # type: ignore + data: "bytes" = b"" # type: ignore + operation: "uint8" = 0 # type: ignore + safeTxGas: "uint256" = 0 # type: ignore + dataGas: "uint256" = 0 # type: ignore + gasPrice: "uint256" = 0 # type: ignore + gasToken: "address" = "0x0000000000000000000000000000000000000000" # type: ignore + refundReceiver: "address" = "0x0000000000000000000000000000000000000000" # type: ignore + nonce: "uint256" # type: ignore + + +class SafeTxV2(EIP712Message): + # NOTE: Subclass this as `SafeTx` w/ at least one header field + to: "address" # type: ignore + value: "uint256" = 0 # type: ignore + data: "bytes" = b"" # type: ignore + operation: "uint8" = 0 # type: ignore + safeTxGas: "uint256" = 0 # type: ignore + baseGas: "uint256" = 0 # type: ignore + gasPrice: "uint256" = 0 # type: ignore + gasToken: "address" = "0x0000000000000000000000000000000000000000" # type: ignore + refundReceiver: "address" = "0x0000000000000000000000000000000000000000" # type: ignore + nonce: "uint256" # type: ignore + + +SafeTx = Union[SafeTxV1, SafeTxV2] +SAFE_VERSIONS = {"1.0.0", "1.1.0", "1.1.1", "1.2.0", "1.3.0"} + + +def create_safe_tx_def( + version: str = "1.3.0", + contract_address: Optional[str] = None, + chain_id: Optional[int] = None, +) -> Type[SafeTx]: + if not contract_address: + raise ValueError("Must define 'contract_address'") + + if version not in SAFE_VERSIONS: + raise ValueError(f"Unknown version {version}") + + major, minor, patch = map(int, version.split(".")) + + if minor < 3: + + class SafeTx(SafeTxV1): + _verifyingContract_ = contract_address + + elif not chain_id: + raise ValueError("Must supply 'chain_id=' for Safe versions 1.3.0 or later") + + else: + + class SafeTx(SafeTxV2): # type: ignore[no-redef] + _chainId_ = chain_id + _verifyingContract_ = contract_address + + return SafeTx diff --git a/eip712/hashing.py b/eip712/hashing.py deleted file mode 100644 index c311488..0000000 --- a/eip712/hashing.py +++ /dev/null @@ -1,234 +0,0 @@ -# copied under the MIT license from eth-account: -# https://github.com/ethereum/eth-account/blob/cc2feca919474203b0b23450ce7f2deed3ce985c/eth_account/_utils/structured_data/hashing.py -import json -from itertools import groupby -from operator import itemgetter - -from eth_abi import encode, is_encodable, is_encodable_type -from eth_abi.grammar import parse -from eth_utils import keccak, to_tuple - -from .validation import validate_structured_data - - -def get_dependencies(primary_type, types): - """ - Perform DFS to get all the dependencies of the primary_type. - """ - deps = set() - struct_names_yet_to_be_expanded = [primary_type] - - while len(struct_names_yet_to_be_expanded) > 0: - struct_name = struct_names_yet_to_be_expanded.pop() - - deps.add(struct_name) - fields = types[struct_name] - for field in fields: - field_type = field["type"] - - # Handle array types - if is_array_type(field_type): - field_type = field_type[: field_type.index("[")] - - if field_type not in types: - # We don't need to expand types that are not user defined (customized) - continue - elif field_type not in deps: - # Custom Struct Type - struct_names_yet_to_be_expanded.append(field_type) - elif field_type in deps: - # skip types that we have already encountered - continue - else: - raise TypeError(f"Unable to determine type dependencies with type `{field_type}`.") - # Don't need to make a struct as dependency of itself - deps.remove(primary_type) - - return tuple(deps) - - -def field_identifier(field): - """ - Convert a field dict into a typed-name string. - Given a ``field`` of the format {'name': NAME, 'type': TYPE}, - this function converts it to ``TYPE NAME`` - """ - return "{0} {1}".format(field["type"], field["name"]) - - -def encode_struct(struct_name, struct_field_types): - return "{0}({1})".format( - struct_name, - ",".join(map(field_identifier, struct_field_types)), - ) - - -def encode_type(primary_type, types): - """ - Serialize types into an encoded string. - The type of a struct is encoded as: - name ‖ "(" ‖ member₁ ‖ "," ‖ member₂ ‖ "," ‖ … ‖ memberₙ ")" - where each member is written as type ‖ " " ‖ name. - """ - # Getting the dependencies and sorting them alphabetically as per EIP712 - deps = get_dependencies(primary_type, types) - sorted_deps = (primary_type,) + tuple(sorted(deps)) - - result = "".join( - [encode_struct(struct_name, types[struct_name]) for struct_name in sorted_deps] - ) - return result - - -def hash_struct_type(primary_type, types): - return keccak(text=encode_type(primary_type, types)) - - -def is_array_type(type): - return type.endswith("]") - - -@to_tuple -def get_depths_and_dimensions(data, depth): - """ - Yields 2-length tuples of depth and dimension of each element at that depth. - """ - if not isinstance(data, (list, tuple)): - # Not checking for Iterable instance, because even Dictionaries and strings - # are considered as iterables, but that's not what we want the condition to be. - return () - - yield depth, len(data) - - for item in data: - # iterating over all 1 dimension less sub-data items - yield from get_depths_and_dimensions(item, depth + 1) - - -def get_array_dimensions(data): - """ - Given an array type data item, check that it is an array and return the dimensions - as a tuple, in order from inside to outside. - Ex: get_array_dimensions([[1, 2, 3], [4, 5, 6]]) returns (3, 2) - """ - depths_and_dimensions = get_depths_and_dimensions(data, 0) - - # re-form as a dictionary with `depth` as key, and all of the dimensions - # found at that depth. - grouped_by_depth = { - depth: tuple(dimension for depth, dimension in group) - for depth, group in groupby(depths_and_dimensions, itemgetter(0)) - } - - dimensions = tuple( - # check that all dimensions are the same, else use "dynamic" - dimensions[0] if all(dim == dimensions[0] for dim in dimensions) else "dynamic" - for _depth, dimensions in sorted(grouped_by_depth.items(), reverse=True) - ) - - return dimensions - - -def encode_field(types, name, field_type, value): - if value is None: - raise ValueError(f"Missing value for field {name} of type {field_type}") - - if field_type in types: - return ("bytes32", keccak(encode_data(field_type, types, value))) - - if field_type == "bytes": - if not isinstance(value, bytes): - raise TypeError( - f"Value of field `{name}` ({value}) is of the type `{type(value)}`, " - f"but expected bytes value" - ) - - return ("bytes32", keccak(value)) - - if field_type == "string": - if not isinstance(value, str): - raise TypeError( - f"Value of field `{name}` ({value}) is of the type `{type(value)}`, " - f"but expected string value" - ) - - return ("bytes32", keccak(text=value)) - - if is_array_type(field_type): - # Get the dimensions from the value - array_dimensions = get_array_dimensions(value) - # Get the dimensions from what was declared in the schema - parsed_field_type = parse(field_type) - - for i in range(len(array_dimensions)): - if len(parsed_field_type.arrlist[i]) == 0: - # Skip empty or dynamically declared dimensions - continue - if array_dimensions[i] != parsed_field_type.arrlist[i][0]: - # Dimensions should match with declared schema - raise TypeError( - f"Array data `{value}` has dimensions `{array_dimensions}`" - f" whereas the schema has dimensions " - f"`{tuple(map(lambda x: x[0] if x else 'dynamic', parsed_field_type.arrlist))}`" # noqa: E501 - ) - - field_type_of_inside_array = field_type[: field_type.rindex("[")] - field_type_value_pairs = [ - encode_field(types, name, field_type_of_inside_array, item) for item in value - ] - - # handle empty array - if value: - data_types, data_hashes = zip(*field_type_value_pairs) - else: - data_types, data_hashes = [], [] - - return ("bytes32", keccak(encode(data_types, data_hashes))) - - # First checking to see if field_type is valid as per abi - if not is_encodable_type(field_type): - raise TypeError(f"Received Invalid type `{field_type}` in field `{name}`") - - # Next, see if the value is encodable as the specified field_type - if is_encodable(field_type, value): - # field_type is a valid type and the provided value is encodable as that type - return (field_type, value) - else: - raise TypeError( - f"Value of `{name}` ({value}) is not encodable as type `{field_type}`. " - f"If the base type is correct, verify that the value does not " - f"exceed the specified size for the type." - ) - - -def encode_data(primary_type, types, data): - encoded_types = ["bytes32"] - encoded_values = [hash_struct_type(primary_type, types)] - - for field in types[primary_type]: - type, value = encode_field(types, field["name"], field["type"], data[field["name"]]) - encoded_types.append(type) - encoded_values.append(value) - - return encode(encoded_types, encoded_values) - - -def load_and_validate_structured_message(structured_json_string_data): - structured_data = json.loads(structured_json_string_data) - validate_structured_data(structured_data) - - return structured_data - - -def hash_domain(structured_data): - return keccak(encode_data("EIP712Domain", structured_data["types"], structured_data["domain"])) - - -def hash_message(structured_data): - return keccak( - encode_data( - structured_data["primaryType"], - structured_data["types"], - structured_data["message"], - ) - ) diff --git a/eip712/messages.py b/eip712/messages.py index e527a6e..6449ec2 100644 --- a/eip712/messages.py +++ b/eip712/messages.py @@ -2,200 +2,145 @@ Message classes for typed structured data hashing and signing in Ethereum. """ -from typing import Dict, NamedTuple +from typing import Any, Dict, Optional -from dataclassy import as_dict, dataclass, fields +from dataclassy import dataclass, fields from eth_abi import is_encodable_type -from eth_typing import Hash32 -from eth_utils.curried import ValidationError, keccak +from eth_account.messages import SignableMessage, hash_domain, hash_eip712_message +from eth_utils.curried import ValidationError from hexbytes import HexBytes -from eip712.hashing import hash_domain -from eip712.hashing import hash_message as hash_eip712_message - # ! Do not change the order of the fields in this list ! # To correctly encode and hash the domain fields, they # must be in this precise order. -EIP712_DOMAIN_FIELDS = [ - "name", - "version", - "chainId", - "verifyingContract", - "salt", +EIP712_DOMAIN_FIELDS = { + "name": "string", + "version": "string", + "chainId": "uint256", + "verifyingContract": "address", + "salt": "bytes32", +} + +EIP712_BODY_FIELDS = [ + "types", + "primaryType", + "domain", + "message", ] -HEADER_FIELDS = [f"_{field}_" for field in EIP712_DOMAIN_FIELDS] - - -# https://github.com/ethereum/eth-account/blob/f1d38e0/eth_account/messages.py#L39 -class SignableMessage(NamedTuple): - """ - These are the components of an `EIP-191 `__ - signable message. Other message formats can be encoded into this format for easy signing. - This data structure doesn't need to know about the original message format. - - In typical usage, you should never need to create these by hand. Instead, use - one of the available encode_* methods in this module, like: - - - :meth:`encode_structured_data` - - :meth:`encode_intended_validator` - - :meth:`encode_structured_data` - """ - - version: bytes # must be length 1 - header: bytes # aka "version specific data" - body: bytes # aka "data to sign" - -# https://github.com/ethereum/eth-account/blob/f1d38e0/eth_account/messages.py#L59 -def _hash_eip191_message(signable_message: SignableMessage) -> Hash32: - """ - Hash the given ``signable_message`` according to the EIP-191 Signed Data Standard. - """ - version = signable_message.version - if len(version) != 1: - raise ValidationError( - "The supplied message version is {version!r}. " - "The EIP-191 signable message standard only supports one-byte versions." - ) - joined = b"\x19" + version + signable_message.header + signable_message.body - return Hash32(keccak(joined)) - - -@dataclass(iter=True, slots=True) +@dataclass(iter=True, slots=True, kwargs=True, kw_only=True) class EIP712Type: """ Dataclass for `EIP-712 `__ structured data types (i.e. the contents of an :class:`EIP712Message`). """ - @property - def type(self) -> str: + def __repr__(self) -> str: return self.__class__.__name__ - def field_type(self, field: str) -> str: - """ - Looks up ``field`` via type annotations, returning the underlying ABI - type (e.g. ``"uint256"``) or :class:`EIP712Type`. Raises ``KeyError`` - if the field doesn't exist. - """ - typ = self.__annotations__[field] - - if isinstance(typ, str): - if not is_encodable_type(typ): - raise ValidationError(f"'{field}: {typ}' is not a valid ABI type") - - return typ - - elif issubclass(typ, EIP712Type): - return str(typ.type) - - else: - raise ValidationError( - f"'{field}' type annotation must either be a subclass of " - f"`EIP712Type` or valid ABI Type string, not {typ.__name__}" - ) - - def types(self) -> dict: + @property + def _types_(self) -> dict: """ Recursively built ``dict`` (name of type ``->`` list of subtypes) of the underlying fields' types. """ - types: Dict[str, list] = {self.type: []} + types: Dict[str, list] = {repr(self): []} for field in fields(self.__class__): value = getattr(self, field) if isinstance(value, EIP712Type): - types[self.type].append({"name": field, "type": value.type}) - types.update(value.types()) + types[repr(self)].append({"name": field, "type": repr(value)}) + types.update(value._types_) else: - types[self.type].append({"name": field, "type": self.field_type(field)}) + # TODO: Use proper ABI typing, not strings + field_type = self.__annotations__[field] + + if isinstance(field_type, str): + if not is_encodable_type(field_type): + raise ValidationError(f"'{field}: {field_type}' is not a valid ABI type") + + elif issubclass(field_type, EIP712Type): + field_type = repr(field_type) + + else: + raise ValidationError( + f"'{field}' type annotation must either be a subclass of " + f"`EIP712Type` or valid ABI Type string, not {field_type.__name__}" + ) + + types[repr(self)].append({"name": field, "type": field_type}) return types - @property - def data(self) -> dict: - """ - Recursively built ``dict`` of the underlying data, to be used for - serialization. - """ - d = as_dict(self) # NOTE: Handles recursion - return {k: v for (k, v) in d.items() if k not in HEADER_FIELDS} + def __getitem__(self, key: str) -> Any: + if (key.startswith("_") and key.endswith("_")) or key not in fields(self.__class__): + raise KeyError("Cannot look up header fields or other attributes this way") + + return getattr(self, key) -# TODO: Make type of EIP712Message a subtype of SignableMessage somehow class EIP712Message(EIP712Type): """ Container for EIP-712 messages with type information, domain separator parameters, and the message object. """ + # NOTE: Must override at least one of these fields + _name_: Optional[str] = None + _version_: Optional[str] = None + _chainId_: Optional[int] = None + _verifyingContract_: Optional[str] = None + _salt_: Optional[bytes] = None + def __post_init__(self): # At least one of the header fields must be in the EIP712 message header - if len(self.domain) == 0: + if not any(getattr(self, f"_{field}_") for field in EIP712_DOMAIN_FIELDS): raise ValidationError( - f"EIP712 Message definition '{self.type}' must define " - f"at least one of {EIP712_DOMAIN_FIELDS}" + f"EIP712 Message definition '{repr(self)}' must define " + f"at least one of: _{'_, _'.join(EIP712_DOMAIN_FIELDS)}_" ) @property - def domain(self) -> dict: - """The EIP-712 domain fields (built using ``HEADER_FIELDS``).""" - # Ensure that HEADER_FIELDS are in the following order: - # name, version, chainId, verifyingContract, salt + def _domain_(self) -> dict: + """The EIP-712 domain structure to be used for serialization and hashing.""" + domain_type = [ + {"name": field, "type": abi_type} + for field, abi_type in EIP712_DOMAIN_FIELDS.items() + if getattr(self, f"_{field}_") + ] return { - field.replace("_", ""): getattr(self, field) - for field in HEADER_FIELDS - if field in fields(self.__class__, internals=True) + "types": { + "EIP712Domain": domain_type, + }, + "domain": {field["name"]: getattr(self, f"_{field['name']}_") for field in domain_type}, } @property - def domain_type(self) -> list: - """The EIP-712 domain structure to be used for serialization.""" - return [{"name": field, "type": self.field_type(f"_{field}_")} for field in self.domain] - - @property - def version(self) -> bytes: - """ - The current major version of the signing domain. Signatures from - different versions are not compatible. - """ - return b"\x01" - - @property - def header(self) -> bytes: - """The EIP-712 message header.""" - return hash_domain( - { - "types": { - "EIP712Domain": self.domain_type, - }, - "domain": self.domain, - } - ) - - @property - def body_data(self) -> dict: + def _body_(self) -> dict: """The EIP-712 structured message to be used for serialization and hashing.""" - types = dict(self.types(), EIP712Domain=self.domain_type) - msg = { - "domain": self.domain, - "types": types, - "primaryType": self.type, - "message": self.data, + return { + "domain": self._domain_["domain"], + "types": dict(self._types_, **self._domain_["types"]), + "primaryType": repr(self), + "message": { + key: getattr(self, key) + for key in fields(self.__class__) + if not key.startswith("_") or not key.endswith("_") + }, } - return msg - @property - def body(self) -> bytes: - """The hash of the EIP-712 message (``body_data``).""" - return hash_eip712_message(self.body_data) + def __getitem__(self, key: str) -> Any: + if key in EIP712_BODY_FIELDS: + return self._body_[key] + + return super().__getitem__(key) @property def signable_message(self) -> SignableMessage: """The current message as a :class:`SignableMessage` named tuple instance.""" return SignableMessage( HexBytes(b"\x01"), - self.header, - self.body, + HexBytes(hash_domain(self._domain_)), + HexBytes(hash_eip712_message(self._body_)), ) diff --git a/eip712/validation.py b/eip712/validation.py deleted file mode 100644 index 55dfb2b..0000000 --- a/eip712/validation.py +++ /dev/null @@ -1,117 +0,0 @@ -""" -Functions for validating EIP-712 message structure (required fields, etc.). -""" - -# copied under the MIT license from the eth-account project: -# https://github.com/ethereum/eth-account/blob/cc2feca919474203b0b23450ce7f2deed3ce985c/eth_account/_utils/structured_data/validation.py -# flake8: noqa -import re - -from eth_utils import ValidationError - -# Regexes -IDENTIFIER_REGEX = r"^[a-zA-Z_$][a-zA-Z_$0-9]*$" -TYPE_REGEX = r"^[a-zA-Z_$][a-zA-Z_$0-9]*(\[([1-9]\d*\b)*\])*$" - - -def validate_has_attribute(attr_name, dict_data): - if attr_name not in dict_data: - raise ValidationError(f"Attribute `{attr_name}` not found in the JSON string") - - -def validate_types_attribute(structured_data): - # Check that the data has `types` attribute - validate_has_attribute("types", structured_data) - - # Check if all the `name` and the `type` attributes in each field of all the - # `types` attribute are valid (Regex Check) - for struct_name in structured_data["types"]: - # Check that `struct_name` is of the type string - if not isinstance(struct_name, str): - raise ValidationError( - "Struct Name of `types` attribute should be a string, " - f"but got type `{type(struct_name)}`" - ) - for field in structured_data["types"][struct_name]: - # Check that `field["name"]` is of the type string - if not isinstance(field["name"], str): - raise ValidationError( - f"Field Name `{field['name']}` of struct `{struct_name}` " - f"should be a string, but got type `{type(field['name'])}`" - ) - # Check that `field["type"]` is of the type string - if not isinstance(field["type"], str): - raise ValidationError( - f"Field Type `{field['type']}` of struct `{struct_name}` " - f"should be a string, but got type `{type(field['name'])}`" - ) - # Check that field["name"] matches with IDENTIFIER_REGEX - if not re.match(IDENTIFIER_REGEX, field["name"]): - raise ValidationError(f"Invalid Identifier `{field['name']}` in `{struct_name}`") - # Check that field["type"] matches with TYPE_REGEX - if not re.match(TYPE_REGEX, field["type"]): - raise ValidationError(f"Invalid Type `{field['type']}` in `{struct_name}`") - - -def validate_field_declared_only_once_in_struct(field_name, struct_data, struct_name): - if len([field for field in struct_data if field["name"] == field_name]) != 1: - raise ValidationError( - f"Attribute `{field_name}` not declared or declared more " f"than once in {struct_name}" - ) - - -EIP712_DOMAIN_FIELDS = [ - "name", - "version", - "chainId", - "verifyingContract", -] - - -def used_header_fields(EIP712Domain_data): - return [field["name"] for field in EIP712Domain_data if field["name"] in EIP712_DOMAIN_FIELDS] - - -def validate_EIP712Domain_schema(structured_data): - # Check that the `types` attribute contains `EIP712Domain` schema declaration - if "EIP712Domain" not in structured_data["types"]: - raise ValidationError("`EIP712Domain struct` not found in types attribute") - # Check that the names and types in `EIP712Domain` are what are mentioned in the - # EIP-712 and they are declared only once (if defined at all) - EIP712Domain_data = structured_data["types"]["EIP712Domain"] - header_fields = used_header_fields(EIP712Domain_data) - if len(header_fields) == 0: - raise ValidationError(f"One of {EIP712_DOMAIN_FIELDS} must be defined in {structured_data}") - for field in header_fields: - validate_field_declared_only_once_in_struct(field, EIP712Domain_data, "EIP712Domain") - - -def validate_primaryType_attribute(structured_data): - # Check that `primaryType` attribute is present - if "primaryType" not in structured_data: - raise ValidationError("The Structured Data needs to have a `primaryType` attribute") - # Check that `primaryType` value is a string - if not isinstance(structured_data["primaryType"], str): - raise ValidationError( - "Value of attribute `primaryType` should be `string`, " - f"but got type `{type(structured_data['primaryType'])}`" - ) - # Check that the value of `primaryType` is present in the `types` attribute - if not structured_data["primaryType"] in structured_data["types"]: - raise ValidationError( - f"The Primary Type `{structured_data['primaryType']}` is not " - "present in the `types` attribute" - ) - - -def validate_structured_data(structured_data): - # validate the `types` attribute - validate_types_attribute(structured_data) - # validate the `EIP712Domain` struct of `types` attribute - validate_EIP712Domain_schema(structured_data) - # validate the `primaryType` attribute - validate_primaryType_attribute(structured_data) - # Check that there is a `domain` attribute in the structured data - validate_has_attribute("domain", structured_data) - # Check that there is a `message` attribute in the structured data - validate_has_attribute("message", structured_data) diff --git a/setup.py b/setup.py index f439b9d..adcd4a1 100644 --- a/setup.py +++ b/setup.py @@ -25,14 +25,13 @@ "twine", # Package upload tool ], "doc": [ - "myst-parser>=0.17.0,<0.18", # Tools for parsing markdown files in the docs - "sphinx-click>=3.1.0,<4.0", # For documenting CLI - "Sphinx>=4.4.0,<5.0", # Documentation generator - "sphinx_rtd_theme>=1.0.0,<2", # Readthedocs.org theme + "myst-parser>=0.18.1,<0.19", # Tools for parsing markdown files in the docs + "Sphinx>=5.3.0,<6.0", # Documentation generator + "sphinx_rtd_theme>=1.2.0,<2", # Readthedocs.org theme "sphinxcontrib-napoleon>=0.7", # Allow Google-style documentation ], "dev": [ - "commitizen>=2.19,<2.20", # Manage commits and publishing releases + "commitizen>=2.42,<3.0", # Manage commits and publishing releases "pre-commit", # Ensure that linters are run prior to commiting "pytest-watch", # `ptw` test watcher/runner "IPython", # Console for interacting @@ -66,11 +65,12 @@ include_package_data=True, install_requires=[ "dataclassy>=0.8.2,<1", - "eth-utils>=2.1.0,<3", - "eth-abi>=3.0.1,<4", - "eth-typing>=2.3.0,<4", + "eth-abi>=2.2.0,<4", + "eth-account>0.4.0,<1.0.0", + "eth-hash[pycryptodome]", # NOTE: Pinned by eth-abi + "eth-typing>=2.3,<4", + "eth-utils>=1.9.5,<3", "hexbytes>=0.3.0,<1", - "pycryptodome>=3.16.0,<4", ], python_requires=">=3.8,<4", extras_require=extras_require, diff --git a/tests/conftest.py b/tests/conftest.py index 0c00a06..b1a6996 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,6 +1,7 @@ import pytest from hexbytes import HexBytes +from eip712.common import create_permit_def from eip712.messages import EIP712Message, EIP712Type PERMIT_NAME = "Yearn Vault" @@ -20,63 +21,51 @@ class SubType(EIP712Type): class ValidMessageWithNameDomainField(EIP712Message): - _name_: "string" = "Valid Test Message" # type: ignore + _name_ = "Valid Test Message" value: "uint256" # type: ignore default_value: "address" = "0xFFfFfFffFFfffFFfFFfFFFFFffFFFffffFfFFFfF" # type: ignore sub: SubType class MessageWithNonCanonicalDomainFieldOrder(EIP712Message): - _name_: "string" = PERMIT_NAME # type: ignore - _salt_: "bytes32" = PERMIT_SALT # type: ignore - _chainId_: "uint256" = PERMIT_CHAIN_ID # type: ignore - _version_: "string" = PERMIT_VERSION # type: ignore - _verifyingContract_: "address" = PERMIT_VAULT_ADDRESS # type: ignore + _name_ = PERMIT_NAME + _salt_ = PERMIT_SALT + _chainId_ = PERMIT_CHAIN_ID + _version_ = PERMIT_VERSION + _verifyingContract_ = PERMIT_VAULT_ADDRESS class MessageWithCanonicalDomainFieldOrder(EIP712Message): - _name_: "string" = PERMIT_NAME # type: ignore - _version_: "string" = PERMIT_VERSION # type: ignore - _chainId_: "uint256" = PERMIT_CHAIN_ID # type: ignore - _verifyingContract_: "address" = PERMIT_VAULT_ADDRESS # type: ignore - _salt_: "bytes32" = PERMIT_SALT # type: ignore - - -class MessageWithInvalidNameType(EIP712Message): - _name_: str = "Invalid Test Message" # type: ignore + _name_ = PERMIT_NAME + _version_ = PERMIT_VERSION + _chainId_ = PERMIT_CHAIN_ID + _verifyingContract_ = PERMIT_VAULT_ADDRESS + _salt_ = PERMIT_SALT class InvalidMessageMissingDomainFields(EIP712Message): value: "uint256" # type: ignore -class Permit(EIP712Message): - _name_: "string" = PERMIT_NAME # type: ignore - _version_: "string" # type: ignore - _chainId_: "uint256" # type: ignore - _verifyingContract_: "address" # type: ignore - _salt_: "bytes32" # type: ignore - - owner: "address" # type: ignore - spender: "address" # type: ignore - value: "uint256" # type: ignore - nonce: "uint256" # type: ignore - deadline: "uint256" # type: ignore - - @pytest.fixture def valid_message_with_name_domain_field(): return ValidMessageWithNameDomainField(value=1, sub=SubType(inner=2)) @pytest.fixture -def permit(): +def Permit(): + return create_permit_def( + name=PERMIT_NAME, + version=PERMIT_VERSION, + chainId=PERMIT_CHAIN_ID, + verifyingContract=PERMIT_VAULT_ADDRESS, + salt=PERMIT_SALT, + ) + + +@pytest.fixture +def permit(Permit): return Permit( - _name_=PERMIT_NAME, - _version_=PERMIT_VERSION, - _chainId_=PERMIT_CHAIN_ID, - _verifyingContract_=PERMIT_VAULT_ADDRESS, - _salt_=PERMIT_SALT, owner=PERMIT_OWNER_ADDRESS, spender=PERMIT_SPENDER_ADDRESS, value=PERMIT_ALLOWANCE, diff --git a/tests/test_common.py b/tests/test_common.py new file mode 100644 index 0000000..d2d6427 --- /dev/null +++ b/tests/test_common.py @@ -0,0 +1,29 @@ +import pytest +from eth_account.messages import hash_eip712_message as hash_message + +from eip712.common import SAFE_VERSIONS, create_safe_tx_def + +MSIG_ADDRESS = "0xFEB4acf3df3cDEA7399794D0869ef76A6EfAff52" + + +@pytest.mark.parametrize("version", SAFE_VERSIONS) +def test_gnosis_safe_tx(version): + tx_def = create_safe_tx_def( + version=version, + contract_address=MSIG_ADDRESS, + chain_id=1, + ) + + msg = tx_def(to=MSIG_ADDRESS, nonce=0) + + assert msg.signable_message.header.hex() == ( + "0x88fbc465dedd7fe71b7baef26a1f46cdaadd50b95c77cbe88569195a9fe589ab" + if version in ("1.3.0",) + else "0x590e9c66b22ee4584cd655fda57748ce186b85f829a092c28209478efbe86a92" + ) + + assert hash_message(msg).hex() == ( + "3c2fdf2ea8af328a67825162e7686000787c5cc9f4b27cb6bfbcaa445b59e2c4" + if version in ("1.3.0",) + else "1b393826bed1f2297ffc01916f8339892f9a51dc7f35f477b9a5cdd651d28603" + ) diff --git a/tests/test_fuzzing.py b/tests/test_fuzzing.py new file mode 100644 index 0000000..4d2cfef --- /dev/null +++ b/tests/test_fuzzing.py @@ -0,0 +1,39 @@ +import string + +import pytest +from eth_abi.tools._strategies import get_abi_strategy +from hypothesis import given, settings +from hypothesis import strategies as st + +from eip712 import EIP712Message # noqa: F401 + +abi_types = ["address"] +abi_types.extend(f"int{i}" for i in range(8, 256 + 8, 8)) +abi_types.extend(f"uint{i}" for i in range(8, 256 + 8, 8)) +abi_types = ( + abi_types # all other types + + [f"{t}[]" for t in abi_types] # dynamic arrays + + [f"{t}[{i}]" for i, t in zip(range(1, 10), abi_types)] # static arrays +) + + +@settings(max_examples=5000) +@pytest.mark.fuzzing +@given(types=st.lists(st.sampled_from(abi_types), min_size=1, max_size=10), data=st.data()) +def test_random_message_def(types, data): + members = string.ascii_lowercase[: len(types)] + members_str = "\n ".join(f'{k}: "{t}"' for k, t in zip(members, types)) + + exec( + f"""class Msg(EIP712Message): + _name_="test def" + {members_str}""", + globals(), + ) # Creates `Msg` definition + + values = [data.draw(get_abi_strategy(t), label=t) for t in types] + msg_dict = dict(zip(members, values)) + instance = Msg(**msg_dict) # noqa: F821 + + for k, v in msg_dict.items(): + assert getattr(instance, k) == v diff --git a/tests/test_messages.py b/tests/test_messages.py index f376f60..e8839d3 100644 --- a/tests/test_messages.py +++ b/tests/test_messages.py @@ -1,20 +1,18 @@ import pytest - -from eip712.messages import ValidationError +from eth_account.messages import ValidationError from .conftest import ( InvalidMessageMissingDomainFields, MessageWithCanonicalDomainFieldOrder, - MessageWithInvalidNameType, MessageWithNonCanonicalDomainFieldOrder, ) def test_multilevel_message(valid_message_with_name_domain_field): - msg = valid_message_with_name_domain_field - assert msg.version.hex() == "01" - assert msg.header.hex() == "336a9d2b32d1ab7ea7bbbd2565eca1910e54b74843858dec7a81f772a3c17e17" - assert msg.body.hex() == "306af87567fa87e55d2bd925d9a3ed2b1ec2c3e71b142785c053dc60b6ca177b" + msg = valid_message_with_name_domain_field.signable_message + assert msg.version.hex() == "0x01" + assert msg.header.hex() == "0x336a9d2b32d1ab7ea7bbbd2565eca1910e54b74843858dec7a81f772a3c17e17" + assert msg.body.hex() == "0x306af87567fa87e55d2bd925d9a3ed2b1ec2c3e71b142785c053dc60b6ca177b" def test_invalid_message_without_domain_fields(): @@ -22,27 +20,20 @@ def test_invalid_message_without_domain_fields(): InvalidMessageMissingDomainFields(value=1) -def test_invalid_type(): - message = MessageWithInvalidNameType() - expected_error_message = ( - "'_name_' type annotation must either be a subclass of " - "`EIP712Type` or valid ABI Type string, not str" - ) - - with pytest.raises(ValidationError, match=expected_error_message): - message.field_type("_name_") - - def test_yearn_vaults_message(permit, permit_raw_data): """ Testing a real world EIP712 message for a "permit" call in yearn-vaults. """ - assert permit.body_data == permit_raw_data + assert permit._body_ == permit_raw_data def test_eip712_domain_field_order_is_invariant(): assert ( - MessageWithCanonicalDomainFieldOrder.domain - == MessageWithNonCanonicalDomainFieldOrder.domain + MessageWithCanonicalDomainFieldOrder._domain_ + == MessageWithNonCanonicalDomainFieldOrder._domain_ ) + + +def test_ux_tuple_and_starargs(permit, Permit): + assert tuple(Permit(*permit)) == tuple(permit)