diff --git a/tests/kernels/test_attention_selector.py b/tests/kernels/test_attention_selector.py
index 04c56a1e6a3bd..40516320b0a46 100644
--- a/tests/kernels/test_attention_selector.py
+++ b/tests/kernels/test_attention_selector.py
@@ -4,7 +4,7 @@
 import torch
 
 from tests.kernels.utils import override_backend_env_variable
-from vllm.attention.selector import which_attn_to_use
+from vllm.attention.selector import get_attn_backend
 from vllm.platforms.cpu import CpuPlatform
 from vllm.platforms.cuda import CudaPlatform
 from vllm.platforms.openvino import OpenVinoPlatform
@@ -16,78 +16,75 @@
     "name", ["TORCH_SDPA", "ROCM_FLASH", "XFORMERS", "FLASHINFER", "OPENVINO"])
 @pytest.mark.parametrize("device", ["cpu", "openvino", "hip", "cuda"])
 def test_env(name: str, device: str, monkeypatch):
-    """Test that the attention selector can be set via environment variable."""
+    """Test that the attention selector can be set via environment variable.
+    Note that we do not test FlashAttn because it is the default backend.
+    """
 
     override_backend_env_variable(monkeypatch, name)
 
     if device == "cpu":
         with patch("vllm.attention.selector.current_platform", CpuPlatform()):
-            backend = which_attn_to_use(16, torch.float16, torch.float16, 16,
-                                        False)
-        assert backend == "vllm.attention.backends.torch_sdpa.TorchSDPABackend"
+            backend = get_attn_backend(16, torch.float16, torch.float16, 16,
+                                       False)
+        assert backend.get_name() == "TORCH_SDPA"
     elif device == "hip":
         with patch("vllm.attention.selector.current_platform", RocmPlatform()):
-            backend = which_attn_to_use(16, torch.float16, torch.float16, 16,
-                                        False)
-        assert backend == "vllm.attention.backends.rocm_flash_attn.ROCmFlashAttentionBackend"  # noqa: E501
+            backend = get_attn_backend(16, torch.float16, torch.float16, 16,
+                                       False)
+        assert backend.get_name() == "ROCM_FLASH"
     elif device == "openvino":
         with patch("vllm.attention.selector.current_platform",
                    OpenVinoPlatform()):
-            backend = which_attn_to_use(16, torch.float16, torch.float16, 16,
-                                        False)
-        assert backend == "vllm.attention.backends.openvino.OpenVINOAttentionBackend"  # noqa: E501
+            backend = get_attn_backend(16, torch.float16, torch.float16, 16,
+                                       False)
+        assert backend.get_name() == "OPENVINO"
     else:
         with patch("vllm.attention.selector.current_platform", CudaPlatform()):
-            backend = which_attn_to_use(16, torch.float16, torch.float16, 16,
-                                        False)
-        if name == "FLASHINFER":
-            assert backend == "vllm.attention.backends.flashinfer.FlashInferBackend"  # noqa: E501
-        if name == "XFORMERS":
-            assert backend == "vllm.attention.backends.xformers.XFormersBackend"
-        else:
-            assert backend == "vllm.attention.backends.flash_attn.FlashAttentionBackend"  # noqa: E501
+            backend = get_attn_backend(16, torch.float16, torch.float16, 16,
+                                       False)
+        assert backend.get_name() == name
 
 
 def test_flash_attn(monkeypatch):
     """Test FlashAttn validation."""
     # TODO: When testing for v1, pipe in `use_v1` as an argument to
-    # which_attn_to_use
+    # get_attn_backend
 
     override_backend_env_variable(monkeypatch, STR_FLASH_ATTN_VAL)
 
     # Unsupported CUDA arch
     with patch("torch.cuda.get_device_capability", return_value=(7, 5)):
-        backend = which_attn_to_use(16, torch.float16, None, 16, False)
-        assert backend != "vllm.attention.backends.flash_attn.FlashAttentionBackend"  # noqa: E501
+        backend = get_attn_backend(16, torch.float16, None, 16, False)
+        assert backend.get_name() != STR_FLASH_ATTN_VAL
 
     # Unsupported data type
-    backend = which_attn_to_use(16, torch.float8_e4m3fn, None, 16, False)
-    assert backend != "vllm.attention.backends.flash_attn.FlashAttentionBackend"
+    backend = get_attn_backend(16, torch.float8_e4m3fn, None, 16, False)
+    assert backend.get_name() != STR_FLASH_ATTN_VAL
 
     # Unsupported kv cache data type
-    backend = which_attn_to_use(16, torch.float16, "fp8", 16, False)
-    assert backend != "vllm.attention.backends.flash_attn.FlashAttentionBackend"
+    backend = get_attn_backend(16, torch.float16, "fp8", 16, False)
+    assert backend.get_name() != STR_FLASH_ATTN_VAL
 
     # Unsupported block size
-    backend = which_attn_to_use(16, torch.float16, None, 8, False)
-    assert backend != "vllm.attention.backends.flash_attn.FlashAttentionBackend"
+    backend = get_attn_backend(16, torch.float16, None, 8, False)
+    assert backend.get_name() != STR_FLASH_ATTN_VAL
 
     # flash-attn is not installed
     with patch.dict('sys.modules', {'vllm_flash_attn': None}):
-        backend = which_attn_to_use(16, torch.float16, None, 16, False)
-        assert backend != "vllm.attention.backends.flash_attn.FlashAttentionBackend"  # noqa: E501
+        backend = get_attn_backend(16, torch.float16, None, 16, False)
+        assert backend.get_name() != STR_FLASH_ATTN_VAL
 
     # Unsupported head size
-    backend = which_attn_to_use(17, torch.float16, None, 16, False)
-    assert backend != "vllm.attention.backends.flash_attn.FlashAttentionBackend"
+    backend = get_attn_backend(17, torch.float16, None, 16, False)
+    assert backend.get_name() != STR_FLASH_ATTN_VAL
 
     # Attention-free models should bypass env and use PlaceholderAttention
-    backend = which_attn_to_use(16, torch.float16, torch.float16, 16, True)
-    assert backend != "vllm.attention.backends.flash_attn.FlashAttentionBackend"
+    backend = get_attn_backend(16, torch.float16, torch.float16, 16, True)
+    assert backend.get_name() != STR_FLASH_ATTN_VAL
 
 
 def test_invalid_env(monkeypatch):
     """Throw an exception if the backend name is invalid."""
     override_backend_env_variable(monkeypatch, STR_INVALID_VAL)
     with pytest.raises(ValueError):
-        which_attn_to_use(16, torch.float16, None, 16, False)
+        get_attn_backend(16, torch.float16, None, 16, False)
diff --git a/vllm/attention/selector.py b/vllm/attention/selector.py
index ca6e3f6b63253..0ff007c87b1c9 100644
--- a/vllm/attention/selector.py
+++ b/vllm/attention/selector.py
@@ -114,25 +114,12 @@ def _cached_get_attn_backend(
             BlocksparseFlashAttentionBackend)
         return BlocksparseFlashAttentionBackend
 
-    attention_cls = which_attn_to_use(head_size, dtype, kv_cache_dtype,
-                                      block_size, is_attention_free, use_v1)
-    assert attention_cls != "", (
-        f"Invalid attention backend for {current_platform.device_name}")
-
-    return resolve_obj_by_qualname(attention_cls)
-
-
-def which_attn_to_use(head_size: int,
-                      dtype: torch.dtype,
-                      kv_cache_dtype: Optional[str],
-                      block_size: int,
-                      is_attention_free: bool,
-                      use_v1: bool = False) -> str:
-    """Returns which flash attention backend to use."""
     # If there are no attention layers (e.g. we are running Mamba),
     # use the placeholder NO_ATTENTION
     if is_attention_free:
-        return "vllm.attention.backends.placeholder_attn.PlaceholderAttentionBackend"  # noqa: E501
+        from vllm.attention.backends.placeholder_attn import (
+            PlaceholderAttentionBackend)
+        return PlaceholderAttentionBackend
 
     # Check whether a particular choice of backend was
     # previously forced.
@@ -151,9 +138,12 @@ def which_attn_to_use(head_size: int,
             selected_backend = backend_name_to_enum(backend_by_env_var)
 
     # get device-specific attn_backend
-    return current_platform.get_attn_backend_cls(selected_backend, head_size,
-                                                 dtype, kv_cache_dtype,
-                                                 block_size, use_v1)
+    attention_cls = current_platform.get_attn_backend_cls(
+        selected_backend, head_size, dtype, kv_cache_dtype, block_size, use_v1)
+    if not attention_cls:
+        raise ValueError(
+            f"Invalid attention backend for {current_platform.device_name}")
+    return resolve_obj_by_qualname(attention_cls)
 
 
 @contextmanager
diff --git a/vllm/platforms/neuron.py b/vllm/platforms/neuron.py
index a4bbbd27c8a89..1da2d0f6c5cce 100644
--- a/vllm/platforms/neuron.py
+++ b/vllm/platforms/neuron.py
@@ -1,8 +1,10 @@
 from typing import TYPE_CHECKING, Optional
 
+import torch
+
 from vllm.logger import init_logger
 
-from .interface import Platform, PlatformEnum
+from .interface import Platform, PlatformEnum, _Backend
 
 if TYPE_CHECKING:
     from vllm.config import VllmConfig