Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Bug Report: In-Place Operation Causes Gradient Error in conv1d_step Function #51

Open
WangYLon opened this issue Sep 23, 2024 · 1 comment

Comments

@WangYLon
Copy link


Bug Report: In-Place Operation Causes Gradient Error in conv1d_step Function

Issue Description:

While training the model, I encountered a runtime error related to gradient computation caused by an in-place operation in the conv1d_step function. The error message pointed out that one of the variables needed for gradient computation was modified in-place, which broke the backpropagation process.

Here is the detailed error message:

RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation: [torch.cuda.FloatTensor [6, 4, 512]], which is output 0 of torch::autograd::CopySlices, is at version 38; expected version 36 instead.

Upon investigation, I found that the issue stems from two in-place operations within the conv1d_step function, specifically in the following lines:

conv_state.copy_(torch.roll(conv_state, shifts=-1, dims=1))  # In-place modification
conv_state[:, -1:, :] = x  # In-place modification

These in-place operations cause the gradient computation to fail in certain scenarios, especially when torch.autograd.set_detect_anomaly(True) is enabled for debugging.


Steps to Reproduce:

  1. Implement the conv1d_step function with the original code.
  2. Enable anomaly detection in PyTorch using torch.autograd.set_detect_anomaly(True).
  3. Run the training process and observe the runtime error during backpropagation.

Proposed Fix:

The issue can be resolved by avoiding in-place modifications to the conv_state. Below is the modified version of the conv1d_step function, which eliminates the in-place operations and instead uses non-in-place operations for state updates.

def conv1d_step(
    x: torch.Tensor,
    conv_state: torch.Tensor,
    conv1d_weight: torch.Tensor,
    conv1d_bias: torch.Tensor = None,
) -> tuple[torch.Tensor, torch.Tensor]:
    """
    B: batch size
    S: sequence length
    D: feature dimension
    KS: kernel size
    Args:
        x (torch.Tensor): (B, S, D)
        conv_state (torch.Tensor): (B, KS, D)
        conv1d_weight (torch.Tensor): (KS, D)
    """
    assert (
        x.shape[0] == conv_state.shape[0]
    ), f"x has batch size {x.shape[0]} but conv_state has batch size {conv_state.shape[0]}"
    assert (
        x.shape[2] == conv_state.shape[2]
    ), f"x has feature dimension {x.shape[2]} but conv_state has feature dimension {conv_state.shape[2]}"
    assert x.shape[1] == 1, f"x has sequence length {x.shape[1]} but it should be 1"

    # Use non-in-place operation for rolling the state
    conv_state = torch.roll(conv_state, shifts=-1, dims=1)
    conv_state[:, -1:, :] = x  # This is now safe as conv_state is newly assigned

    y = torch.sum(conv_state * conv1d_weight, dim=1, keepdim=True)

    if conv1d_bias is not None:
        y += conv1d_bias

    return y, conv_state

This solution replaces the in-place copy_() operation with a non-in-place torch.roll(), which ensures that the gradient computation is not interrupted.


Environment Details:

  • PyTorch Version: [2.2.2]
  • CUDA Version: [12.2 for os cuda;12.1 for torch cuda]
  • Python Version: [3.12.3]
  • Operating System: Ubuntu 20.04]

Please let me know if any further details or testing are required!


Additional Context:

This error was encountered while training an xLSTM model for pedestrian trajectory prediction. The training process failed during the backward pass due to in-place operations within the convolutional component of the xLSTM block.


@IcarusWizard
Copy link

same problem and same solution, great to see someone reported already

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants