Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Dy2Stat]Support Mixed Precision training in @to_static #34562

Merged
merged 4 commits into from
Aug 5, 2021
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions python/paddle/fluid/dygraph/amp/auto_cast.py
Original file line number Diff line number Diff line change
@@ -90,6 +90,17 @@ def _update_list(custom_white_list, custom_black_list):
return _white_list, _black_list


def _in_amp_guard():
"""
Judge whether current code block is in `amp_guard` context.
"""
tracer = _dygraph_tracer()
if tracer:
return tracer._enable_autocast
else:
return False


@signature_safe_contextmanager
@dygraph_only
def amp_guard(enable=True, custom_white_list=None, custom_black_list=None):
56 changes: 51 additions & 5 deletions python/paddle/fluid/dygraph/dygraph_to_static/partial_program.py
Original file line number Diff line number Diff line change
@@ -17,7 +17,7 @@
import six

import paddle
from paddle.fluid import framework, backward, core
from paddle.fluid import framework, backward, core, program_guard
from paddle.fluid.dygraph import layers
from paddle.fluid.dygraph.base import switch_to_static_graph
from paddle.fluid.dygraph.dygraph_to_static import logging_utils
@@ -26,6 +26,9 @@
from paddle.fluid.layers.utils import pack_sequence_as
from paddle.fluid.layers.utils import _hash_with_id
from paddle.fluid.compiler import BuildStrategy
from paddle.fluid.contrib.mixed_precision.decorator import AutoMixedPrecisionLists
from paddle.fluid.contrib.mixed_precision.fp16_utils import rewrite_program
from paddle.fluid.dygraph.amp.auto_cast import _in_amp_guard
import paddle.compat as cpt
from paddle import _C_ops

