Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

The mindspore preserves the gradients of the layers that are not involved in forward propagation #309

Open
PhyllisJi opened this issue Oct 25, 2024 · 3 comments

Comments

@PhyllisJi
Copy link

Environment

Hardware Environment(Ascend/GPU/CPU): GPU

Software Environment:

  • MindSpore version (source or binary): 2.2.14
  • Python version (e.g., Python 3.7.5): 3.8
  • OS platform and distribution (e.g., Linux Ubuntu 16.04): Ubuntu
  • GCC/Compiler version (if compiled from source):

Describe the current behavior

In MindSpore, layers defined in the init method automatically register their parameters, even if those layers are not invoked in the construct method. This results in the parameters of these unused layers receiving gradients and being included in the optimization process, leading to potential performance issues and unintended behavior.

Parameters of unused layers consume additional computation resources by participating in gradient calculations, even when these layers do not contribute to the model's forward pass.

Unused parameters occupy memory throughout the training process. For complex models with multiple unused branches or conditional logic, this can significantly impact memory usage and training efficiency.

Describe the expected behavior

  1. Lazy Parameter Registration:

Modify MindSpore’s behavior to register parameters only when layers are used within construct. This would align MindSpore’s behavior with other popular frameworks like TensorFlow and PyTorch, enhancing both performance and usability.

  1. Warning for Unused Layers:
    Introduce runtime warnings for layers that are defined but not invoked during the forward pass. This would help developers identify potential design issues early on.

Steps to reproduce the issue

import mindspore
import paddle
import paddle.nn as nn
import numpy as np


class Model_UpigUILcfoLJTSACr2VkArnWgxwvqcCh(mindspore.nn.Cell):
    def __init__(self):
        super(Model_UpigUILcfoLJTSACr2VkArnWgxwvqcCh, self).__init__()
        self.conv1_mutated = mindspore.nn.Conv2dTranspose(in_channels=3, out_channels=3, kernel_size=(7, 7), stride=(2, 2), pad_mode="pad", padding=(0, 0, 0, 0), output_padding=(0, 0), dilation=(1, 1), group=1, has_bias=True)
        self.pool1_mutated = mindspore.nn.MaxPool2d(kernel_size=(8, 8), stride=(2, 2), pad_mode="pad", padding=(0, 0), dilation=(1, 1), return_indices=False, ceil_mode=False)
        self.conv2_mutated = mindspore.nn.Conv2dTranspose(in_channels=3, out_channels=2, kernel_size=(1, 1), stride=(1, 1), pad_mode="pad", padding=(0, 0, 0, 0), output_padding=(0, 0), dilation=(1, 1), group=1, has_bias=True)
        self.relu1_mutated = mindspore.nn.HSwish()
        self.conv3_mutated = mindspore.nn.Conv2dTranspose(in_channels=2, out_channels=3, kernel_size=(1, 1), stride=(1, 1), pad_mode="pad", padding=(0, 0, 0, 0), output_padding=(0, 0), dilation=(1, 1), group=1, has_bias=True)
        self.relu2_mutated = mindspore.nn.HSwish()
        self.conv4_mutated = mindspore.nn.Conv2dTranspose(in_channels=2, out_channels=3, kernel_size=(3, 3), stride=(1, 1), pad_mode="pad", padding=(1, 1, 1, 1), output_padding=(0, 0), dilation=(1, 1), group=1, has_bias=True)
        self.tail_flatten = mindspore.nn.Flatten(start_dim=1, end_dim=-1)
        self.tail_fc = mindspore.nn.Dense(in_channels=149187, out_channels=1000)

    def construct(self, x):
        x = self.conv1_mutated(x)
        x = self.pool1_mutated(x)
        x = self.conv2_mutated(x)
        x = self.relu1_mutated(x)
        y2 = self.conv4_mutated(x)
        tail_flatten_output = self.tail_flatten(y2)
        tail_fc_output = self.tail_fc(tail_flatten_output)

        tail_fc_output = tail_fc_output
        return tail_fc_output


class Model_1729408030(nn.Layer):
    def __init__(self):
        super(Model_1729408030, self).__init__()
        self.conv1_mutated = paddle.nn.Conv2DTranspose(in_channels=3, out_channels=3, kernel_size=[7, 7], stride=[2, 2], padding=[0, 0], output_padding=[0, 0], dilation=[1, 1], groups=1, bias_attr=None)
        self.pool1_mutated = paddle.nn.MaxPool2D(kernel_size=[8, 8], stride=[2, 2], padding=[0, 0], ceil_mode=False)
        self.conv2_mutated = paddle.nn.Conv2DTranspose(in_channels=3, out_channels=2, kernel_size=[1, 1], stride=[1, 1], padding=[0, 0], output_padding=[0, 0], dilation=[1, 1], groups=1, bias_attr=None)
        self.relu1_mutated = paddle.nn.Hardswish()
        self.conv3_mutated = paddle.nn.Conv2DTranspose(in_channels=2, out_channels=3, kernel_size=[1, 1], stride=[1, 1], padding=[0, 0], output_padding=[0, 0], dilation=[1, 1], groups=1, bias_attr=None)
        self.relu2_mutated = paddle.nn.Hardswish()
        self.conv4_mutated = paddle.nn.Conv2DTranspose(in_channels=2, out_channels=3, kernel_size=[3, 3], stride=[1, 1], padding=[1, 1], output_padding=[0, 0], dilation=[1, 1], groups=1, bias_attr=None)
        self.tail_flatten = paddle.nn.Flatten()
        self.tail_fc = paddle.nn.Linear(in_features=149187, out_features=1000)

    def forward(self, x):
        x = self.conv1_mutated(x)
        x = self.pool1_mutated(x)
        x = self.conv2_mutated(x)
        x = self.relu1_mutated(x)
        y2 = self.conv4_mutated(x)
        tail_flatten_output = self.tail_flatten(y2)
        tail_fc_output = self.tail_fc(tail_flatten_output)

        tail_fc_output = tail_fc_output
        return tail_fc_output


