From b08b0b5086dd8de11162620bf7e7a9caa71f145f Mon Sep 17 00:00:00 2001 From: Minhua Chen Date: Tue, 4 Feb 2025 20:10:31 -0800 Subject: [PATCH] AdagradW (#3605) Summary: CounterWeightDecayMode.SQRT Differential Revision: D67625467 --- fbgemm_gpu/codegen/genscript/optimizers.py | 14 ++++++++++++-- .../split_table_batched_embeddings_ops_training.py | 1 + 2 files changed, 13 insertions(+), 2 deletions(-) diff --git a/fbgemm_gpu/codegen/genscript/optimizers.py b/fbgemm_gpu/codegen/genscript/optimizers.py index ff570b6e51..4387ff6f79 100644 --- a/fbgemm_gpu/codegen/genscript/optimizers.py +++ b/fbgemm_gpu/codegen/genscript/optimizers.py @@ -473,7 +473,11 @@ def rowwise_adagrad_with_counter() -> Dict[str, Any]: const auto iter_delta = prev_iter[idx] == 0 ? 1.0 : iter * 1.0 - prev_iter[idx]; prev_iter[idx] = iter * 1.0; const auto counter_log_rho = logf(2.0) / counter_halflife; - row_counter[idx] = 1.0 + expf(-iter_delta * counter_log_rho) * row_counter[idx]; + if (regularization_mode == 3 && weight_decay_mode == 3) { + tail_id_threshold_val = iter_delta; + } else { + row_counter[idx] = 1.0 + expf(-iter_delta * counter_log_rho) * row_counter[idx]; + } } else if (counter_halflife == 0) { // count only 1 (appear or not) row_counter[idx] = 1.0; } else { // count raw appearance without decaying @@ -552,7 +556,13 @@ def rowwise_adagrad_with_counter() -> Dict[str, Any]: exp_reg_correction = 1.0; if (regularization_mode == 3) { // counter-based regularization (regularization_mode=3) if (adjustment_enabled) { - if (weight_decay_mode == 2) { // Decoupled weight decay (weight_decay_mode=2) + if (weight_decay_mode == 3) { // AdagradW (weight_decay_mode=3) + freq = min(tail_id_threshold_val, iter*1.0 - adjustment_iter); + exp_reg_correction = 1.0 - weight_decay * learning_rate / sqrtf(iter*1.0); + freq = expf(- weight_decay * learning_rate * 2.0 * (sqrtf(iter*1.0) - sqrtf(iter*1.0 - freq + 1.0))); + adjusted_multiplier *= freq; // lazy update + exp_reg_correction *= freq; + } else if (weight_decay_mode == 2) { // Decoupled weight decay (weight_decay_mode=2) exp_reg_correction = 1.0 - freq * weight_decay * learning_rate; } else if (weight_decay_mode == 1) { // L2 regularization (coupled wd) exp_reg_correction = 1.0 - freq * weight_decay * multiplier; diff --git a/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_training.py b/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_training.py index 70693def2c..9ad83df1b2 100644 --- a/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_training.py +++ b/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_training.py @@ -99,6 +99,7 @@ class CounterWeightDecayMode(enum.IntEnum): NONE = 0 L2 = 1 DECOUPLE = 2 + SQRT = 3 class StepMode(enum.IntEnum):