Skip to content

Commit

Permalink
2025-02-06 nightly release (9a343a0)
Browse files Browse the repository at this point in the history
  • Loading branch information
pytorchbot committed Feb 6, 2025
1 parent 82ee098 commit a47adae
Show file tree
Hide file tree
Showing 11 changed files with 934 additions and 4 deletions.
8 changes: 8 additions & 0 deletions .github/metrics/requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

click
requests
178 changes: 178 additions & 0 deletions .github/metrics/scrape.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,178 @@
#!/bin/bash
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import dataclasses
import os
from dataclasses import dataclass
from datetime import datetime, timedelta
from typing import Set

import click
import requests


@dataclass
class PullRequestInfo:
title: str
number: int
closed_at: datetime
labels: Set[str]
base_ref: str

def cleaned(self):
new_labels = [x for x in self.labels if x not in ["cla signed"]]
return dataclasses.replace(self, labels=new_labels)

def merged(self):
return "Merged" in self.labels

def tostr(self):
return f"{self.closed_at} {self.title} (#{self.number})"


class GitHub:
# GitHub API base URL
API_URL = "https://api.github.com/repos"

def __init__(self) -> None:
# Replace with your GitHub token and the repo details
if "GITHUB_TOKEN" not in os.environ:
raise Exception("GITHUB_TOKEN not set")
self.token = os.environ["GITHUB_TOKEN"]

# Headers for authentication
self.headers = {
"Authorization": f"token {self.token}",
"Accept": "application/vnd.github.v3+json",
}

def fetch_closed_prs(self, repo, start_date, end_date, target_branch):
"""
Fetches merged PRs within a specified date range that were merged into a given branch.
Args:
repo: Repository in 'owner/repo' format.
start_date: Start date in 'YYYY-MM-DD' format.
end_date: End date in 'YYYY-MM-DD' format.
target_branch: The target branch PRs were merged into.
Returns:
List of merged PRs with details.
"""

# Convert dates to datetime objects for comparison
start_date = datetime.strptime(start_date, "%Y-%m-%d")
end_date = datetime.strptime(end_date, "%Y-%m-%d")

# Endpoint to fetch closed PRs
url = f"{self.API_URL}/{repo}/pulls?state=closed&per_page=100&sort=updated&direction=desc"
closed_prs = []

while url:
print(url)
response = requests.get(url, headers=self.headers)
if response.status_code != 200:
print(f"Error: {response.status_code} - {response.text}")
break

prs = response.json()
for pr in prs:
if pr["closed_at"] is not None: # Check if the PR was closed
closed_at = datetime.strptime(pr["closed_at"], "%Y-%m-%dT%H:%M:%SZ")
# Check if the merged date is within the specified range and the target branch matches
if (
start_date <= closed_at <= end_date
and pr["base"]["ref"] == target_branch
):
# Fetch labels for the PR
labels = [label["name"] for label in pr["labels"]]
closed_prs.append(
PullRequestInfo(
pr["title"],
int(pr["number"]),
closed_at,
set(labels),
pr["base"]["ref"],
)
)

# Check for pagination
if "next" in response.links:
url = response.links["next"]["url"]
else:
url = None

return sorted(closed_prs, key=lambda x: x.closed_at, reverse=True)


@click.group()
def cli() -> None:
pass


@cli.command()
@click.option(
"--repo",
default="pytorch/fbgemm",
help="Repository in owner/repo format (default: pytorch/fbgemm).",
)
@click.option(
"--last",
default=7,
type=int,
help="Number of days to look back (e.g., 7 for the last 7 days).",
)
@click.option(
"--branch", default="main", help="Target branch to filter by (default: main)."
)
@click.option(
"--labels",
help="Comma-separated list of labels to filter PRs by (e.g., bug,enhancement).",
)
def fetch(repo, last, branch, labels):
"""
Fetches merged PRs in a given repository within a date range, filtered by target branch and labels.
"""
# Calculate date range based on --last flag
end_date = datetime.utcnow()
start_date = end_date - timedelta(days=last)

# Convert dates to strings
start_date_str = start_date.strftime("%Y-%m-%d")
end_date_str = end_date.strftime("%Y-%m-%d")

# Parse labels
labels = set(labels.split(",") if labels else [])

# Fetch merged PRs
merged_prs = [
pr
for pr in GitHub().fetch_closed_prs(repo, start_date_str, end_date_str, branch)
if pr.merged()
]

# Filter by labels
filtered_prs = [
pr for pr in merged_prs if ((pr.labels & labels) if labels else True)
]
# Output results
if filtered_prs:
print(
f"Found {len(filtered_prs)} merged PRs in '{repo}' between {start_date_str} and {end_date_str} into branch '{branch}':"
)
for pr in filtered_prs:
print(pr.tostr())
else:
print(
f"No merged PRs found in '{repo}' between {start_date_str} and {end_date_str} into branch '{branch}'."
)


if __name__ == "__main__":
cli()
5 changes: 4 additions & 1 deletion fbgemm_gpu/experimental/gen_ai/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,10 @@ endif()
# CUDA-specific sources
file(GLOB_RECURSE experimental_gen_ai_cpp_source_files_cuda
src/quantize/cutlass_extensions/*.cu
src/quantize/cutlass_extensions/**/*.cu)
src/quantize/cutlass_extensions/**/*.cu
src/quantize/fast_gemv/*.cu
src/quantize/fast_gemv/**/*.cu
src/quantize/fast_gemv/**/*.cuh)

