From 0841eac734854a1353457f4d1d4c98e2ea90a49d Mon Sep 17 00:00:00 2001 From: tfrench Date: Thu, 7 Sep 2023 14:51:30 +0000 Subject: [PATCH] [go_sdk download] allow patches to standard library --- go/private/BUILD.bazel | 1 + go/private/extensions.bzl | 9 +++ go/private/sdk.bzl | 34 +++++--- tests/bcr/BUILD.bazel | 5 ++ tests/bcr/MODULE.bazel | 8 +- tests/bcr/sdk_patch_test.go | 13 +++ tests/bcr/test_go_sdk.patch | 13 +++ .../go_download_sdk/go_download_sdk_test.go | 81 +++++++++++++++++++ 8 files changed, 151 insertions(+), 13 deletions(-) create mode 100644 tests/bcr/sdk_patch_test.go create mode 100644 tests/bcr/test_go_sdk.patch diff --git a/go/private/BUILD.bazel b/go/private/BUILD.bazel index 16907b537a..92f3239194 100644 --- a/go/private/BUILD.bazel +++ b/go/private/BUILD.bazel @@ -107,6 +107,7 @@ bzl_library( "//go/private:nogo", "//go/private:platforms", "//go/private/skylib/lib:versions", + "@bazel_tools//tools/build_defs/repo:utils.bzl", ], ) diff --git a/go/private/extensions.bzl b/go/private/extensions.bzl index bcceaa7c3a..857cbb57d5 100644 --- a/go/private/extensions.bzl +++ b/go/private/extensions.bzl @@ -31,6 +31,13 @@ _download_tag = tag_class( ), "urls": attr.string_list(default = ["https://dl.google.com/go/{}"]), "version": attr.string(), + "patches": attr.label_list( + doc = "A list of patches to apply to the SDK after downloading it", + ), + "patch_strip": attr.int( + default = 0, + doc = "The number of leading path segments to be stripped from the file name in the patches.", + ), "strip_prefix": attr.string(default = "go"), }, ) @@ -93,6 +100,8 @@ def _go_sdk_impl(ctx): goarch = download_tag.goarch, sdks = download_tag.sdks, experiments = download_tag.experiments, + patches = download_tag.patches, + patch_strip = download_tag.patch_strip, urls = download_tag.urls, version = download_tag.version, strip_prefix = download_tag.strip_prefix, diff --git a/go/private/sdk.bzl b/go/private/sdk.bzl index 08c6e74c00..08fc1a9389 100644 --- a/go/private/sdk.bzl +++ b/go/private/sdk.bzl @@ -12,18 +12,10 @@ # See the License for the specific language governing permissions and # limitations under the License. -load( - "//go/private:common.bzl", - "executable_path", -) -load( - "//go/private:nogo.bzl", - "go_register_nogo", -) -load( - "//go/private/skylib/lib:versions.bzl", - "versions", -) +load("//go/private:common.bzl", "executable_path") +load("//go/private:nogo.bzl", "go_register_nogo") +load("//go/private/skylib/lib:versions.bzl", "versions") +load("@bazel_tools//tools/build_defs/repo:utils.bzl", "patch") MIN_SUPPORTED_VERSION = (1, 14, 0) @@ -75,6 +67,10 @@ def _go_download_sdk_impl(ctx): version = ctx.attr.version sdks = ctx.attr.sdks + if not version: + if ctx.attr.patches: + fail("a single version must be specified to apply patches") + if not sdks: # If sdks was unspecified, download a full list of files. # If version was unspecified, pick the latest version. @@ -115,7 +111,9 @@ def _go_download_sdk_impl(ctx): if platform not in sdks: fail("unsupported platform {}".format(platform)) filename, sha256 = sdks[platform] + _remote_sdk(ctx, [url.format(filename) for url in ctx.attr.urls], ctx.attr.strip_prefix, sha256) + patch(ctx, patch_args = _get_patch_args(ctx.attr.patch_strip)) detected_version = _detect_sdk_version(ctx, ".") _sdk_build_file(ctx, platform, detected_version, experiments = ctx.attr.experiments) @@ -146,6 +144,13 @@ go_download_sdk_rule = repository_rule( "urls": attr.string_list(default = ["https://dl.google.com/go/{}"]), "version": attr.string(), "strip_prefix": attr.string(default = "go"), + "patches": attr.label_list( + doc = "A list of patches to apply to the SDK after downloading it", + ), + "patch_strip": attr.int( + default = 0, + doc = "The number of leading path segments to be stripped from the file name in the patches.", + ), "_sdk_build_file": attr.label( default = Label("//go/private:BUILD.sdk.bazel"), ), @@ -175,6 +180,11 @@ def _to_constant_name(s): # Prefix with _ as identifiers are not allowed to start with numbers. return "_" + "".join([c if c.isalnum() else "_" for c in s.elems()]).upper() +def _get_patch_args(patch_strip): + if patch_strip: + return ["-p{}".format(patch_strip)] + return [] + def go_toolchains_single_definition(ctx, *, prefix, goos, goarch, sdk_repo, sdk_type, sdk_version): if not goos and not goarch: goos, goarch = detect_host_platform(ctx) diff --git a/tests/bcr/BUILD.bazel b/tests/bcr/BUILD.bazel index d5c5f6f1fd..851dd60c0e 100644 --- a/tests/bcr/BUILD.bazel +++ b/tests/bcr/BUILD.bazel @@ -19,6 +19,11 @@ go_test( embed = [":lib"], ) +go_test( + name = "sdk_patch_test", + srcs = ["sdk_patch_test.go"], +) + go_library( name = "mockable", srcs = [ diff --git a/tests/bcr/MODULE.bazel b/tests/bcr/MODULE.bazel index 8cd19d052a..024731c6de 100644 --- a/tests/bcr/MODULE.bazel +++ b/tests/bcr/MODULE.bazel @@ -22,7 +22,13 @@ bazel_dep(name = "gazelle", version = "0.32.0") bazel_dep(name = "protobuf", version = "3.19.6") go_sdk = use_extension("@my_rules_go//go:extensions.bzl", "go_sdk") -go_sdk.download(version = "1.19.5") +go_sdk.download( + version = "1.21.1", + patch_strip = 1, + patches = [ + "//:test_go_sdk.patch", + ], +) # Request an invalid SDK to verify that it isn't fetched since the first tag takes precedence. go_sdk.host(version = "3.0.0") diff --git a/tests/bcr/sdk_patch_test.go b/tests/bcr/sdk_patch_test.go new file mode 100644 index 0000000000..f3c1a64f78 --- /dev/null +++ b/tests/bcr/sdk_patch_test.go @@ -0,0 +1,13 @@ +package lib + +import ( + "os" + + "testing" +) + +func TestName(t *testing.T) { + if os.SayHello != "Hello" { + t.Fail() + } +} diff --git a/tests/bcr/test_go_sdk.patch b/tests/bcr/test_go_sdk.patch new file mode 100644 index 0000000000..83beb83b5f --- /dev/null +++ b/tests/bcr/test_go_sdk.patch @@ -0,0 +1,13 @@ +diff --git a/src/os/dir.go b/src/os/dir.go +index 5306bcb..d110a19 100644 +--- a/src/os/dir.go ++++ b/src/os/dir.go +@@ -17,6 +17,8 @@ const ( + readdirFileInfo + ) + ++const SayHello = "Hello" ++ + // Readdir reads the contents of the directory associated with file and + // returns a slice of up to n FileInfo values, as would be returned + // by Lstat, in directory order. Subsequent calls on the same file will yield diff --git a/tests/core/go_download_sdk/go_download_sdk_test.go b/tests/core/go_download_sdk/go_download_sdk_test.go index 61bd05b89f..12192f5d92 100644 --- a/tests/core/go_download_sdk/go_download_sdk_test.go +++ b/tests/core/go_download_sdk/go_download_sdk_test.go @@ -33,6 +33,11 @@ go_test( srcs = ["version_test.go"], ) +go_test( + name = "patch_test", + srcs = ["patch_test.go"], +) + -- version_test.go -- package version_test @@ -49,6 +54,19 @@ func Test(t *testing.T) { t.Errorf("got version %q; want %q", v, *want) } } +-- patch_test.go -- +package version_test + +import ( + "os" + "testing" +) + +func Test(t *testing.T) { + if v := os.SayHello; v != "Hello"{ + t.Errorf("got version %q; want \"Hello\"", v) + } +} `, }) } @@ -185,3 +203,66 @@ go_register_toolchains() }) } } + +func TestPatch(t *testing.T) { + origWorkspaceData, err := ioutil.ReadFile("WORKSPACE") + if err != nil { + t.Fatal(err) + } + + i := bytes.Index(origWorkspaceData, []byte("go_rules_dependencies()")) + if i < 0 { + t.Fatal("could not find call to go_rules_dependencies()") + } + + buf := &bytes.Buffer{} + buf.Write(origWorkspaceData[:i]) + buf.WriteString(` +load("@io_bazel_rules_go//go:deps.bzl", "go_download_sdk") + +go_download_sdk( + name = "go_sdk_patched", + version = "1.21.1", + patch_strip = 1, + patches = ["//:test.patch"], +) + +go_rules_dependencies() + +go_register_toolchains() +`) + if err := ioutil.WriteFile("WORKSPACE", buf.Bytes(), 0666); err != nil { + t.Fatal(err) + } + + patchContent := []byte(`diff --git a/src/os/dir.go b/src/os/dir.go +index 5306bcb..d110a19 100644 +--- a/src/os/dir.go ++++ b/src/os/dir.go +@@ -17,6 +17,8 @@ const ( + readdirFileInfo + ) + ++const SayHello = "Hello" ++ + // Readdir reads the contents of the directory associated with file and + // returns a slice of up to n FileInfo values, as would be returned + // by Lstat, in directory order. Subsequent calls on the same file will yield +`) + + if err := ioutil.WriteFile("test.patch", patchContent, 0666); err != nil { + t.Fatal(err) + } + defer func() { + if err := ioutil.WriteFile("WORKSPACE", origWorkspaceData, 0666); err != nil { + t.Errorf("error restoring WORKSPACE: %v", err) + } + }() + + if err := bazel_testing.RunBazel( + "test", + "//:patch_test", + ); err != nil { + t.Fatal(err) + } +}