Code for "Understanding and Improving Knowledge Distailltion for Quantization-Aware-Training of Large Transformer Encoders"
Proceeding, Arxiv
- This paper provides in-depth analysis of the mechanism of Knowledge Distillation(KD) on Attention recovery of quantized large Transformer encoders.
- Based on this analysis, we propose new sets of KD loss functions for better QAT of ultra-low bit precision (Weight Ternarization of Transformer Encoders).
Our implementation is based on the Huawei-Noah TernaryBERT Pytorch code. (link)
pip install -r requirements.txt
First, you need task-specific fine-tuned full-precision BERT models for initialize model for QAT. You can fine-tune BERT-base/large pre-trained model using huggingface example code with following link.
Or you could download fine-tuned BERT-base/large model with Google Cloud link. (sst-2, rte provided)
Fine-tuned BERT model file should be plaece in "models" folder with its GLUE task name. (ex. models/rte)
This repository provides multiple KD options for Ternary QAT of BERT-base/large.
For attention map/output loss QAT,
bash run_kd_qat_map.sh $GPU_NUM # map loss
bash run_kd_qat_output.sh $GPU_NUM # output loss
For exploration of KD options for QAT with BERT-base/large over GLUE tasks, use run_kd_qat_exploration.sh. For example, let's run attention-map loss QAT of BERT-base with CoLA task.
task_name=cola
bert=base
map_coeff=1
output_coeff=0
bash run_kd_qat_exploration.sh $task_name $bert $map_coeff $output_coeff
For explorating mixing parameters for attention-map/output losses, run run_mixing_param_sweep.sh
For experimental notebooks, you need QAT model file from Training section. Note that fine-tuned full-precision model files should be placed in models folder, and QAT model files should be placed in output folder. For example,
teacher_model_dir = "models/BERT_base/sst-2"
student_model_dir = "output/BERT_large/rte/exploration/$EXP_NAME"
Please set full-precision/QAT model directory name properly in notebook :)
Exp 1 shows how to measure self-attention map distance between teacher model (full-precision model) and student model (ternary quantized model) This notebook provides attention map distance plot as follows.
Exp2 provides hessian max eigenvalue spectra of QAT model. This implementation is based on the Pyhessian and repository of "Park et al, How do Vision Transformer Work?, ICLR 2022"
Pyhessian : https://github.com/amirgholami/PyHessian How-vits-work : xxxnell/how-do-vits-work#12
This experiements provide analysis of attention output's min-max dynamic range and attention norm. Once you load model file properly, you can find the model's attention output dynamic range per layer and conduct norm based analysis per layer/head.
Per Layer Attnetion output min-max dynamic range (Left), Norm based analysis per layer (Right)
Per Head attnetion probability, transformed output heat map visualization and difference between student and teacher model visualizataion with heat map. (with Attention-map/output loss QAT)
Attentoin Norm based analysis is based on "Kobayashi et al Attention is Not Only a Weight: Analyzing Transformers with Vector Norms, EMNLP 2020" code link
For further question, contact me anytime ([email protected]) or kindly leave questions in issues tab.