Skip to content

Commit

Permalink
Add signature v2 format (#983)
Browse files Browse the repository at this point in the history
* Drop static header check

* Add v2 signature format

* Enforce nonce monotonicity

* Store only wallet hotkey

* Simplify signature check

* Update neuron parameters on mismatch

* Add receptor signature format test
  • Loading branch information
adriansmares authored Dec 7, 2022
1 parent df54ca9 commit 504704a
Show file tree
Hide file tree
Showing 7 changed files with 204 additions and 88 deletions.
1 change: 1 addition & 0 deletions bittensor/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
version_split = __version__.split(".")
__version_as_int__ = (100 * int(version_split[0])) + (10 * int(version_split[1])) + (1 * int(version_split[2]))

__new_signature_version__ = 360

# Turn off rich console locals trace.
from rich.traceback import install
Expand Down
101 changes: 65 additions & 36 deletions bittensor/_axon/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,8 +155,9 @@ def __new__(
if thread_pool == None:
thread_pool = futures.ThreadPoolExecutor( max_workers = config.axon.max_workers )
if server == None:
receiver_hotkey = wallet.hotkey.ss58_address
server = grpc.server( thread_pool,
interceptors=(AuthInterceptor(blacklist=blacklist),),
interceptors=(AuthInterceptor(receiver_hotkey=receiver_hotkey, blacklist=blacklist),),
maximum_concurrent_rpcs = config.axon.maximum_concurrent_rpcs,
options = [('grpc.keepalive_time_ms', 100000),
('grpc.keepalive_timeout_ms', 500000)]
Expand Down Expand Up @@ -341,22 +342,26 @@ def check_forward_callback( forward_callback:Callable, synapses:list = []):
class AuthInterceptor(grpc.ServerInterceptor):
"""Creates a new server interceptor that authenticates incoming messages from passed arguments."""

def __init__(self, key: str = "Bittensor", blacklist: List = []):
def __init__(
self,
receiver_hotkey: str,
blacklist: Callable = None,
):
r"""Creates a new server interceptor that authenticates incoming messages from passed arguments.
Args:
key (str, `optional`):
key for authentication header in the metadata (default = Bittensor)
receiver_hotkey(str):
the SS58 address of the hotkey which should be targeted by RPCs
black_list (Function, `optional`):
black list function that prevents certain pubkeys from sending messages
"""
super().__init__()
self.auth_header_value = key
self.nonces = {}
self.blacklist = blacklist
self.receiver_hotkey = receiver_hotkey

def parse_legacy_signature(
self, signature: str
) -> Union[Tuple[int, str, str, str], None]:
) -> Union[Tuple[int, str, str, str, int], None]:
r"""Attempts to parse a signature using the legacy format, using `bitxx` as a separator"""
parts = signature.split("bitxx")
if len(parts) < 4:
Expand All @@ -367,52 +372,71 @@ def parse_legacy_signature(
except ValueError:
return None
receptor_uuid, parts = parts[-1], parts[:-1]
message, parts = parts[-1], parts[:-1]
pubkey = "".join(parts)
return (nonce, pubkey, message, receptor_uuid)
signature, parts = parts[-1], parts[:-1]
sender_hotkey = "".join(parts)
return (nonce, sender_hotkey, signature, receptor_uuid, 1)

def parse_signature(self, metadata: Dict[str, str]) -> Tuple[int, str, str, str]:
def parse_signature_v2(
self, signature: str
) -> Union[Tuple[int, str, str, str, int], None]:
r"""Attempts to parse a signature using the v2 format"""
parts = signature.split(".")
if len(parts) != 4:
return None
try:
nonce = int(parts[0])
except ValueError:
return None
sender_hotkey = parts[1]
signature = parts[2]
receptor_uuid = parts[3]
return (nonce, sender_hotkey, signature, receptor_uuid, 2)

def parse_signature(
self, metadata: Dict[str, str]
) -> Tuple[int, str, str, str, int]:
r"""Attempts to parse a signature from the metadata"""
signature = metadata.get("bittensor-signature")
if signature is None:
raise Exception("Request signature missing")
parts = self.parse_legacy_signature(signature)
if parts is not None:
return parts
for parser in [self.parse_signature_v2, self.parse_legacy_signature]:
parts = parser(signature)
if parts is not None:
return parts
raise Exception("Unknown signature format")

def check_signature(
self, nonce: int, pubkey: str, signature: str, receptor_uuid: str
self,
nonce: int,
sender_hotkey: str,
signature: str,
receptor_uuid: str,
format: int,
):
r"""verification of signature in metadata. Uses the pubkey and nonce"""
keypair = Keypair(ss58_address=pubkey)
keypair = Keypair(ss58_address=sender_hotkey)
# Build the expected message which was used to build the signature.
message = f"{nonce}{pubkey}{receptor_uuid}"
if format == 2:
message = f"{nonce}.{sender_hotkey}.{self.receiver_hotkey}.{receptor_uuid}"
elif format == 1:
message = f"{nonce}{sender_hotkey}{receptor_uuid}"
else:
raise Exception("Invalid signature version")
# Build the key which uniquely identifies the endpoint that has signed
# the message.
endpoint_key = f"{pubkey}:{receptor_uuid}"
endpoint_key = f"{sender_hotkey}:{receptor_uuid}"

if endpoint_key in self.nonces.keys():
previous_nonce = self.nonces[endpoint_key]
# Nonces must be strictly monotonic over time.
if nonce - previous_nonce <= -10:
if nonce <= previous_nonce:
raise Exception("Nonce is too small")
if not keypair.verify(message, signature):
raise Exception("Signature mismatch")
self.nonces[endpoint_key] = nonce
return

if not keypair.verify(message, signature):
raise Exception("Signature mismatch")
self.nonces[endpoint_key] = nonce

def version_checking(self, metadata: Dict[str, str]):
r"""Checks the header and version in the metadata"""
provided_value = metadata.get("rpc-auth-header")
if provided_value is None or provided_value != self.auth_header_value:
raise Exception("Unexpected caller metadata")

def black_list_checking(self, pubkey: str, method: str):
def black_list_checking(self, hotkey: str, method: str):
r"""Tries to call to blacklist function in the miner and checks if it should blacklist the pubkey"""
if self.blacklist == None:
return
Expand All @@ -424,7 +448,7 @@ def black_list_checking(self, pubkey: str, method: str):
if request_type is None:
raise Exception("Unknown request type")

if self.blacklist(pubkey, request_type):
if self.blacklist(hotkey, request_type):
raise Exception("Request type is blacklisted")

def intercept_service(self, continuation, handler_call_details):
Expand All @@ -433,16 +457,21 @@ def intercept_service(self, continuation, handler_call_details):
metadata = dict(handler_call_details.invocation_metadata)

try:
# version checking
self.version_checking(metadata)

(nonce, pubkey, signature, receptor_uuid) = self.parse_signature(metadata)
(
nonce,
sender_hotkey,
signature,
receptor_uuid,
signature_format,
) = self.parse_signature(metadata)

# signature checking
self.check_signature(nonce, pubkey, signature, receptor_uuid)
self.check_signature(
nonce, sender_hotkey, signature, receptor_uuid, signature_format
)

# blacklist checking
self.black_list_checking(pubkey, method)
self.black_list_checking(sender_hotkey, method)

return continuation(handler_call_details)

Expand Down
28 changes: 20 additions & 8 deletions bittensor/_receptor/receptor_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,20 +123,32 @@ def __del__ ( self ):
def __exit__ ( self ):
self.__del__()

def sign ( self ):
def sign_v1( self ):
r""" Uses the wallet pubkey to sign a message containing the pubkey and the time
"""
nounce = self.nounce()
message = str(nounce) + str(self.wallet.hotkey.ss58_address) + str(self.receptor_uid)
nonce = self.nonce()
message = str(nonce) + str(self.wallet.hotkey.ss58_address) + str(self.receptor_uid)
spliter = 'bitxx'
signature = spliter.join([ str(nounce), str(self.wallet.hotkey.ss58_address), "0x" + self.wallet.hotkey.sign(message).hex(), str(self.receptor_uid) ])
signature = spliter.join([ str(nonce), str(self.wallet.hotkey.ss58_address), "0x" + self.wallet.hotkey.sign(message).hex(), str(self.receptor_uid) ])
return signature

def nounce ( self ):

def sign_v2(self):
nonce = f"{self.nonce()}"
sender_hotkey = self.wallet.hotkey.ss58_address
receiver_hotkey = self.endpoint.hotkey
message = f"{nonce}.{sender_hotkey}.{receiver_hotkey}.{self.receptor_uid}"
signature = f"0x{self.wallet.hotkey.sign(message).hex()}"
return ".".join([nonce, sender_hotkey, signature, self.receptor_uid])

def sign(self):
if self.endpoint.version >= bittensor.__new_signature_version__:
return self.sign_v2()
return self.sign_v1()

def nonce ( self ):
r"""creates a string representation of the time
"""
nounce = int(clock.time() * 1000)
return nounce
return clock.monotonic_ns()

def state ( self ):
try:
Expand Down
31 changes: 18 additions & 13 deletions bittensor/_subtensor/subtensor_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -671,24 +671,29 @@ def serve (
# Decrypt hotkey
wallet.hotkey

with bittensor.__console__.status(":satellite: Checking Axon..."):
neuron = self.neuron_for_pubkey( wallet.hotkey.ss58_address )
if not neuron.is_null and neuron.ip == net.ip_to_int(ip) and neuron.port == port:
bittensor.__console__.print(":white_heavy_check_mark: [green]Already Served[/green]\n [bold white]ip: {}\n port: {}\n modality: {}\n hotkey: {}\n coldkey: {}[/bold white]".format(ip, port, modality, wallet.hotkey.ss58_address, wallet.coldkeypub.ss58_address))
return True

ip_as_int = net.ip_to_int(ip)
ip_version = net.ip_version(ip)

# TODO(const): subscribe with version too.
params = {
'version': bittensor.__version_as_int__,
'ip': ip_as_int,
'port': port,
'ip_type': ip_version,
'ip': net.ip_to_int(ip),
'port': port,
'ip_type': net.ip_version(ip),
'modality': modality,
'coldkey': wallet.coldkeypub.ss58_address,
}

with bittensor.__console__.status(":satellite: Checking Axon..."):
neuron = self.neuron_for_pubkey( wallet.hotkey.ss58_address )
neuron_up_to_date = not neuron.is_null and params == {
'version': neuron.version,
'ip': neuron.ip,
'port': neuron.port,
'ip_type': neuron.ip_version,
'modality': neuron.modality,
'coldkey': neuron.coldkey
}
if neuron_up_to_date:
bittensor.__console__.print(":white_heavy_check_mark: [green]Already Served[/green]\n [bold white]ip: {}\n port: {}\n modality: {}\n hotkey: {}\n coldkey: {}[/bold white]".format(ip, port, modality, wallet.hotkey.ss58_address, wallet.coldkeypub.ss58_address))
return True

if prompt:
if not Confirm.ask("Do you want to serve axon:\n [bold white]ip: {}\n port: {}\n modality: {}\n hotkey: {}\n coldkey: {}[/bold white]".format(ip, port, modality, wallet.hotkey.ss58_address, wallet.coldkeypub.ss58_address)):
return False
Expand Down
53 changes: 39 additions & 14 deletions tests/unit_tests/bittensor_tests/test_axon.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,19 +34,37 @@
wallet = bittensor.wallet.mock()
axon = bittensor.axon(wallet = wallet)

sender_wallet = bittensor.wallet.mock()

def gen_nonce():
return f"{time.monotonic_ns()}"

def sign(wallet):
nounce = str(int(time.time() * 1000))
receptor_uid = str(uuid.uuid1())
message = "{}{}{}".format(nounce, str(wallet.hotkey.ss58_address), receptor_uid)
def sign_v1(wallet):
nonce, receptor_uid = gen_nonce(), str(uuid.uuid1())
message = "{}{}{}".format(nonce, str(wallet.hotkey.ss58_address), receptor_uid)
spliter = 'bitxx'
signature = spliter.join([ nounce, str(wallet.hotkey.ss58_address), "0x" + wallet.hotkey.sign(message).hex(), receptor_uid])
signature = spliter.join([ nonce, str(wallet.hotkey.ss58_address), "0x" + wallet.hotkey.sign(message).hex(), receptor_uid])
return signature

def test_sign():
sign(wallet)
sign(axon.wallet)
def sign_v2(sender_wallet, receiver_wallet):
nonce, receptor_uid = gen_nonce(), str(uuid.uuid1())
sender_hotkey = sender_wallet.hotkey.ss58_address
receiver_hotkey = receiver_wallet.hotkey.ss58_address
message = f"{nonce}.{sender_hotkey}.{receiver_hotkey}.{receptor_uid}"
signature = f"0x{sender_wallet.hotkey.sign(message).hex()}"
return ".".join([nonce, sender_hotkey, signature, receptor_uid])

def sign(sender_wallet, receiver_wallet, receiver_version):
if receiver_version >= bittensor.__new_signature_version__:
return sign_v2(sender_wallet, receiver_wallet)
return sign_v1(sender_wallet)

def test_sign_v1():
sign_v1(wallet)
sign_v1(axon.wallet)

def test_sign_v2():
sign_v2(sender_wallet, wallet)

def test_forward_wandb():
inputs_raw = torch.rand(3, 3, bittensor.__network_dim__)
Expand Down Expand Up @@ -902,7 +920,7 @@ def forward( inputs_x: torch.FloatTensor, synapses, model_output = None):
assert code == bittensor.proto.ReturnCode.Success


def test_grpc_forward_works():
def run_test_grpc_forward_works(receiver_version):
def forward( inputs_x:torch.FloatTensor, synapse , model_output = None):
return None, dict(), torch.zeros( [3, 3, bittensor.__network_dim__])
axon = bittensor.axon (
Expand All @@ -927,14 +945,14 @@ def forward( inputs_x:torch.FloatTensor, synapse , model_output = None):

request = bittensor.proto.TensorMessage(
version = bittensor.__version_as_int__,
hotkey = axon.wallet.hotkey.ss58_address,
hotkey = sender_wallet.hotkey.ss58_address,
tensors = [inputs_serialized],
synapses = [ syn.serialize_to_wire_proto() for syn in synapses ]
)
response = stub.Forward(request,
metadata = (
('rpc-auth-header','Bittensor'),
('bittensor-signature',sign(axon.wallet)),
('bittensor-signature',sign(sender_wallet, wallet, receiver_version)),
('bittensor-version',str(bittensor.__version_as_int__)),
))

Expand All @@ -943,8 +961,11 @@ def forward( inputs_x:torch.FloatTensor, synapse , model_output = None):
assert response.return_code == bittensor.proto.ReturnCode.Success
axon.stop()

def test_grpc_forward_works():
for receiver_version in [341, bittensor.__new_signature_version__, bittensor.__version_as_int__]:
run_test_grpc_forward_works(receiver_version)

def test_grpc_backward_works():
def run_test_grpc_backward_works(receiver_version):
def forward( inputs_x:torch.FloatTensor, synapse , model_output = None):
return None, dict(), torch.zeros( [3, 3, bittensor.__network_dim__], requires_grad=True)

Expand All @@ -969,19 +990,23 @@ def forward( inputs_x:torch.FloatTensor, synapse , model_output = None):
grads_serialized = synapses[0].serialize_backward_request_gradient(inputs_raw, grads_raw)
request = bittensor.proto.TensorMessage(
version = bittensor.__version_as_int__,
hotkey = '1092310312914',
hotkey = sender_wallet.hotkey.ss58_address,
tensors = [inputs_serialized, grads_serialized],
synapses = [ syn.serialize_to_wire_proto() for syn in synapses ]
)
response = stub.Backward(request,
metadata = (
('rpc-auth-header','Bittensor'),
('bittensor-signature',sign(axon.wallet)),
('bittensor-signature',sign(sender_wallet, wallet, receiver_version)),
('bittensor-version',str(bittensor.__version_as_int__)),
))
assert response.return_code == bittensor.proto.ReturnCode.Success
axon.stop()

def test_grpc_backward_works():
for receiver_version in [341, bittensor.__new_signature_version__, bittensor.__version_as_int__]:
run_test_grpc_backward_works(receiver_version)

def test_grpc_forward_fails():
def forward( inputs_x:torch.FloatTensor, synapse, model_output = None):
return None, dict(), torch.zeros( [3, 3, bittensor.__network_dim__])
Expand Down
Loading

0 comments on commit 504704a

Please sign in to comment.