From 504704ab41cd00118f64f445c42ba39523fa83a9 Mon Sep 17 00:00:00 2001 From: Adrian-Stefan Mares <36161392+adriansmares@users.noreply.github.com> Date: Wed, 7 Dec 2022 19:39:24 +0100 Subject: [PATCH] Add signature v2 format (#983) * Drop static header check * Add v2 signature format * Enforce nonce monotonicity * Store only wallet hotkey * Simplify signature check * Update neuron parameters on mismatch * Add receptor signature format test --- bittensor/__init__.py | 1 + bittensor/_axon/__init__.py | 101 +++++++++++------- bittensor/_receptor/receptor_impl.py | 28 +++-- bittensor/_subtensor/subtensor_impl.py | 31 +++--- tests/unit_tests/bittensor_tests/test_axon.py | 53 ++++++--- .../bittensor_tests/test_forward_backward.py | 27 +++-- .../bittensor_tests/test_receptor.py | 51 ++++++++- 7 files changed, 204 insertions(+), 88 deletions(-) diff --git a/bittensor/__init__.py b/bittensor/__init__.py index 09897e5bee..a341a21a03 100644 --- a/bittensor/__init__.py +++ b/bittensor/__init__.py @@ -27,6 +27,7 @@ version_split = __version__.split(".") __version_as_int__ = (100 * int(version_split[0])) + (10 * int(version_split[1])) + (1 * int(version_split[2])) +__new_signature_version__ = 360 # Turn off rich console locals trace. from rich.traceback import install diff --git a/bittensor/_axon/__init__.py b/bittensor/_axon/__init__.py index b2f2d30a22..2dd8c97d69 100644 --- a/bittensor/_axon/__init__.py +++ b/bittensor/_axon/__init__.py @@ -155,8 +155,9 @@ def __new__( if thread_pool == None: thread_pool = futures.ThreadPoolExecutor( max_workers = config.axon.max_workers ) if server == None: + receiver_hotkey = wallet.hotkey.ss58_address server = grpc.server( thread_pool, - interceptors=(AuthInterceptor(blacklist=blacklist),), + interceptors=(AuthInterceptor(receiver_hotkey=receiver_hotkey, blacklist=blacklist),), maximum_concurrent_rpcs = config.axon.maximum_concurrent_rpcs, options = [('grpc.keepalive_time_ms', 100000), ('grpc.keepalive_timeout_ms', 500000)] @@ -341,22 +342,26 @@ def check_forward_callback( forward_callback:Callable, synapses:list = []): class AuthInterceptor(grpc.ServerInterceptor): """Creates a new server interceptor that authenticates incoming messages from passed arguments.""" - def __init__(self, key: str = "Bittensor", blacklist: List = []): + def __init__( + self, + receiver_hotkey: str, + blacklist: Callable = None, + ): r"""Creates a new server interceptor that authenticates incoming messages from passed arguments. Args: - key (str, `optional`): - key for authentication header in the metadata (default = Bittensor) + receiver_hotkey(str): + the SS58 address of the hotkey which should be targeted by RPCs black_list (Function, `optional`): black list function that prevents certain pubkeys from sending messages """ super().__init__() - self.auth_header_value = key self.nonces = {} self.blacklist = blacklist + self.receiver_hotkey = receiver_hotkey def parse_legacy_signature( self, signature: str - ) -> Union[Tuple[int, str, str, str], None]: + ) -> Union[Tuple[int, str, str, str, int], None]: r"""Attempts to parse a signature using the legacy format, using `bitxx` as a separator""" parts = signature.split("bitxx") if len(parts) < 4: @@ -367,52 +372,71 @@ def parse_legacy_signature( except ValueError: return None receptor_uuid, parts = parts[-1], parts[:-1] - message, parts = parts[-1], parts[:-1] - pubkey = "".join(parts) - return (nonce, pubkey, message, receptor_uuid) + signature, parts = parts[-1], parts[:-1] + sender_hotkey = "".join(parts) + return (nonce, sender_hotkey, signature, receptor_uuid, 1) - def parse_signature(self, metadata: Dict[str, str]) -> Tuple[int, str, str, str]: + def parse_signature_v2( + self, signature: str + ) -> Union[Tuple[int, str, str, str, int], None]: + r"""Attempts to parse a signature using the v2 format""" + parts = signature.split(".") + if len(parts) != 4: + return None + try: + nonce = int(parts[0]) + except ValueError: + return None + sender_hotkey = parts[1] + signature = parts[2] + receptor_uuid = parts[3] + return (nonce, sender_hotkey, signature, receptor_uuid, 2) + + def parse_signature( + self, metadata: Dict[str, str] + ) -> Tuple[int, str, str, str, int]: r"""Attempts to parse a signature from the metadata""" signature = metadata.get("bittensor-signature") if signature is None: raise Exception("Request signature missing") - parts = self.parse_legacy_signature(signature) - if parts is not None: - return parts + for parser in [self.parse_signature_v2, self.parse_legacy_signature]: + parts = parser(signature) + if parts is not None: + return parts raise Exception("Unknown signature format") def check_signature( - self, nonce: int, pubkey: str, signature: str, receptor_uuid: str + self, + nonce: int, + sender_hotkey: str, + signature: str, + receptor_uuid: str, + format: int, ): r"""verification of signature in metadata. Uses the pubkey and nonce""" - keypair = Keypair(ss58_address=pubkey) + keypair = Keypair(ss58_address=sender_hotkey) # Build the expected message which was used to build the signature. - message = f"{nonce}{pubkey}{receptor_uuid}" + if format == 2: + message = f"{nonce}.{sender_hotkey}.{self.receiver_hotkey}.{receptor_uuid}" + elif format == 1: + message = f"{nonce}{sender_hotkey}{receptor_uuid}" + else: + raise Exception("Invalid signature version") # Build the key which uniquely identifies the endpoint that has signed # the message. - endpoint_key = f"{pubkey}:{receptor_uuid}" + endpoint_key = f"{sender_hotkey}:{receptor_uuid}" if endpoint_key in self.nonces.keys(): previous_nonce = self.nonces[endpoint_key] # Nonces must be strictly monotonic over time. - if nonce - previous_nonce <= -10: + if nonce <= previous_nonce: raise Exception("Nonce is too small") - if not keypair.verify(message, signature): - raise Exception("Signature mismatch") - self.nonces[endpoint_key] = nonce - return if not keypair.verify(message, signature): raise Exception("Signature mismatch") self.nonces[endpoint_key] = nonce - def version_checking(self, metadata: Dict[str, str]): - r"""Checks the header and version in the metadata""" - provided_value = metadata.get("rpc-auth-header") - if provided_value is None or provided_value != self.auth_header_value: - raise Exception("Unexpected caller metadata") - - def black_list_checking(self, pubkey: str, method: str): + def black_list_checking(self, hotkey: str, method: str): r"""Tries to call to blacklist function in the miner and checks if it should blacklist the pubkey""" if self.blacklist == None: return @@ -424,7 +448,7 @@ def black_list_checking(self, pubkey: str, method: str): if request_type is None: raise Exception("Unknown request type") - if self.blacklist(pubkey, request_type): + if self.blacklist(hotkey, request_type): raise Exception("Request type is blacklisted") def intercept_service(self, continuation, handler_call_details): @@ -433,16 +457,21 @@ def intercept_service(self, continuation, handler_call_details): metadata = dict(handler_call_details.invocation_metadata) try: - # version checking - self.version_checking(metadata) - - (nonce, pubkey, signature, receptor_uuid) = self.parse_signature(metadata) + ( + nonce, + sender_hotkey, + signature, + receptor_uuid, + signature_format, + ) = self.parse_signature(metadata) # signature checking - self.check_signature(nonce, pubkey, signature, receptor_uuid) + self.check_signature( + nonce, sender_hotkey, signature, receptor_uuid, signature_format + ) # blacklist checking - self.black_list_checking(pubkey, method) + self.black_list_checking(sender_hotkey, method) return continuation(handler_call_details) diff --git a/bittensor/_receptor/receptor_impl.py b/bittensor/_receptor/receptor_impl.py index 821a3f2bdf..c70d194afe 100644 --- a/bittensor/_receptor/receptor_impl.py +++ b/bittensor/_receptor/receptor_impl.py @@ -123,20 +123,32 @@ def __del__ ( self ): def __exit__ ( self ): self.__del__() - def sign ( self ): + def sign_v1( self ): r""" Uses the wallet pubkey to sign a message containing the pubkey and the time """ - nounce = self.nounce() - message = str(nounce) + str(self.wallet.hotkey.ss58_address) + str(self.receptor_uid) + nonce = self.nonce() + message = str(nonce) + str(self.wallet.hotkey.ss58_address) + str(self.receptor_uid) spliter = 'bitxx' - signature = spliter.join([ str(nounce), str(self.wallet.hotkey.ss58_address), "0x" + self.wallet.hotkey.sign(message).hex(), str(self.receptor_uid) ]) + signature = spliter.join([ str(nonce), str(self.wallet.hotkey.ss58_address), "0x" + self.wallet.hotkey.sign(message).hex(), str(self.receptor_uid) ]) return signature - - def nounce ( self ): + + def sign_v2(self): + nonce = f"{self.nonce()}" + sender_hotkey = self.wallet.hotkey.ss58_address + receiver_hotkey = self.endpoint.hotkey + message = f"{nonce}.{sender_hotkey}.{receiver_hotkey}.{self.receptor_uid}" + signature = f"0x{self.wallet.hotkey.sign(message).hex()}" + return ".".join([nonce, sender_hotkey, signature, self.receptor_uid]) + + def sign(self): + if self.endpoint.version >= bittensor.__new_signature_version__: + return self.sign_v2() + return self.sign_v1() + + def nonce ( self ): r"""creates a string representation of the time """ - nounce = int(clock.time() * 1000) - return nounce + return clock.monotonic_ns() def state ( self ): try: diff --git a/bittensor/_subtensor/subtensor_impl.py b/bittensor/_subtensor/subtensor_impl.py index fb01c7ce6f..f5970eb025 100644 --- a/bittensor/_subtensor/subtensor_impl.py +++ b/bittensor/_subtensor/subtensor_impl.py @@ -671,24 +671,29 @@ def serve ( # Decrypt hotkey wallet.hotkey - with bittensor.__console__.status(":satellite: Checking Axon..."): - neuron = self.neuron_for_pubkey( wallet.hotkey.ss58_address ) - if not neuron.is_null and neuron.ip == net.ip_to_int(ip) and neuron.port == port: - bittensor.__console__.print(":white_heavy_check_mark: [green]Already Served[/green]\n [bold white]ip: {}\n port: {}\n modality: {}\n hotkey: {}\n coldkey: {}[/bold white]".format(ip, port, modality, wallet.hotkey.ss58_address, wallet.coldkeypub.ss58_address)) - return True - - ip_as_int = net.ip_to_int(ip) - ip_version = net.ip_version(ip) - - # TODO(const): subscribe with version too. params = { 'version': bittensor.__version_as_int__, - 'ip': ip_as_int, - 'port': port, - 'ip_type': ip_version, + 'ip': net.ip_to_int(ip), + 'port': port, + 'ip_type': net.ip_version(ip), 'modality': modality, 'coldkey': wallet.coldkeypub.ss58_address, } + + with bittensor.__console__.status(":satellite: Checking Axon..."): + neuron = self.neuron_for_pubkey( wallet.hotkey.ss58_address ) + neuron_up_to_date = not neuron.is_null and params == { + 'version': neuron.version, + 'ip': neuron.ip, + 'port': neuron.port, + 'ip_type': neuron.ip_version, + 'modality': neuron.modality, + 'coldkey': neuron.coldkey + } + if neuron_up_to_date: + bittensor.__console__.print(":white_heavy_check_mark: [green]Already Served[/green]\n [bold white]ip: {}\n port: {}\n modality: {}\n hotkey: {}\n coldkey: {}[/bold white]".format(ip, port, modality, wallet.hotkey.ss58_address, wallet.coldkeypub.ss58_address)) + return True + if prompt: if not Confirm.ask("Do you want to serve axon:\n [bold white]ip: {}\n port: {}\n modality: {}\n hotkey: {}\n coldkey: {}[/bold white]".format(ip, port, modality, wallet.hotkey.ss58_address, wallet.coldkeypub.ss58_address)): return False diff --git a/tests/unit_tests/bittensor_tests/test_axon.py b/tests/unit_tests/bittensor_tests/test_axon.py index 58b6457e8b..072f26d549 100644 --- a/tests/unit_tests/bittensor_tests/test_axon.py +++ b/tests/unit_tests/bittensor_tests/test_axon.py @@ -34,19 +34,37 @@ wallet = bittensor.wallet.mock() axon = bittensor.axon(wallet = wallet) +sender_wallet = bittensor.wallet.mock() +def gen_nonce(): + return f"{time.monotonic_ns()}" -def sign(wallet): - nounce = str(int(time.time() * 1000)) - receptor_uid = str(uuid.uuid1()) - message = "{}{}{}".format(nounce, str(wallet.hotkey.ss58_address), receptor_uid) +def sign_v1(wallet): + nonce, receptor_uid = gen_nonce(), str(uuid.uuid1()) + message = "{}{}{}".format(nonce, str(wallet.hotkey.ss58_address), receptor_uid) spliter = 'bitxx' - signature = spliter.join([ nounce, str(wallet.hotkey.ss58_address), "0x" + wallet.hotkey.sign(message).hex(), receptor_uid]) + signature = spliter.join([ nonce, str(wallet.hotkey.ss58_address), "0x" + wallet.hotkey.sign(message).hex(), receptor_uid]) return signature -def test_sign(): - sign(wallet) - sign(axon.wallet) +def sign_v2(sender_wallet, receiver_wallet): + nonce, receptor_uid = gen_nonce(), str(uuid.uuid1()) + sender_hotkey = sender_wallet.hotkey.ss58_address + receiver_hotkey = receiver_wallet.hotkey.ss58_address + message = f"{nonce}.{sender_hotkey}.{receiver_hotkey}.{receptor_uid}" + signature = f"0x{sender_wallet.hotkey.sign(message).hex()}" + return ".".join([nonce, sender_hotkey, signature, receptor_uid]) + +def sign(sender_wallet, receiver_wallet, receiver_version): + if receiver_version >= bittensor.__new_signature_version__: + return sign_v2(sender_wallet, receiver_wallet) + return sign_v1(sender_wallet) + +def test_sign_v1(): + sign_v1(wallet) + sign_v1(axon.wallet) + +def test_sign_v2(): + sign_v2(sender_wallet, wallet) def test_forward_wandb(): inputs_raw = torch.rand(3, 3, bittensor.__network_dim__) @@ -902,7 +920,7 @@ def forward( inputs_x: torch.FloatTensor, synapses, model_output = None): assert code == bittensor.proto.ReturnCode.Success -def test_grpc_forward_works(): +def run_test_grpc_forward_works(receiver_version): def forward( inputs_x:torch.FloatTensor, synapse , model_output = None): return None, dict(), torch.zeros( [3, 3, bittensor.__network_dim__]) axon = bittensor.axon ( @@ -927,14 +945,14 @@ def forward( inputs_x:torch.FloatTensor, synapse , model_output = None): request = bittensor.proto.TensorMessage( version = bittensor.__version_as_int__, - hotkey = axon.wallet.hotkey.ss58_address, + hotkey = sender_wallet.hotkey.ss58_address, tensors = [inputs_serialized], synapses = [ syn.serialize_to_wire_proto() for syn in synapses ] ) response = stub.Forward(request, metadata = ( ('rpc-auth-header','Bittensor'), - ('bittensor-signature',sign(axon.wallet)), + ('bittensor-signature',sign(sender_wallet, wallet, receiver_version)), ('bittensor-version',str(bittensor.__version_as_int__)), )) @@ -943,8 +961,11 @@ def forward( inputs_x:torch.FloatTensor, synapse , model_output = None): assert response.return_code == bittensor.proto.ReturnCode.Success axon.stop() +def test_grpc_forward_works(): + for receiver_version in [341, bittensor.__new_signature_version__, bittensor.__version_as_int__]: + run_test_grpc_forward_works(receiver_version) -def test_grpc_backward_works(): +def run_test_grpc_backward_works(receiver_version): def forward( inputs_x:torch.FloatTensor, synapse , model_output = None): return None, dict(), torch.zeros( [3, 3, bittensor.__network_dim__], requires_grad=True) @@ -969,19 +990,23 @@ def forward( inputs_x:torch.FloatTensor, synapse , model_output = None): grads_serialized = synapses[0].serialize_backward_request_gradient(inputs_raw, grads_raw) request = bittensor.proto.TensorMessage( version = bittensor.__version_as_int__, - hotkey = '1092310312914', + hotkey = sender_wallet.hotkey.ss58_address, tensors = [inputs_serialized, grads_serialized], synapses = [ syn.serialize_to_wire_proto() for syn in synapses ] ) response = stub.Backward(request, metadata = ( ('rpc-auth-header','Bittensor'), - ('bittensor-signature',sign(axon.wallet)), + ('bittensor-signature',sign(sender_wallet, wallet, receiver_version)), ('bittensor-version',str(bittensor.__version_as_int__)), )) assert response.return_code == bittensor.proto.ReturnCode.Success axon.stop() +def test_grpc_backward_works(): + for receiver_version in [341, bittensor.__new_signature_version__, bittensor.__version_as_int__]: + run_test_grpc_backward_works(receiver_version) + def test_grpc_forward_fails(): def forward( inputs_x:torch.FloatTensor, synapse, model_output = None): return None, dict(), torch.zeros( [3, 3, bittensor.__network_dim__]) diff --git a/tests/unit_tests/bittensor_tests/test_forward_backward.py b/tests/unit_tests/bittensor_tests/test_forward_backward.py index 527b792ced..adb61848cd 100644 --- a/tests/unit_tests/bittensor_tests/test_forward_backward.py +++ b/tests/unit_tests/bittensor_tests/test_forward_backward.py @@ -223,24 +223,21 @@ def forward( inputs_x: torch.FloatTensor, synapse , model_output = None): ) axon.attach_synapse_callback( forward, synapse_type = bittensor.proto.Synapse.SynapseType.TEXT_LAST_HIDDEN_STATE) axon.start() - endpoints = [] - for i in range(20): - wallet.create_new_hotkey( use_password=False, overwrite = True) - endpoint = bittensor.endpoint( - version = bittensor.__version_as_int__, - uid = 0, - hotkey = wallet.hotkey.ss58_address, - ip = '0.0.0.0', - ip_type = 4, - port = axon_port, - modality = 2, - coldkey = wallet.coldkey.ss58_address - ) - endpoints += [endpoint] + endpoint = bittensor.endpoint( + version = bittensor.__version_as_int__, + uid = 0, + hotkey = wallet.hotkey.ss58_address, + ip = '0.0.0.0', + ip_type = 4, + port = axon_port, + modality = 2, + coldkey = wallet.coldkey.ss58_address + ) + endpoints = [endpoint] x = torch.zeros(3, 3) synapses = [bittensor.synapse.TextLastHiddenState()] - tensors, codes, times = dendrite.text( endpoints=endpoints, inputs=[x for i in endpoints], synapses=synapses) + tensors, codes, times = dendrite.text( endpoints=endpoints, inputs=[x for _ in endpoints], synapses=synapses) receptors_states = dendrite.receptor_pool.get_receptors_state() # TODO: Fails locally independent of multiprocessing. assert receptors_states[endpoint.hotkey] == receptors_states[endpoint.hotkey].READY diff --git a/tests/unit_tests/bittensor_tests/test_receptor.py b/tests/unit_tests/bittensor_tests/test_receptor.py index 829a6bfeee..b3ad9e50c8 100644 --- a/tests/unit_tests/bittensor_tests/test_receptor.py +++ b/tests/unit_tests/bittensor_tests/test_receptor.py @@ -431,9 +431,52 @@ def backward_break(): out, ops, time = receptor.backward(synapses, x, [hidden_grads, causal_grads, causallmnext_grads, seq_2_seq_grads], timeout=1) assert ops == [bittensor.proto.ReturnCode.UnknownException] * len(synapses) +def test_receptor_signature_output(): + def verify_v1(signature: str): + (nonce, sender_address, signature, receptor_uuid) = signature.split("bitxx") + assert nonce == "123" + assert sender_address == "5Ey8t8pBJSYqLYCzeC3HiPJu5DxzXy2Dzheaj29wRHvhjoai" + assert receptor_uuid == "6d8b8788-6b6a-11ed-916f-0242c0a85003" + message = f"{nonce}{sender_address}{receptor_uuid}" + assert wallet.hotkey.verify(message, signature) + + def verify_v2(signature: str): + (nonce, sender_address, signature, receptor_uuid) = signature.split(".") + assert nonce == "123" + assert sender_address == "5Ey8t8pBJSYqLYCzeC3HiPJu5DxzXy2Dzheaj29wRHvhjoai" + assert receptor_uuid == "6d8b8788-6b6a-11ed-916f-0242c0a85003" + message = f"{nonce}.{sender_address}.5CSbZ7wG456oty4WoiX6a1J88VUbrCXLhrKVJ9q95BsYH4TZ.{receptor_uuid}" + assert wallet.hotkey.verify(message, signature) + + matrix = { + bittensor.__new_signature_version__ - 1: verify_v1, + bittensor.__new_signature_version__: verify_v2, + } + + for (receiver_version, verify) in matrix.items(): + endpoint = bittensor.endpoint( + version=receiver_version, + uid=0, + ip="127.0.0.1", + ip_type=4, + port=65000, + hotkey="5CSbZ7wG456oty4WoiX6a1J88VUbrCXLhrKVJ9q95BsYH4TZ", + coldkey="5DD26kC2kxajmwfbbZmVmxhrY9VeeyR1Gpzy9i8wxLUg6zxm", + modality=2, + ) + + receptor = bittensor.receptor( + endpoint=endpoint, + wallet=wallet, + ) + receptor.receptor_uid = "6d8b8788-6b6a-11ed-916f-0242c0a85003" + receptor.nonce = lambda: 123 + + verify(receptor.sign()) + #-- axon receptor connection -- -def test_axon_receptor_connection_forward_works(): +def run_test_axon_receptor_connection_forward_works(receiver_version): def forward_generate( input, synapse, model_output = None): return None, None, torch.zeros( [3, 70]) @@ -459,7 +502,7 @@ def forward_casual_lm_next(input, synapse, model_output=None): axon.start() endpoint = bittensor.endpoint( - version = bittensor.__version_as_int__, + version = receiver_version, uid = 0, ip = '127.0.0.1', ip_type = 4, @@ -480,6 +523,10 @@ def forward_casual_lm_next(input, synapse, model_output=None): axon.stop() +def test_axon_receptor_connection_forward_works(): + for receiver_version in [341, bittensor.__new_signature_version__, bittensor.__version_as_int__]: + run_test_axon_receptor_connection_forward_works(receiver_version) + def test_axon_receptor_connection_forward_unauthenticated(): def forward_generate( input, synapse, model_output = None ): return None, None, torch.zeros( [3, 70])