Skip to content

Commit

Permalink
Upgrade pytorch to 2.5.0 (#347)
Browse files Browse the repository at this point in the history
  • Loading branch information
guocuimi authored Oct 17, 2024
1 parent 2550fba commit 7838ca5
Show file tree
Hide file tree
Showing 6 changed files with 20 additions and 26 deletions.
7 changes: 1 addition & 6 deletions .github/workflows/build_wheel.yml
Original file line number Diff line number Diff line change
Expand Up @@ -25,12 +25,7 @@ jobs:
matrix:
python: ["3.8", "3.9", "3.10", "3.11", "3.12"]
cuda: ["11.8", "12.1", "12.4"]
torch: ["2.2.2", "2.3.1", "2.4.1"]
exclude:
- cuda: "12.4"
torch: "2.3.1"
- cuda: "12.4"
torch: "2.2.2"
torch: ["2.4.1", "2.5.0"]
runs-on: [self-hosted, linux, release]
env:
PYTHON_VERSION: ${{ matrix.python }}
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/package_test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ jobs:
matrix:
python: ["3.12"]
cuda: ["12.4"]
torch: ["2.4.1"]
torch: ["2.5.0"]
runs-on: [self-hosted, linux, build]
env:
PYTHON_VERSION: ${{ matrix.python }}
Expand Down
4 changes: 2 additions & 2 deletions .github/workflows/publish_wheel.yml
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,8 @@ jobs:
fail-fast: false
matrix:
python: ["3.8", "3.9", "3.10", "3.11", "3.12"]
cuda: ["12.1"]
torch: ["2.4.1"]
cuda: ["12.4"]
torch: ["2.5.0"]
runs-on: [self-hosted, linux, release]
env:
PYTHON_VERSION: ${{ matrix.python }}
Expand Down
4 changes: 2 additions & 2 deletions .github/workflows/release_test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,8 @@ jobs:
fail-fast: false
matrix:
python: ["3.8", "3.9", "3.10", "3.11", "3.12"]
cuda: ["12.1"]
torch: ["2.4.1"]
cuda: ["12.4"]
torch: ["2.5.0"]
runs-on: [self-hosted, linux, release]
env:
PYTHON_VERSION: ${{ matrix.python }}
Expand Down
20 changes: 10 additions & 10 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -194,25 +194,25 @@ if (DEFINED ENV{LIBTORCH_ROOT})
else()
include(FetchContent)
if (CUDAToolkit_VERSION VERSION_GREATER_EQUAL 12.4)
# download libtorch 2.4.1 with cuda 12.4 from pytorch.org
# download libtorch 2.5.0 with cuda 12.4 from pytorch.org
if (USE_CXX11_ABI)
set(LIBTORCH_URL "https://download.pytorch.org/libtorch/cu124/libtorch-cxx11-abi-shared-with-deps-2.4.1%2Bcu124.zip")
set(LIBTORCH_URL "https://download.pytorch.org/libtorch/cu124/libtorch-cxx11-abi-shared-with-deps-2.5.0%2Bcu124.zip")
else()
set(LIBTORCH_URL "https://download.pytorch.org/libtorch/cu124/libtorch-shared-with-deps-2.4.1%2Bcu124.zip")
set(LIBTORCH_URL "https://download.pytorch.org/libtorch/cu124/libtorch-shared-with-deps-2.5.0%2Bcu124.zip")
endif()
elseif(CUDAToolkit_VERSION VERSION_GREATER_EQUAL 12.1)
# download libtorch 2.4.1 with cuda 12.1 from pytorch.org
# download libtorch 2.5.0 with cuda 12.1 from pytorch.org
if (USE_CXX11_ABI)
set(LIBTORCH_URL "https://download.pytorch.org/libtorch/cu121/libtorch-cxx11-abi-shared-with-deps-2.4.1%2Bcu121.zip")
set(LIBTORCH_URL "https://download.pytorch.org/libtorch/cu121/libtorch-cxx11-abi-shared-with-deps-2.5.0%2Bcu121.zip")
else()
set(LIBTORCH_URL "https://download.pytorch.org/libtorch/cu121/libtorch-shared-with-deps-2.4.1%2Bcu121.zip")
set(LIBTORCH_URL "https://download.pytorch.org/libtorch/cu121/libtorch-shared-with-deps-2.5.0%2Bcu121.zip")
endif()
elseif(CUDAToolkit_VERSION VERSION_GREATER_EQUAL 11.8)
# download libtorch 2.4.1 with cuda 11.8 from pytorch.org
# download libtorch 2.5.0 with cuda 11.8 from pytorch.org
if (USE_CXX11_ABI)
set(LIBTORCH_URL "https://download.pytorch.org/libtorch/cu118/libtorch-cxx11-abi-shared-with-deps-2.4.1%2Bcu118.zip")
set(LIBTORCH_URL "https://download.pytorch.org/libtorch/cu118/libtorch-cxx11-abi-shared-with-deps-2.5.0%2Bcu118.zip")
else()
set(LIBTORCH_URL "https://download.pytorch.org/libtorch/cu118/libtorch-shared-with-deps-2.4.1%2Bcu118.zip")
set(LIBTORCH_URL "https://download.pytorch.org/libtorch/cu118/libtorch-shared-with-deps-2.5.0%2Bcu118.zip")
endif()
else()
# error out if cuda version is not supported
Expand All @@ -232,7 +232,7 @@ else()
FetchContent_MakeAvailable(libtorch)

find_package(Torch REQUIRED PATHS ${libtorch_SOURCE_DIR} NO_DEFAULT_PATH)
message(STATUS "Downloading and using libtorch 2.4.1 for cuda ${CUDA_VERSION} at ${libtorch_SOURCE_DIR}")
message(STATUS "Downloading and using libtorch 2.5.0 for cuda ${CUDA_VERSION} at ${libtorch_SOURCE_DIR}")
endif()

# check if USE_CXX11_ABI is set correctly
Expand Down
9 changes: 4 additions & 5 deletions src/memory/memory.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,8 @@ int64_t max_memory_allocated(const torch::Device& device) {
const auto device_index =
device.has_index() ? device.index() : current_device();
const auto stats = CUDACachingAllocator::getDeviceStats(device_index);
return stats
.allocated_bytes[static_cast<size_t>(
CUDACachingAllocator::StatType::AGGREGATE)]
.peak;
// StatType::AGGREGATE
return stats.allocated_bytes[0].peak;
}

// returns the total memory in bytes of the device.
Expand All @@ -44,7 +42,8 @@ int64_t available_memory(const torch::Device& device) {
<< "Failed to set device to " << device_index;
size_t free = 0;
size_t total = 0;
CHECK(cudaMemGetInfo(&free, &total) == cudaSuccess) << "Failed to get memory info for " << device;
CHECK(cudaMemGetInfo(&free, &total) == cudaSuccess)
<< "Failed to get memory info for " << device;
return static_cast<int64_t>(free);
}

Expand Down

0 comments on commit 7838ca5

Please sign in to comment.