diff --git a/references/classification/transforms.py b/references/classification/transforms.py index 5443437d29d..96236608eec 100644 --- a/references/classification/transforms.py +++ b/references/classification/transforms.py @@ -19,9 +19,9 @@ def get_mixup_cutmix(*, mixup_alpha, cutmix_alpha, num_classes, use_v2): ) if cutmix_alpha > 0: mixup_cutmix.append( - transforms_module.CutMix(alpha=mixup_alpha, num_classes=num_classes) + transforms_module.CutMix(alpha=cutmix_alpha, num_classes=num_classes) if use_v2 - else RandomCutMix(num_classes=num_classes, p=1.0, alpha=mixup_alpha) + else RandomCutMix(num_classes=num_classes, p=1.0, alpha=cutmix_alpha) ) if not mixup_cutmix: return None