Skip to content

Commit

Permalink
Bit 590 backward fix (#957)
Browse files Browse the repository at this point in the history
* init

* no local forward and remote forward overlap

* clean up

* saving remote

* fix local size mismatch

* clean up

* fix

* hidden state and causalLM deterministicness

* rm backward

* default to have dendrite backward
  • Loading branch information
isabella618033 authored Oct 25, 2022
1 parent 2df26db commit 1da158c
Show file tree
Hide file tree
Showing 5 changed files with 72 additions and 37 deletions.
2 changes: 1 addition & 1 deletion bittensor/_axon/axon_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -261,7 +261,7 @@ def finalize_codes_stats_and_logs( message = None):
code = synapse_codes[ index ],
call_time = synapse_call_times[ index ],
pubkey = request.hotkey,
inputs = synapse_inputs [index] ,
inputs = deserialized_forward_tensors [index].shape if deserialized_forward_tensors [index] != None else None ,
outputs = None if synapse_responses[index] == None else list( synapse_responses[index].shape ),
message = synapse_messages[ index ] if message == None else message,
synapse = synapse.synapse_type
Expand Down
32 changes: 21 additions & 11 deletions bittensor/_neuron/text/core_server/nucleus_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from types import SimpleNamespace
from typing import Tuple, Optional

import transformers
from transformers import AutoModel,AutoTokenizer,AutoConfig, AutoModelForCausalLM
from torch.nn.utils.rnn import pad_sequence
from bittensor.utils.tokenizer_utils import prep_tokenizer, get_translation_map, translate_logits_to_probs_std, \
Expand Down Expand Up @@ -115,14 +116,16 @@ def __init__(self,
self.outputs_cache = None
self.gradients_cache = None
self.best_loss = math.inf
self.best_remote_loss = math.inf

#checking if the parameters of the server makes sense
if self.checking and pretrained == True:
self.check()

# -- keeps track of gradients applied
self.backward_gradients_count = 0

self.remote_losses = []

def set_fine_tuning_params(self) -> Tuple[bool, str]:
r''' Set to tune only the parameter of the last layer
Returns:
Expand Down Expand Up @@ -205,7 +208,7 @@ def remapping_token(self, token_batch, std_tokenizer=None, return_offsets_mappin
result = translate_special_token_text(text_batch, std_tokenizer, self.tokenizer) # translate special tokens
to_text_batch, from_offsets_batch, to_offsets_batch, pad_offsets_batch = result

tokens = self.tokenizer(to_text_batch, padding=True, truncation=True, return_tensors='pt',
tokens = self.tokenizer(to_text_batch, padding=True, truncation=True, max_length=token_batch.size(1), return_tensors='pt',
add_special_tokens=False).to(self.device) # assume tokenizer.padding_side = 'left'

if return_offsets_mapping: # get offsets_mapping in tokenization to delineate token segment positions
Expand Down Expand Up @@ -235,7 +238,6 @@ def forward(self, inputs, tokenizer=None):
"""
message, model_output, decoded_targets = self.local_forward(inputs, tokenizer)

shift_logits = decoded_targets[..., :-1, :].contiguous()
shift_labels = inputs[..., 1:].contiguous()
loss = self.loss_fct( shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1) )
Expand Down Expand Up @@ -264,8 +266,7 @@ def local_forward(self, token_batch, tokenizer=None, encode_len=bittensor.__netw
logits (:obj:`torch.FloatTensor`):
The nucleus's logit outputs as a torch tensor of shape [batch_size, sequence_len, __vocab_size__]
"""
tokens = self.token_remap(token_batch, std_tokenizer=tokenizer) # remap to server tokenizer

tokens = self.token_remap(token_batch, std_tokenizer=tokenizer, return_offsets_mapping=True) # remap to server tokenizer
if model_output == None:
if self.config.neuron.local_train:
model_output = self.pre_model(input_ids=tokens['input_ids'],
Expand Down Expand Up @@ -298,6 +299,9 @@ def encode_forward(self,inputs,tokenizer=None, model_output = None):
encoded_hidden (:type:`torch.Tensor`, `required`)
The hidden layer output as a torch tensor of shape [batch_size, sequence_len, __network_dim__ ]
"""
transformers.set_seed(0)
transformers.enable_full_determinism(0)

sen_len = inputs.size()
tokens = self.token_remap(inputs, tokenizer) # remap to server tokenizer

Expand Down Expand Up @@ -352,6 +356,9 @@ def encode_forward_causallm(self, token_batch, tokenizer=None, encode_len=bitten
logits_std (:obj:`torch.FloatTensor`):
The nucleus's logit outputs as a torch tensor of shape [batch_size, sequence_len, __vocab_size__]
"""
transformers.set_seed(0)
transformers.enable_full_determinism(0)

tokens = self.token_remap(token_batch, std_tokenizer=tokenizer, return_offsets_mapping=True) # remap to server tokenizer

def _forward(_model_output=model_output):
Expand All @@ -374,10 +381,8 @@ def _forward(_model_output=model_output):
#removing the loss calculation for stablity testing
original_loss = self.get_loss_fct(pre_logits, tokens['input_ids']).item()
translated_loss = self.get_loss_fct(logits_std, token_batch).item()
#message = 'Success'
message = f'Loss: {original_loss:.2f}{translated_loss:.2f}'
# logger.info(f'TextCausalLM \t| Server loss: {original_loss: .2f} \t| Translated loss: {translated_loss: .2f}')


return message, _model_output, logits_std

if self.config.neuron.remote_train:
Expand Down Expand Up @@ -421,10 +426,12 @@ def encode_forward_causallmnext(self, token_batch, std_tokenizer=None, topk: int
[prob_floor_b=1, ignore_index, ..., ignore_index]],
[...]]
"""
transformers.set_seed(0)
transformers.enable_full_determinism(0)

if std_tokenizer is None:
std_tokenizer = self.std_tokenizer

# remap to server tokenizer, expect right-aligned sequences so that last position keeps continuation prediction
tokens = self.token_remap(token_batch, std_tokenizer)

def _forward(_model_output=model_output):
Expand All @@ -442,8 +449,8 @@ def _forward(_model_output=model_output):

original_loss = self.get_loss_fct(_model_output.logits, tokens['input_ids']).item()
message = f'Loss: {original_loss:.2f}'
#message = 'Success'

_model_output.loss = original_loss
return message, _model_output, topk_tensor

if self.config.neuron.remote_train:
Expand Down Expand Up @@ -485,6 +492,7 @@ def save(self, path):
'pretrained_model': self.pre_model.state_dict(),
'decoder': self.decoder.state_dict(),
'best_loss': self.best_loss,
'best_remote_loss': self.best_remote_loss,
}
if self.padding == False:
state_dict['mapping'] = self.mapping.state_dict()
Expand All @@ -502,6 +510,7 @@ def load(self, path):
if self.padding == False:
self.mapping.load_state_dict(state_dict['mapping'])
self.best_loss = state_dict['best_loss']
self.best_remote_loss = state_dict['best_remote_loss']
bittensor.logging.success( prefix = 'Reloaded model', sufix = '<blue>{}/model.torch</blue>'.format( path ))


Expand Down Expand Up @@ -543,6 +552,7 @@ def config ():
parser.add_argument('--neuron.blacklist_allow_non_registered', action='store_true', help='''If true, allow non-registered peers''', default=False)
parser.add_argument('--neuron.disable_blacklist', action='store_true', help='Turns off blacklisting', default=False)
parser.add_argument('--neuron.disable_priority', action='store_true', help='Turns off priority threadpool', default=False)
parser.add_argument('--neuron.num_remote_loss', type=int, help='Number of past remote loss to keep in stat.', default=20)

# Synapse Arguements
parser.add_argument('--neuron.lasthidden', action='store_false', help='To turn off last hidden synapse', default=True)
Expand Down
67 changes: 44 additions & 23 deletions bittensor/_neuron/text/core_server/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -292,17 +292,22 @@ def backward_callback(inputs_x:torch.FloatTensor, grads_dy:torch.FloatTensor, sy
for index, synapse in enumerate(synapses):
try:
if synapse.synapse_type in axon.synapse_callbacks and axon.synapse_callbacks[synapse.synapse_type] != None:
model_output, response_tensor = axon.synapse_callbacks[synapse.synapse_type](inputs_x[index], synapse)
message, model_output, response_tensor = axon.synapse_callbacks[synapse.synapse_type](inputs_x[index], synapse)
grads_dy_norm = grads_dy[index]/(grads_dy[index].sum() + 0.00001)
torch.autograd.backward (
tensors = [ response_tensor ],
grad_tensors = [ grads_dy_norm ],
retain_graph=True
)
)
# Only consider loss from causal LM next.
if synapse.synapse_type == bittensor.proto.Synapse.SynapseType.TEXT_CAUSAL_LM_NEXT:
model.remote_losses.append(model_output.loss)
model.remote_losses = model.remote_losses[-config.neuron.num_remote_loss:] if len(model.remote_losses) > config.neuron.num_remote_loss else model.remote_losses
model.backward_gradients_count += inputs_x[index].size(0)
response_tensors.append(None)
response_codes.append(bittensor.proto.ReturnCode.Success)
response_messages.append('Success')

else:
response_tensors.append(None)
response_codes.append(bittensor.proto.ReturnCode.NotImplemented)
Expand Down Expand Up @@ -356,7 +361,6 @@ def backward_callback(inputs_x:torch.FloatTensor, grads_dy:torch.FloatTensor, sy

# --- Run Forever.
while True:

iteration = 0
local_data = {}
nn = subtensor.neuron_for_pubkey(wallet.hotkey.ss58_address)
Expand All @@ -366,15 +370,19 @@ def backward_callback(inputs_x:torch.FloatTensor, grads_dy:torch.FloatTensor, sy
if config.neuron.local_train:
# --- Training step.
while end_block >= current_block:
if current_block != subtensor.get_current_block():
loss, _ = model( next( dataset ).to(model.device) )
if iteration > 0 :
losses += loss
else:
losses = loss
iteration += 1
current_block = subtensor.get_current_block()
logger.info(f'local training\titeration: {iteration}\tloss: {loss}')
if current_block != subtensor.get_current_block() and axon.priority_threadpool.is_empty:
with mutex:
logger.info(f'local training\titeration: {iteration}\tstart')
loss, _ = model( next(dataset).to(model.device) )
if iteration > 0 :
losses += loss
else:
losses = loss
iteration += 1
current_block = subtensor.get_current_block()
logger.info(f'local training\titeration: {iteration}\tloss: {loss}')
else:
time.sleep(1)

if iteration != 0:
(losses/iteration).backward()
Expand All @@ -384,7 +392,6 @@ def backward_callback(inputs_x:torch.FloatTensor, grads_dy:torch.FloatTensor, sy
time.sleep(12)
current_block = subtensor.get_current_block()


# --- Update parameters
if (config.neuron.local_train and iteration > 0) or (config.neuron.remote_train and model.backward_gradients_count > 0):
# Custom learning rate
Expand All @@ -393,18 +400,32 @@ def backward_callback(inputs_x:torch.FloatTensor, grads_dy:torch.FloatTensor, sy
else:
optimizer.param_groups[0]['lr'] = 0.1

logger.info('Backpropagation Started')
clip_grad_norm_(model.parameters(), 1.0)
optimizer.step()
optimizer.zero_grad()
model.backward_gradients = 0
logger.info('Backpropagation Successful: Model updated')
local_data = {'local/loss': losses.detach().item() / iteration}
logger.info('Optmization Started')
with mutex:
clip_grad_norm_(model.parameters(), 1.0)
optimizer.step()
optimizer.zero_grad()
logger.info('Optimization Successful: Model updated')

if (config.neuron.local_train and iteration > 0):
local_data = {'local/loss': losses.detach().item() / iteration}

if local_data['local/loss'] < model.best_loss:
model.best_loss = local_data['local/loss']
model.save(config.neuron.full_path)
if local_data['local/loss'] < model.best_loss:
model.best_loss = local_data['local/loss']
model.save(config.neuron.full_path)

# Save it only when it gives a low average loss over a large sample size (config.neuron.num_remote_loss), default to 20.
elif (config.neuron.remote_train and len(model.remote_losses) >= config.neuron.num_remote_loss):
local_data = {'local/remote_loss': sum(model.remote_losses) / len(model.remote_losses)}

if local_data['local/remote_loss'] < model.best_remote_loss:
model.best_remote_loss = local_data['local/remote_loss']
model.save(config.neuron.full_path)

model.remote_losses = []

model.backward_gradients_count = 0

wandb_data = {
'stake': nn.stake,
'rank': nn.rank,
Expand Down
4 changes: 2 additions & 2 deletions bittensor/_neuron/text/core_validator/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -823,7 +823,7 @@ def add_args( cls, parser ):
parser.add_argument('--nucleus.dropout', type=float, help='the dropout value', default=0.2)
parser.add_argument('--nucleus.importance', type=float, help='hyperparameter for the importance loss', default=3)
parser.add_argument('--nucleus.noise_multiplier', type=float, help='Standard deviation multipler on weights', default=2 )
parser.add_argument('--nucleus.dendrite_backward', action='store_true', help='Pass backward request to the server side or not', default=False )
parser.add_argument('--nucleus.no_dendrite_backward', action='store_true', help='Pass backward request to the server side or not', default=False )
parser.add_argument('--nucleus.scaling_law_power', type=float, help='Power for modified scaling law, powered down to improve dynamic range, e.g. 3 → 6 nats for 0.5. (default value: -1, pulling from subtensor directly)', default=-1)
parser.add_argument('--nucleus.synergy_scaling_law_power', type=float, help='Power for synergy modified scaling law, powered down to improve dynamic range, e.g. 3 → 6 nats for 0.5. (default value: -1, pulling from subtensor directly)', default=-1)

Expand Down Expand Up @@ -962,7 +962,7 @@ def forward(
timeout=bittensor.__blocktime__
)

if not self.config.nucleus.dendrite_backward:
if self.config.nucleus.no_dendrite_backward:
query_responses = [[syn.detach().to(self.device) for syn in res] for res in query_responses]
return_ops = [ops.detach().to(self.device) for ops in return_ops]
times = [t.detach().to(self.device) for t in times]
Expand Down
4 changes: 4 additions & 0 deletions bittensor/_threadpool/priority_thread_pool_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,10 @@ def __init__(self, maxsize = -1, max_workers=None, thread_name_prefix='',
self._initializer = initializer
self._initargs = initargs

@property
def is_empty(self):
return self._work_queue.empty()

def submit(self, fn, *args, **kwargs):
with self._shutdown_lock:
if self._broken:
Expand Down

0 comments on commit 1da158c

Please sign in to comment.