Skip to content

Commit

Permalink
Merge pull request #12 from valohai/invalidate
Browse files Browse the repository at this point in the history
Add support for CloudFront invalidations
  • Loading branch information
akx authored Aug 7, 2024
2 parents c3521b8 + 1704eff commit 8eff4a7
Show file tree
Hide file tree
Showing 9 changed files with 175 additions and 46 deletions.
6 changes: 3 additions & 3 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
repos:
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.3.3
rev: v0.5.6
hooks:
- id: ruff
args:
- --fix
- id: ruff-format
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v4.5.0
rev: v4.6.0
hooks:
- id: check-merge-conflict
- id: check-yaml
Expand All @@ -18,6 +18,6 @@ repos:
args:
- --fix=lf
- repo: https://github.com/crate-ci/typos
rev: v1.19.0
rev: v1.23.6
hooks:
- id: typos
37 changes: 37 additions & 0 deletions art/cloudfront.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
from __future__ import annotations

import logging
import time
from typing import Any

log = logging.getLogger(__name__)


# Separated for testing purposes
def get_cloudfront_client() -> Any:
import boto3

return boto3.client("cloudfront")


def execute_cloudfront_invalidations(invalidations: dict[str, set[str]]) -> None:
cf_client = get_cloudfront_client()
ts = int(time.time())
for dist_id, paths in invalidations.items():
log.info("Creating CloudFront invalidation for %s: %d paths", dist_id, len(paths))
caller_reference = f"art-{dist_id}-{ts}"
inv = cf_client.create_invalidation(
DistributionId=dist_id,
InvalidationBatch={
"Paths": {
"Quantity": len(paths),
"Items": sorted(paths),
},
"CallerReference": caller_reference,
},
)
log.info(
"Created CloudFront invalidation with caller reference %s: %s",
caller_reference,
inv["Invalidation"]["Id"],
)
42 changes: 25 additions & 17 deletions art/command.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,11 @@
import os
import shutil
import tempfile
from typing import Any, Dict, List, Optional
from typing import List, Optional

from art.config import ArtConfig, FileMapEntry
from art.consts import DEFAULT_CONFIG_FILENAME
from art.context import ArtContext
from art.excs import Problem
from art.git import git_clone
from art.manifest import Manifest
Expand Down Expand Up @@ -95,35 +96,38 @@ def run_command(argv: Optional[List[str]] = None) -> None:
args = Args(**vars(ap.parse_args(argv)))
logging.basicConfig(level=(args.log_level or logging.INFO))

