Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

RMSNorm Forward Segmentation fault (core dumped) #1092

Open
faaany opened this issue Nov 18, 2024 · 7 comments
Open

RMSNorm Forward Segmentation fault (core dumped) #1092

faaany opened this issue Nov 18, 2024 · 7 comments
Assignees

Comments

@faaany
Copy link

faaany commented Nov 18, 2024

🐛 Describe the bug

When running the below example code on XPU:

import torch 
import torch.nn as nn

from liger_kernel.transformers.rms_norm import LigerRMSNorm

class BaseRMSNorm(nn.Module):
    def __init__(self, hidden_size, eps=1e-6):
        super().__init__()
        self.weight = nn.Parameter(torch.ones(hidden_size))
        self.variance_epsilon = eps

    def forward(self, hidden_states):
        input_dtype = hidden_states.dtype
        variance = hidden_states.pow(2).mean(-1, keepdim=True)
        hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
        return self.weight * hidden_states.to(input_dtype)

bs=2
sl=128
hd=512
device="xpu"
dtype= torch.bfloat16
offset=0.0
casting_mode="none"
in_place=False 

_tensor = torch.randn(bs, sl, hd, device=device, dtype=dtype)

h1 = _tensor.clone().requires_grad_(True)
h2 = _tensor.clone().requires_grad_(True)

# do
do = torch.randn(bs, sl, hd, device=device, dtype=dtype)

# reference (llama or gemma)
ref_rms = BaseRMSNorm(hidden_size=hd).to(device).to(dtype)
ref_o = ref_rms(h1)
ref_o.backward(do, retain_graph=True)

# triton
triton_rms = (
    LigerRMSNorm(
        hidden_size=hd, offset=offset, casting_mode=casting_mode, in_place=in_place
    )
    .to(device)
    .to(dtype)
)
triton_o = triton_rms(h2)
triton_o.backward(do, retain_graph=True)

I got a segmentation fault error. Below is the debug message with gdb:

Thread 1 "python" received signal SIGSEGV, Segmentation fault.
mlir::triton::intel::convertFp32ToBf16 (loc=loc@entry=..., rewriter=..., v=..., rounding=rounding@entry=mlir::triton::RoundingMode::RTNE) at ../../../third_party/intel/lib/TritonIntelGPUToLLVM/BF16Casts.cpp:100
100     ../../../third_party/intel/lib/TritonIntelGPUToLLVM/BF16Casts.cpp: No such file or directory.
(gdb) tb
Temporary breakpoint 1 at 0x7ffe61ee5016: file /root/.triton/llvm/llvm-ce80c80d-centos-x64/include/mlir/IR/Operation.h, line 234.

Versions

Collecting environment information...
PyTorch version: 2.5.0a0+gite84e33f
Is debug build: False
CUDA used to build PyTorch: None
ROCM used to build PyTorch: N/A

OS: Ubuntu 22.04.4 LTS (x86_64)
GCC version: (Ubuntu 11.4.0-1ubuntu1~22.04) 11.4.0
Clang version: Could not collect
CMake version: Could not collect
Libc version: glibc-2.35

Python version: 3.10.9 | packaged by conda-forge | (main, Feb 2 2023, 20:20:04) [GCC 11.3.0] (64-bit runtime)
Python platform: Linux-5.15.0-125-generic-x86_64-with-glibc2.35
Is CUDA available: False
CUDA runtime version: No CUDA
CUDA_MODULE_LOADING set to: N/A
GPU models and configuration: No CUDA
Nvidia driver version: No CUDA
cuDNN version: No CUDA
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True

CPU:
Architecture: x86_64
CPU op-mode(s): 32-bit, 64-bit
Address sizes: 52 bits physical, 57 bits virtual
Byte Order: Little Endian
CPU(s): 224
On-line CPU(s) list: 0-223
Vendor ID: GenuineIntel
Model name: Intel(R) Xeon(R) Platinum 8480+
CPU family: 6
Model: 143
Thread(s) per core: 2
Core(s) per socket: 56
Socket(s): 2
Stepping: 6
CPU max MHz: 3800.0000
CPU min MHz: 800.0000
BogoMIPS: 4000.00
Flags: fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush dts acpi mmx fxsr sse sse2 ss ht tm pbe syscall nx pdpe1gb rdtscp lm constant_tsc art arch_perfmon pebs bts rep_good nopl xtopology nonstop_tsc cpuid aperfmperf tsc_known_freq pni pclmulqdq dtes64 monitor ds_cpl vmx smx est tm2 ssse3 sdbg fma cx16 xtpr pdcm pcid dca sse4_1 sse4_2 x2apic movbe popcnt tsc_deadline_timer aes xsave avx f16c rdrand lahf_lm abm 3dnowprefetch cpuid_fault epb cat_l3 cat_l2 cdp_l3 invpcid_single intel_ppin cdp_l2 ssbd mba ibrs ibpb stibp ibrs_enhanced tpr_shadow vnmi flexpriority ept vpid ept_ad fsgsbase tsc_adjust bmi1 avx2 smep bmi2 erms invpcid cqm rdt_a avx512f avx512dq rdseed adx smap avx512ifma clflushopt clwb intel_pt avx512cd sha_ni avx512bw avx512vl xsaveopt xsavec xgetbv1 xsaves cqm_llc cqm_occup_llc cqm_mbm_total cqm_mbm_local split_lock_detect avx_vnni avx512_bf16 wbnoinvd dtherm ida arat pln pts hwp hwp_act_window hwp_epp hwp_pkg_req avx512vbmi umip pku ospke waitpkg avx512_vbmi2 gfni vaes vpclmulqdq avx512_vnni avx512_bitalg tme avx512_vpopcntdq la57 rdpid bus_lock_detect cldemote movdiri movdir64b enqcmd fsrm md_clear serialize tsxldtrk pconfig arch_lbr amx_bf16 avx512_fp16 amx_tile amx_int8 flush_l1d arch_capabilities
Virtualization: VT-x
L1d cache: 5.3 MiB (112 instances)
L1i cache: 3.5 MiB (112 instances)
L2 cache: 224 MiB (112 instances)
L3 cache: 210 MiB (2 instances)
NUMA node(s): 2
NUMA node0 CPU(s): 0-55,112-167
NUMA node1 CPU(s): 56-111,168-223
Vulnerability Gather data sampling: Not affected
Vulnerability Itlb multihit: Not affected
Vulnerability L1tf: Not affected
Vulnerability Mds: Not affected
Vulnerability Meltdown: Not affected
Vulnerability Mmio stale data: Not affected
Vulnerability Reg file data sampling: Not affected
Vulnerability Retbleed: Not affected
Vulnerability Spec rstack overflow: Not affected
Vulnerability Spec store bypass: Mitigation; Speculative Store Bypass disabled via prctl and seccomp
Vulnerability Spectre v1: Mitigation; usercopy/swapgs barriers and __user pointer sanitization
Vulnerability Spectre v2: Mitigation; Enhanced / Automatic IBRS; IBPB conditional; RSB filling; PBRSB-eIBRS SW sequence; BHI BHI_DIS_S
Vulnerability Srbds: Not affected
Vulnerability Tsx async abort: Not affected