# HIP-specific sources
file(GLOB_RECURSE experimental_gen_ai_cpp_source_files_hip
Expand Down
32 changes: 32 additions & 0 deletions fbgemm_gpu/experimental/gen_ai/bench/quantize_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -370,6 +370,38 @@ def cuda(self) -> bool:
return True


@register_quantize_op
class BF16OSSFastGemv(QuantizeOpBase):
"""
BF16 OSS fast gemv kernel.
"""

def quantize(self, x, w):
# dummy quantize
return x, w

def compute(self, x, w):
out = torch.ops.fbgemm.bf16_fast_gemv(x, w)
return out

def quantize_and_compute(self, x, w):
x, w = self.quantize(x, w)
return self.compute(x, w)

@property
def name(self) -> str:
return "bf16_oss_fast_gemv"

@property
def hip(self) -> bool:
# This implementation is specific to cublas.
return False

@property
def cuda(self) -> bool:
return True


@register_quantize_op
class FP8CublasRowwiseGemm(QuantizeOpBase):
"""
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,138 @@
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/

#include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h>
#include <c10/core/ScalarType.h>
#include <c10/cuda/CUDAGuard.h>
#include <cuda_bf16.h>

#include "include/fast_gemv.cuh"

namespace fbgemm_gpu {

// The heuristics are derived by sweeping over 4 different
// problem sizes we care about and selected the best elapsed time/bw
// combination. See more in
// deeplearning/fbgemm/fbgemm_gpu/experimental/gen_ai/src/quantize/fast_gemv/sweep_utils.py
dim3 get_best_block_dim(int m, int n, int k) {
if (m == 1 && n == 1280 && k == 8192) {
return dim3(128, 4);
} else if (m == 1 && n == 8192 && k == 1024) {
return dim3(32, 8);
} else if (m == 1 && n == 7168 && k == 8192) {
return dim3(256, 1);
} else if (m == 1 && n == 8192 && k == 3584) {
return dim3(64, 2);
} else {
// Default block dimensions
return dim3(32, 4);
}
}

at::Tensor bf16_fast_gemv(at::Tensor X, at::Tensor W) {
// X: M x K
// W: N x K
auto m = X.size(0);
auto n = W.size(0);
auto k = W.size(1);

TORCH_CHECK(X.is_cuda() && X.is_contiguous());
TORCH_CHECK(W.is_cuda() && W.is_contiguous());

dim3 block_dim = get_best_block_dim(m, n, k);

TORCH_CHECK(
n % block_dim.y == 0,
"Invalid block dimensions: n (",
n,
") must be divisible by block_dim.y (",
block_dim.y,
"). Received n: ",
n,
", block_dim.y: ",
block_dim.y,
" Please either use a `n` which is divisible by `block_dim.y`, or update "
"`get_best_block_dim()` heuristics to choose another `block_dim.y`. "
" All current params - m: ",
m,
", n: ",
n,
", k: ",
k,
", block_dim.x: ",
block_dim.x,
", block_dim.y: ",
block_dim.y,
".");
TORCH_CHECK(
k % block_dim.x == 0,
"Invalid block dimensions: k (",
k,
") must be divisible by block_dim.x (",
block_dim.x,
"). Received k: ",
k,
", block_dim.x: ",
block_dim.x,
" Please either use a `k` which is divisible by `block_dim.x`, or update "
"`get_best_block_dim()` heuristics to choose another `block_dim.x`."
" All current params - m: ",
m,
", n: ",
n,
", k: ",
k,
", block_dim.x: ",
block_dim.x,
", block_dim.y: ",
block_dim.y,
".");
TORCH_CHECK(
(k / block_dim.x) % 8 == 0,
"Invalid num_per_thread: (",
k / block_dim.x,
") must be divisible by 8.",
" Received k: ",
k,
", block_dim.x: ",
block_dim.x,
" Please either use a `k` that `k / block_dim.x` that is divisble by 8, or update "
"`get_best_block_dim()` heuristics to choose another `block_dim.x`."
" All current params - m: ",
m,
", n: ",
n,
", k: ",
k,
", block_dim.x: ",
block_dim.x,
", block_dim.y: ",
block_dim.y,
".");

dim3 grid_dim(1, n / block_dim.y);
unsigned int num_per_thread = k / block_dim.x;

auto stream = at::cuda::getCurrentCUDAStream();

auto Y = at::empty({m, n}, X.options().dtype(at::kBFloat16));

gemv_bf16<<<grid_dim, block_dim, 0, stream>>>(
reinterpret_cast<__nv_bfloat16*>(W.data_ptr()), // mat
reinterpret_cast<__nv_bfloat16*>(X.data_ptr()), // vec
reinterpret_cast<__nv_bfloat16*>(Y.data_ptr()), // res
k,
num_per_thread);

C10_CUDA_KERNEL_LAUNCH_CHECK();

return Y;
}

} // namespace fbgemm_gpu
Loading

0 comments on commit a47adae

Please sign in to comment.