config_args: Dict[str, Any] = {"dests": list(args.dests), "name": ""}
is_git = False
if args.git_source:
config_args.update(
work_dir = tempfile.mkdtemp(prefix="art-git-")
atexit.register(shutil.rmtree, work_dir)
config = ArtConfig(
dests=list(args.dests),
name="",
repo_url=args.git_source,
ref=args.git_ref,
work_dir=tempfile.mkdtemp(prefix="art-git-"),
work_dir=work_dir,
)
is_git = True
git_clone(config)
elif args.local_source:
work_dir = os.path.abspath(args.local_source)
config_args.update(
config = ArtConfig(
dests=list(args.dests),
name="",
repo_url=work_dir,
work_dir=work_dir,
)
else:
ap.error("Either a git source or a local source must be defined")

config = ArtConfig(**config_args)

if is_git:
git_clone(config)
atexit.register(shutil.rmtree, config.work_dir)
return
context = ArtContext(
dry_run=bool(args.dry_run),
)

for forked_config in fork_configs_from_work_dir(config, filename=args.config_file):
try:
process_config_postfork(args, forked_config)
process_config_postfork(context, args, forked_config)
except Problem as p:
ap.error(f"config {forked_config.name}: {p}")
context.execute_post_run_tasks()


def clean_dest(dest: str) -> str:
Expand All @@ -132,7 +136,11 @@ def clean_dest(dest: str) -> str:
return dest


def process_config_postfork(args: Args, config: ArtConfig) -> None:
def process_config_postfork(
context: ArtContext,
args: Args,
config: ArtConfig,
) -> None:
if not config.dests:
raise Problem("No destination(s) specified (on command line or in config in source)")
config.dests = [clean_dest(dest) for dest in config.dests]
Expand All @@ -152,12 +160,12 @@ def process_config_postfork(args: Args, config: ArtConfig) -> None:
for dest in config.dests:
for suffix in suffixes:
write(
config,
context=context,
config=config,
dest=dest,
path_suffix=suffix,
manifest=manifest,
wrap_filename=wrap_temp,
dry_run=args.dry_run,
)
if wrap_temp:
os.unlink(wrap_temp)
Expand Down
19 changes: 19 additions & 0 deletions art/context.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
from __future__ import annotations

import dataclasses

from art.cloudfront import execute_cloudfront_invalidations


@dataclasses.dataclass(frozen=True)
class ArtContext:
dry_run: bool = False
_cloudfront_invalidations: dict[str, set[str]] = dataclasses.field(default_factory=dict)

def add_cloudfront_invalidation(self, dist_id: str, path: str) -> None:
self._cloudfront_invalidations.setdefault(dist_id, set()).add(path)

def execute_post_run_tasks(self) -> None:
if self._cloudfront_invalidations:
execute_cloudfront_invalidations(self._cloudfront_invalidations)
self._cloudfront_invalidations.clear()
26 changes: 18 additions & 8 deletions art/s3.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,27 @@
import logging
from functools import cache
from typing import IO, Any, Dict
from urllib.parse import urlparse

_s3_client = None
from art.context import ArtContext

log = logging.getLogger(__name__)


@cache
def get_s3_client() -> Any:
global _s3_client
if not _s3_client:
import boto3
import boto3

_s3_client = boto3.client("s3")
return _s3_client
return boto3.client("s3")


def s3_write(url: str, source_fp: IO[bytes], *, options: Dict[str, Any], dry_run: bool) -> None:
def s3_write(
url: str,
source_fp: IO[bytes],
*,
options: Dict[str, Any],
context: ArtContext,
) -> None:
purl = urlparse(url)
s3_client = get_s3_client()
assert purl.scheme == "s3"
Expand All @@ -27,8 +33,12 @@ def s3_write(url: str, source_fp: IO[bytes], *, options: Dict[str, Any], dry_run
if acl:
kwargs["ACL"] = acl

if dry_run:
if context.dry_run:
log.info("Dry-run: would write to S3 (ACL %s): %s", acl, url)
return
s3_client.put_object(**kwargs)
log.info("Wrote to S3 (ACL %s): %s", acl, url)

cf_distribution_id = options.get("cf-distribution-id")
if cf_distribution_id:
context.add_cloudfront_invalidation(cf_distribution_id, purl.path)
30 changes: 21 additions & 9 deletions art/write.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from urllib.parse import parse_qsl

from art.config import ArtConfig
from art.context import ArtContext
from art.manifest import Manifest
from art.s3 import s3_write

Expand All @@ -17,13 +18,13 @@ def _write_file(
dest: str,
source_fp: IO[bytes],
*,
context: ArtContext,
options: Optional[Dict[str, Any]] = None,
dry_run: bool = False,
) -> None:
if options is None:
options = {}
writer = _get_writer_for_dest(dest)
writer(dest, source_fp, options=options, dry_run=dry_run)
writer(dest, source_fp, options=options, context=context)


def _get_writer_for_dest(dest: str) -> Callable: # type: ignore[type-arg]
Expand All @@ -34,8 +35,14 @@ def _get_writer_for_dest(dest: str) -> Callable: # type: ignore[type-arg]
raise ValueError(f"Invalid destination: {dest}")


def local_write(dest: str, source_fp: IO[bytes], *, options: Dict[str, Any], dry_run: bool) -> None:
if dry_run:
def local_write(
dest: str,
source_fp: IO[bytes],
*,
context: ArtContext,
options: Dict[str, Any],
) -> None:
if context.dry_run:
log.info("Dry-run: Would have written local file %s", dest)
return
os.makedirs(os.path.dirname(dest), exist_ok=True)
Expand All @@ -45,12 +52,12 @@ def local_write(dest: str, source_fp: IO[bytes], *, options: Dict[str, Any], dry


def write(
config: ArtConfig,
*,
context: ArtContext,
config: ArtConfig,
dest: str,
path_suffix: str,
manifest: Manifest,
dry_run: bool,
wrap_filename: Optional[str] = None,
) -> None:
options = {}
Expand All @@ -63,20 +70,25 @@ def write(
dest_path = posixpath.join(dest, dest_filename)
local_path = os.path.join(config.work_dir, fileinfo["path"])
with open(local_path, "rb") as infp:
_write_file(dest_path, infp, options=options, dry_run=dry_run)
_write_file(
dest_path,
infp,
context=context,
options=options,
)

_write_file(
dest=posixpath.join(dest, ".manifest.json"),
source_fp=io.BytesIO(manifest.as_json_bytes()),
context=context,
options=options,
dry_run=dry_run,
)

if config.wrap and wrap_filename:
with open(wrap_filename, "rb") as infp:
_write_file(
dest=posixpath.join(dest, config.wrap),
source_fp=infp,
context=context,
options=options,
dry_run=dry_run,
)
45 changes: 42 additions & 3 deletions art_tests/test_s3.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,52 @@
import io
from unittest.mock import Mock

import pytest
from boto3 import _get_default_session

from art import cloudfront
from art.context import ArtContext
from art.s3 import get_s3_client
from art.write import _write_file


def test_s3_acl(mocker):
@pytest.fixture(autouse=True)
def aws_fake_credentials(monkeypatch):
# Makes sure we don't accidentally use real AWS credentials.
monkeypatch.setattr(_get_default_session()._session, "_credentials", Mock())


def test_s3_acl(monkeypatch):
cli = get_s3_client()
cli.put_object = cli.put_object # avoid magic
mocker.patch.object(cli, "put_object")
put_object = Mock()
monkeypatch.setattr(cli, "put_object", put_object)
body = io.BytesIO(b"test")
_write_file("s3://bukkit/key", body, options={"acl": "public-read"})
_write_file("s3://bukkit/key", body, options={"acl": "public-read"}, context=ArtContext())
cli.put_object.assert_called_with(Bucket="bukkit", Key="key", ACL="public-read", Body=body)


def test_s3_invalidate_cloudfront(monkeypatch):
cli = get_s3_client()
cli.put_object = cli.put_object # avoid magic
put_object = Mock()
monkeypatch.setattr(cli, "put_object", put_object)
body = io.BytesIO(b"test")
options = {"acl": "public-read", "cf-distribution-id": "UWUWU"}
context = ArtContext()
_write_file("s3://bukkit/key/foo/bar", body, options=options, context=context)
_write_file("s3://bukkit/key/baz/quux", body, options=options, context=context)
_write_file("s3://bukkit/key/baz/barple", body, options=options, context=context)
cf_client = Mock()
cf_client.create_invalidation.return_value = {"Invalidation": {"Id": "AAAAA"}}
monkeypatch.setattr(cloudfront, "get_cloudfront_client", Mock(return_value=cf_client))
context.execute_post_run_tasks()
# Assert the 3 files get a single invalidation
cf_client.create_invalidation.assert_called_once()
call_kwargs = cf_client.create_invalidation.call_args.kwargs
assert call_kwargs["DistributionId"] == "UWUWU"
assert set(call_kwargs["InvalidationBatch"]["Paths"]["Items"]) == {
"/key/baz/barple",
"/key/baz/quux",
"/key/foo/bar",
}
15 changes: 10 additions & 5 deletions art_tests/test_write.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,23 @@
import unittest.mock

import art.write
from art.config import ArtConfig
from art.context import ArtContext
from art.manifest import Manifest


def test_dest_options(mocker, tmpdir):
def test_dest_options(monkeypatch, tmpdir):
cfg = ArtConfig(work_dir=str(tmpdir), dests=[str(tmpdir)], name="", repo_url=str(tmpdir))
mf = Manifest(files={})
wf = mocker.patch("art.write._write_file")
wf = unittest.mock.MagicMock()
monkeypatch.setattr(art.write, "_write_file", wf)
context = ArtContext(dry_run=False)
art.write.write(
cfg,
config=cfg,
context=context,
dest="derp://foo/bar/?acl=quux",
path_suffix="blag",
manifest=mf,
dry_run=False,
path_suffix="blag",
)
call_kwargs = wf.call_args[1]
assert call_kwargs["options"] == {"acl": "quux"}
Expand Down
Loading

0 comments on commit 8eff4a7

Please sign in to comment.