Skip to content

Commit

Permalink
Weight Utils fix (#1372)
Browse files Browse the repository at this point in the history
When Torch sets zeros or ones, it does this on the CPU. The metagraph is loaded on GPU. This results in errors—proposed fix.
  • Loading branch information
mrseeker authored Jun 26, 2023
1 parent b704234 commit 157801f
Showing 1 changed file with 3 additions and 3 deletions.
6 changes: 3 additions & 3 deletions bittensor/utils/weight_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,14 +182,14 @@ def process_weights_for_netuid(
non_zero_weights = weights[ non_zero_weight_idx ]
if non_zero_weights.numel() == 0 or metagraph.n < min_allowed_weights:
bittensor.logging.warning( 'No non-zero weights returning all ones.' )
final_weights = torch.ones( ( metagraph.n ) ) / metagraph.n
final_weights = torch.ones( ( metagraph.n ) ).to( metagraph.n ) / metagraph.n
bittensor.logging.debug( 'final_weights', final_weights )
return torch.tensor( list( range( len( final_weights ) ) ) ), final_weights

elif non_zero_weights.numel() < min_allowed_weights:
bittensor.logging.warning( 'No non-zero weights less then min allowed weight, returning all ones.' )
# ( const ): Should this be torch.zeros( ( metagraph.n ) ) to reset everyone to build up weight?
weights = torch.ones( ( metagraph.n ) ) * 1e-5 # creating minimum even non-zero weights
weights = torch.ones( ( metagraph.n ) ).to( metagraph.n ) * 1e-5 # creating minimum even non-zero weights
weights[non_zero_weight_idx] += non_zero_weights
bittensor.logging.debug( 'final_weights', weights )
normalized_weights = bittensor.utils.weight_utils.normalize_max_weight(
Expand Down Expand Up @@ -221,4 +221,4 @@ def process_weights_for_netuid(
)
bittensor.logging.debug( 'final_weights', normalized_weights )

return non_zero_weight_uids, normalized_weights
return non_zero_weight_uids, normalized_weights

0 comments on commit 157801f

Please sign in to comment.