Versions of relevant libraries:
[pip3] clip-anytorch==2.6.0
[pip3] dctorch==0.1.2
[pip3] intel_extension_for_pytorch==2.5.10+git9d489a8
[pip3] mypy-extensions==1.0.0
[pip3] numpy==1.26.4
[pip3] nvidia-cublas-cu12==12.1.3.1
[pip3] nvidia-cuda-cupti-cu12==12.1.105
[pip3] nvidia-cuda-nvrtc-cu12==12.1.105
[pip3] nvidia-cuda-runtime-cu12==12.1.105
[pip3] nvidia-cudnn-cu12==9.1.0.70
[pip3] nvidia-cufft-cu12==11.0.2.54
[pip3] nvidia-curand-cu12==10.3.2.106
[pip3] nvidia-cusolver-cu12==11.4.5.107
[pip3] nvidia-cusparse-cu12==12.1.0.106
[pip3] nvidia-nccl-cu12==2.20.5
[pip3] nvidia-nvjitlink-cu12==12.6.77
[pip3] nvidia-nvtx-cu12==12.1.105
[pip3] pytorch-triton-xpu==3.1.0+91b14bf559
[pip3] torch==2.5.0a0+gite84e33f
[pip3] torchaudio==2.5.0a0+56bc006
[pip3] torchdata==0.9.0
[pip3] torchdiffeq==0.2.4
[pip3] torchpippy==0.2.0
[pip3] torchsde==0.2.6
[pip3] torchvision==0.20.0a0+8e8a208
[pip3] triton==3.1.0
[conda] clip-anytorch 2.6.0 pypi_0 pypi
[conda] dctorch 0.1.2 pypi_0 pypi
[conda] intel-extension-for-pytorch 2.5.10+git9d489a8 pypi_0 pypi
[conda] numpy 1.26.4 pypi_0 pypi
[conda] nvidia-cublas-cu12 12.1.3.1 pypi_0 pypi
[conda] nvidia-cuda-cupti-cu12 12.1.105 pypi_0 pypi
[conda] nvidia-cuda-nvrtc-cu12 12.1.105 pypi_0 pypi
[conda] nvidia-cuda-runtime-cu12 12.1.105 pypi_0 pypi
[conda] nvidia-cudnn-cu12 9.1.0.70 pypi_0 pypi
[conda] nvidia-cufft-cu12 11.0.2.54 pypi_0 pypi
[conda] nvidia-curand-cu12 10.3.2.106 pypi_0 pypi
[conda] nvidia-cusolver-cu12 11.4.5.107 pypi_0 pypi
[conda] nvidia-cusparse-cu12 12.1.0.106 pypi_0 pypi
[conda] nvidia-nccl-cu12 2.20.5 pypi_0 pypi
[conda] nvidia-nvjitlink-cu12 12.6.77 pypi_0 pypi
[conda] nvidia-nvtx-cu12 12.1.105 pypi_0 pypi
[conda] pytorch-triton-xpu 3.1.0+91b14bf559 pypi_0 pypi
[conda] torch 2.5.0a0+gite84e33f pypi_0 pypi
[conda] torchaudio 2.5.0a0+56bc006 pypi_0 pypi
[conda] torchdata 0.9.0 pypi_0 pypi
[conda] torchdiffeq 0.2.4 pypi_0 pypi
[conda] torchpippy 0.2.0 pypi_0 pypi
[conda] torchsde 0.2.6 pypi_0 pypi
[conda] torchvision 0.20.0a0+8e8a208 pypi_0 pypi
[conda] triton 3.1.0 pypi_0 pypi

@Stonepia
Copy link
Contributor

From your env, seems that you have two triton installed, and they got conflict:

[pip3] pytorch-triton-xpu==3.1.0+91b14bf559
[pip3] triton==3.1.0

Please try the following:

pip uninstall triton

We only need the pytorch-triton-xpu, no need the triton package.

@Stonepia Stonepia self-assigned this Nov 19, 2024
@faaany
Copy link
Author

faaany commented Nov 21, 2024

After several trials, I found that I need to uninstall both triton and pytorch-triton-xpu from my environment, and then install pytorch-triton-xpu.

When I try this:

pip install triton
pip install --pre pytorch-triton-xpu==3.1.0+91b14bf559 --index-url https://download.pytorch.org/whl/nightly/xpu
pip uninstall triton

it won't work.

@faaany
Copy link
Author

faaany commented Nov 21, 2024

But even when I installed triton correctly, I still get Segmentation fault (core dumped).

@Stonepia
Copy link
Contributor

Stonepia commented Nov 21, 2024

I did further check, this should be an issue in Triton:

@triton.jit
def _rms_norm_forward_kernel(
    Y_ptr,
    Y_row_stride,
    X_ptr,
    X_row_stride,
    W_ptr,
    W_row_stride,
    RSTD_ptr,
    RSTD_row_stride,
    n_cols,
    eps,
    offset,
    casting_mode: tl.constexpr,  # constexpr so the `if` blocks can be optimized out
    BLOCK_SIZE: tl.constexpr,
):

    row_idx = tl.program_id(0)
    col_offsets = tl.arange(0, BLOCK_SIZE)
    mask = col_offsets < n_cols

    Y_ptr += row_idx * Y_row_stride
    X_ptr += row_idx * X_row_stride
    RSTD_ptr += row_idx * RSTD_row_stride

    X_row = tl.load(X_ptr + col_offsets, mask=mask, other=0)
    X_row_dtype = X_row.dtype
    W_row = tl.load(W_ptr + col_offsets, mask=mask, other=0)

    if casting_mode == _CASTING_MODE_NONE:
        eps = eps.to(X_row_dtype)
        offset = offset.to(X_row_dtype)

    mean_square = tl.sum(X_row * X_row, axis=0) / n_cols
    rstd = rsqrt(mean_square + eps)

    # This line will get error because it trying to convert bf16
    tl.store(RSTD_ptr, rstd)


_rms_norm_forward_kernel[(n_rows,)](
    Y,
    Y.stride(0),
    X,
    X.stride(0),
    W,
    W.stride(0),
    RSTD,
    RSTD.stride(0),
    n_cols,
    eps,
    offset,
    casting_mode,
    BLOCK_SIZE=BLOCK_SIZE,
    num_warps=num_warps,
)

It will throw segment fault when converting to llvm:

// -----// IR Dump Before ConvertTritonIntelGPUToLLVM (convert-triton-intel-gpu-to-llvm) ('builtin.module' operation) //----- //
#blocked = #triton_gpu.blocked<{sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>
#loc = loc("/home/pt-gpu/4T-4652/tongsu/torch-xpu-test/debug.py":23:0)
#loc1 = loc(unknown)
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, triton_gpu.shared = 16 : i32, triton_gpu.target = "xpu", "triton_gpu.threads-per-warp" = 32 : i32, triton_intel_gpu.min_sg_size = 16 : i32, triton_intel_gpu.support_bf16_conversion, triton_intel_gpu.support_dpas, triton_intel_gpu.support_sg_2d_block} {
  tt.func public @_rms_norm_forward_kernel(%arg0: !tt.ptr<bf16> {tt.divisibility = 16 : i32} loc("/home/pt-gpu/4T-4652/tongsu/torch-xpu-test/debug.py":23:0), %arg1: i32 {tt.divisibility = 16 : i32} loc("/home/pt-gpu/4T-4652/tongsu/torch-xpu-test/debug.py":23:0), %arg2: !tt.ptr<bf16> {tt.divisibility = 16 : i32} loc("/home/pt-gpu/4T-4652/tongsu/torch-xpu-test/debug.py":23:0), %arg3: i32 {tt.divisibility = 16 : i32} loc("/home/pt-gpu/4T-4652/tongsu/torch-xpu-test/debug.py":23:0), %arg4: !tt.ptr<bf16> {tt.divisibility = 16 : i32} loc("/home/pt-gpu/4T-4652/tongsu/torch-xpu-test/debug.py":23:0), %arg5: !tt.ptr<bf16> {tt.divisibility = 16 : i32} loc("/home/pt-gpu/4T-4652/tongsu/torch-xpu-test/debug.py":23:0), %arg6: i32 {tt.divisibility = 16 : i32} loc("/home/pt-gpu/4T-4652/tongsu/torch-xpu-test/debug.py":23:0), %arg7: f32 loc("/home/pt-gpu/4T-4652/tongsu/torch-xpu-test/debug.py":23:0), %arg8: f32 loc("/home/pt-gpu/4T-4652/tongsu/torch-xpu-test/debug.py":23:0), %arg9: !llvm.ptr<3> loc("/home/pt-gpu/4T-4652/tongsu/torch-xpu-test/debug.py":23:0)) attributes {noinline = false} {
    %cst = arith.constant dense<0.000000e+00> : tensor<512xbf16, #blocked> loc(#loc1)
    %0 = tt.get_program_id x : i32 loc(#loc2)
    %1 = tt.make_range {end = 512 : i32, start = 0 : i32} : tensor<512xi32, #blocked> loc(#loc3)
    %2 = tt.splat %arg6 : i32 -> tensor<512xi32, #blocked> loc(#loc4)
    %3 = arith.cmpi slt, %1, %2 : tensor<512xi32, #blocked> loc(#loc4)
    %4 = arith.muli %0, %arg3 : i32 loc(#loc5)
    %5 = tt.addptr %arg2, %4 : !tt.ptr<bf16>, i32 loc(#loc6)
    %6 = tt.addptr %arg5, %0 : !tt.ptr<bf16>, i32 loc(#loc7)
    %7 = tt.splat %5 : !tt.ptr<bf16> -> tensor<512x!tt.ptr<bf16>, #blocked> loc(#loc8)
    %8 = tt.addptr %7, %1 : tensor<512x!tt.ptr<bf16>, #blocked>, tensor<512xi32, #blocked> loc(#loc8)
    %9 = tt.load %8, %3, %cst : tensor<512x!tt.ptr<bf16>, #blocked> loc(#loc9)
    %10 = arith.truncf %arg7 : f32 to bf16 loc(#loc10)
    %11 = arith.mulf %9, %9 : tensor<512xbf16, #blocked> loc(#loc11)
    %12 = arith.extf %11 : tensor<512xbf16, #blocked> to tensor<512xf32, #blocked> loc(#loc21)
    %13 = "tt.reduce"(%12) <{axis = 0 : i32}> ({
    ^bb0(%arg10: f32 loc(unknown), %arg11: f32 loc(unknown)):
      %20 = arith.addf %arg10, %arg11 : f32 loc(#loc24)
      tt.reduce.return %20 : f32 loc(#loc22)
    }) {allocation.offset = 0 : i32} : (tensor<512xf32, #blocked>) -> f32 loc(#loc22)
    %14 = arith.sitofp %arg6 : i32 to f32 loc(#loc16)
    %15 = arith.divf %13, %14 : f32 loc(#loc16)
    %16 = arith.extf %10 : bf16 to f32 loc(#loc17)
    %17 = arith.addf %15, %16 : f32 loc(#loc17)
    %18 = tt.extern_elementwise %17 {libname = "", libpath = "", pure = true, symbol = "__imf_rsqrtf"} : (f32) -> f32 loc(#loc18)
    %19 = arith.truncf %18 : f32 to bf16 loc(#loc19)
    tt.store %6, %19 : !tt.ptr<bf16> loc(#loc19)
    tt.return loc(#loc20)
  } loc(#loc)
} loc(#loc)


Thread 1 "python" received signal SIGSEGV, Segmentation fault.
mlir::triton::intel::convertFp32ToBf16 (loc=loc@entry=..., rewriter=..., v=..., rounding=rounding@entry=mlir::triton::RoundingMode::RTNE) at ../../../third_party/intel/lib/TritonIntelGPUToLLVM/BF16Casts.cpp:100
100     ../../../third_party/intel/lib/TritonIntelGPUToLLVM/BF16Casts.cpp: No such file or directory.
Triton Failed Commit: pytorch-triton-xpu   3.1.0+91b14bf559
Triton pass: 1764e542bbefe6e2cfafad882c61a5a7f7abd1cc 

This issue is fixed in Triton's latest branch.
@whitneywhtsang Could I ask if you know which commit fixed the issue?
@guangyey @etaf May I ask if you know about this Triton change? Do we have plan to upgrade the Triton commit?

@Stonepia
Copy link
Contributor

Attached reproducer for anyone who in the future may interest in it.
repro.zip

@guangyey
Copy link
Contributor

@Stonepia we have a plan to update triton commit pin soon, see pytorch/pytorch#137886

@whitneywhtsang
Copy link

@whitneywhtsang Could I ask if you know which commit fixed the issue?

It is fixed by intel/intel-xpu-backend-for-triton@9b5b553.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

4 participants