@@ -149,6 +152,9 @@ def __init__(self, main_program, inputs, outputs, parameters=None,
self._double_grads = self._get_double_grads(self._origin_main_program)
self.training = True

# For AMP training
self._amp_list = AutoMixedPrecisionLists()

@LazyInitialized
def _infer_program(self):
"""
@@ -168,6 +174,25 @@ def _train_program(self):

return train_program

@LazyInitialized
@switch_to_static_graph
def _infer_amp_program(self):
"""
Lazy initialized property of infer_amp_program.
"""
infer_amp_program = self._origin_main_program.clone()
with program_guard(infer_amp_program):
rewrite_program(infer_amp_program, self._amp_list)

return infer_amp_program

@LazyInitialized
def _train_amp_program(self):
"""
Lazy initialized property of train_amp_program.
"""
return self._append_backward_desc(self._infer_amp_program)

@LazyInitialized
def _infer_program_id(self):
return _hash_with_id(self._infer_program, self)
@@ -180,6 +205,14 @@ def _train_program_id(self):

return program_id

@LazyInitialized
def _train_amp_program_id(self):
program_id = _hash_with_id(self._train_amp_program, self)
core._set_cached_executor_build_strategy(program_id,
self._build_strategy)

return program_id

def _verify_program(self, main_program):
"""
Verify that the program parameter is initialized, prune some unused params,
@@ -241,12 +274,17 @@ def _get_double_grads(self, program):
double_grads.append(var_base)
return self._valid_vars(double_grads)

def _get_end_op_index(self):
infer_program = self._infer_amp_program if _in_amp_guard(
) else self._infer_program
return infer_program.desc.block(0).op_size()

def __call__(self, inputs):
in_vars, out_vars = self._prepare(inputs)

attrs = ('global_block', self.program.desc.block(0), 'start_op_index',
0, 'end_op_index', self._infer_program.desc.block(0).op_size(),
'is_test', not self.training, 'program_id', self.program_id)
0, 'end_op_index', self._get_end_op_index(), 'is_test',
not self.training, 'program_id', self.program_id)
_C_ops.run_program(
self._valid_vars(in_vars),
self._valid_vars(self._params),
@@ -258,11 +296,19 @@ def __call__(self, inputs):

@property
def program(self):
return self._train_program if self.training else self._infer_program
if self.training:
return self._train_amp_program if _in_amp_guard(
) else self._train_program
else:
return self._infer_program

@property
def program_id(self):
return self._train_program_id if self.training else self._infer_program_id
if self.training:
return self._train_amp_program_id if _in_amp_guard(
) else self._train_program_id
else:
return self._infer_program_id

def _prepare(self, inputs):
"""
5 changes: 5 additions & 0 deletions python/paddle/fluid/framework.py
Original file line number Diff line number Diff line change
@@ -2035,6 +2035,11 @@ def __init__(self,
del op_attrs[role_var_name]

if len(self.desc.type()) != 0:
# NOTE(Aurelius84): prog.clone() will lead that var.op is always None,
Aurelius84 marked this conversation as resolved.
Show resolved Hide resolved
# we add this to fix the problem.
for arg in self.desc.output_arg_names():
if block.has_var(arg) and block.var(arg).op is None:
block.var(arg).op = self
return
if type is None:
raise ValueError(
Original file line number Diff line number Diff line change
@@ -32,6 +32,9 @@

SEED = 2020

if paddle.fluid.is_compiled_with_cuda():
paddle.fluid.set_flags({'FLAGS_cudnn_deterministic': True})


class SimpleImgConvPool(fluid.dygraph.Layer):
def __init__(self,
@@ -48,7 +51,7 @@ def __init__(self,
conv_dilation=1,
conv_groups=1,
act=None,
use_cudnn=False,
use_cudnn=True,
param_attr=None,
bias_attr=None):
super(SimpleImgConvPool, self).__init__()
@@ -101,7 +104,6 @@ def __init__(self):
loc=0.0, scale=scale)),
act="softmax")

@paddle.jit.to_static
def forward(self, inputs, label=None):
x = self.inference(inputs)
if label is not None:
@@ -167,14 +169,14 @@ def test_mnist_declarative_cpu_vs_mkldnn(self):
dygraph_loss_cpu, dygraph_loss_mkldnn))

def train(self, to_static=False):
prog_trans = ProgramTranslator()
prog_trans.enable(to_static)

loss_data = []
with fluid.dygraph.guard(self.place):
fluid.default_main_program().random_seed = SEED
fluid.default_startup_program().random_seed = SEED
mnist = MNIST()
if to_static:
mnist = paddle.jit.to_static(mnist)
adam = AdamOptimizer(
learning_rate=0.001, parameter_list=mnist.parameters())

Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import paddle
import unittest
import numpy as np
from time import time
from test_mnist import MNIST, TestMNIST, SEED
from paddle.jit import ProgramTranslator
from paddle.fluid.optimizer import AdamOptimizer

if paddle.fluid.is_compiled_with_cuda():
paddle.fluid.set_flags({'FLAGS_cudnn_deterministic': True})


class TestAMP(TestMNIST):
def train_static(self):
return self.train(to_static=True)

def train_dygraph(self):
return self.train(to_static=False)

def test_mnist_to_static(self):
dygraph_loss = self.train_dygraph()
static_loss = self.train_static()
# NOTE(Aurelius84): In static AMP training, there is a grep_list but
# dygraph AMP don't. It will bring the numbers of cast_op is different
# and leads to loss has a bit diff.
self.assertTrue(
np.allclose(
dygraph_loss, static_loss, atol=1e-3),
msg='dygraph is {}\n static_res is \n{}'.format(dygraph_loss,
static_loss))

def train(self, to_static=False):
paddle.seed(SEED)
mnist = MNIST()

if to_static:
print("Successfully to apply @to_static.")
mnist = paddle.jit.to_static(mnist)

adam = AdamOptimizer(
learning_rate=0.001, parameter_list=mnist.parameters())

scaler = paddle.amp.GradScaler(init_loss_scaling=1024)

loss_data = []
for epoch in range(self.epoch_num):
start = time()
for batch_id, data in enumerate(self.train_reader()):
dy_x_data = np.array([x[0].reshape(1, 28, 28)
for x in data]).astype('float32')
y_data = np.array(
[x[1] for x in data]).astype('int64').reshape(-1, 1)

img = paddle.to_tensor(dy_x_data)
label = paddle.to_tensor(y_data)
label.stop_gradient = True

with paddle.amp.auto_cast():
prediction, acc, avg_loss = mnist(img, label=label)

scaled = scaler.scale(avg_loss)
scaled.backward()
scaler.minimize(adam, scaled)

loss_data.append(avg_loss.numpy()[0])
# save checkpoint
mnist.clear_gradients()
if batch_id % 10 == 0:
print(
"Loss at epoch {} step {}: loss: {:}, acc: {}, cost: {}"
.format(epoch, batch_id,
avg_loss.numpy(), acc.numpy(), time() - start))
start = time()
if batch_id == 50:
break
return loss_data


if __name__ == '__main__':
unittest.main()