def paddle_train(model, inp, label, is_gpu, dtype):
    import paddle

    gpu_str = 'gpu:1' if is_gpu else 'cpu'
    paddle.set_device(f'{gpu_str}')
    if dtype == "float32":
        my_input = paddle.to_tensor(inp).astype('float32')
    elif dtype == "float64":
        my_input = paddle.to_tensor(inp).astype('float64')
    else:
        my_input = paddle.to_tensor(inp).astype('float16')
    output = model(my_input)
    target = paddle.to_tensor(label, dtype='int64')
    loss = paddle.nn.CrossEntropyLoss()(output, target)
    loss.backward()
    gradients = {name: param.grad.to('cpu').numpy() for name, param in model.named_parameters() if param.grad is not None }
    for key in gradients.keys():
        if len(gradients[key].shape) == 2:
            gradients[key] = gradients[key].T
    return gradients, loss.item(), output.detach().to('cpu').numpy()

def mindspore_train(model, inp, label, is_gpu, dtype):
    import mindspore

    mindspore.context.set_context(device_target=('GPU' if is_gpu else 'CPU'))
    if dtype == "float32":
        my_input = mindspore.Tensor(inp.astype(np.float32))
    elif dtype == "float64":
        my_input = mindspore.Tensor(inp.astype(np.float64))
    else:
        my_input = mindspore.Tensor(inp.astype(np.float16))

    def forward_fn(label):
        ms_output = model(my_input)
        label = label.astype(np.int32)
        ms_targets = mindspore.Tensor(label)
        loss = mindspore.nn.CrossEntropyLoss(reduction='mean')(ms_output, ms_targets)
        return loss, ms_output

    (loss, output), ms_gradients = mindspore.value_and_grad(forward_fn, None, model.trainable_params(), has_aux=True)(
        label)
    gradients = {}
    for var, gradient in zip(model.trainable_params(), ms_gradients):
        gradients.setdefault(var.name, gradient.numpy())
    return gradients, loss.numpy().item(), output.numpy()

shape = [1, 3, 224, 224]
inp = np.random.random(shape).astype(np.float32)
label = np.random.random([1])
ms_model = Model_UpigUILcfoLJTSACr2VkArnWgxwvqcCh()
m_g, m_l, m_o = mindspore_train(ms_model, inp, label, False, "float32")
print(m_g.keys())

pd_model = Model_1729408030()
p_g, p_l, p_o = paddle_train(pd_model, inp, label, False, "float32")
print(p_g.keys())

Related log / screenshot

MindSpore:
dict_keys(['conv1_mutated.weight', 'conv1_mutated.bias', 'conv2_mutated.weight', 'conv2_mutated.bias', 'conv3_mutated.weight', 'conv3_mutated.bias', 'conv4_mutated.weight', 'conv4_mutated.bias', 'tail_fc.weight', 'tail_fc.bias'])
PaddlePaddle:
dict_keys(['conv1_mutated.weight', 'conv1_mutated.bias', 'conv2_mutated.weight', 'conv2_mutated.bias', 'conv4_mutated.weight', 'conv4_mutated.bias', 'tail_fc.weight', 'tail_fc.bias'])

Special notes for this issue

@zhouyifeng888
Copy link

Hello, you are correct. In MindSpore, by default, some layers defined will still receive gradient updates even if they do not participate in forward propagation, because the parameters of each layer have requires_grad set to True by default. If you want to prevent gradient updates for layers that do not participate in forward propagation, you can try using the nn.Cell.set_grad(requires_grad=False) method to see if it achieves your expected result.

@daheyinyin
Copy link

Thanks for your suggestions about using Lazy Parameter Registration and Warning for Unused Layers during the Forward Pass. These are two very useful suggestions.

@PhyllisJi
Copy link
Author

PhyllisJi commented Oct 28, 2024

Thanks for your suggestions about using Lazy Parameter Registration and Warning for Unused Layers during the Forward Pass. These are two very useful suggestions.

Thank you for your reply! Looking forward to a better Mindspore~

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants