-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #12 from valohai/invalidate
Add support for CloudFront invalidations
- Loading branch information
Showing
9 changed files
with
175 additions
and
46 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,37 @@ | ||
from __future__ import annotations | ||
|
||
import logging | ||
import time | ||
from typing import Any | ||
|
||
log = logging.getLogger(__name__) | ||
|
||
|
||
# Separated for testing purposes | ||
def get_cloudfront_client() -> Any: | ||
import boto3 | ||
|
||
return boto3.client("cloudfront") | ||
|
||
|
||
def execute_cloudfront_invalidations(invalidations: dict[str, set[str]]) -> None: | ||
cf_client = get_cloudfront_client() | ||
ts = int(time.time()) | ||
for dist_id, paths in invalidations.items(): | ||
log.info("Creating CloudFront invalidation for %s: %d paths", dist_id, len(paths)) | ||
caller_reference = f"art-{dist_id}-{ts}" | ||
inv = cf_client.create_invalidation( | ||
DistributionId=dist_id, | ||
InvalidationBatch={ | ||
"Paths": { | ||
"Quantity": len(paths), | ||
"Items": sorted(paths), | ||
}, | ||
"CallerReference": caller_reference, | ||
}, | ||
) | ||
log.info( | ||
"Created CloudFront invalidation with caller reference %s: %s", | ||
caller_reference, | ||
inv["Invalidation"]["Id"], | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,19 @@ | ||
from __future__ import annotations | ||
|
||
import dataclasses | ||
|
||
from art.cloudfront import execute_cloudfront_invalidations | ||
|
||
|
||
@dataclasses.dataclass(frozen=True) | ||
class ArtContext: | ||
dry_run: bool = False | ||
_cloudfront_invalidations: dict[str, set[str]] = dataclasses.field(default_factory=dict) | ||
|
||
def add_cloudfront_invalidation(self, dist_id: str, path: str) -> None: | ||
self._cloudfront_invalidations.setdefault(dist_id, set()).add(path) | ||
|
||
def execute_post_run_tasks(self) -> None: | ||
if self._cloudfront_invalidations: | ||
execute_cloudfront_invalidations(self._cloudfront_invalidations) | ||
self._cloudfront_invalidations.clear() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,13 +1,52 @@ | ||
import io | ||
from unittest.mock import Mock | ||
|
||
import pytest | ||
from boto3 import _get_default_session | ||
|
||
from art import cloudfront | ||
from art.context import ArtContext | ||
from art.s3 import get_s3_client | ||
from art.write import _write_file | ||
|
||
|
||
def test_s3_acl(mocker): | ||
@pytest.fixture(autouse=True) | ||
def aws_fake_credentials(monkeypatch): | ||
# Makes sure we don't accidentally use real AWS credentials. | ||
monkeypatch.setattr(_get_default_session()._session, "_credentials", Mock()) | ||
|
||
|
||
def test_s3_acl(monkeypatch): | ||
cli = get_s3_client() | ||
cli.put_object = cli.put_object # avoid magic | ||
mocker.patch.object(cli, "put_object") | ||
put_object = Mock() | ||
monkeypatch.setattr(cli, "put_object", put_object) | ||
body = io.BytesIO(b"test") | ||
_write_file("s3://bukkit/key", body, options={"acl": "public-read"}) | ||
_write_file("s3://bukkit/key", body, options={"acl": "public-read"}, context=ArtContext()) | ||
cli.put_object.assert_called_with(Bucket="bukkit", Key="key", ACL="public-read", Body=body) | ||
|
||
|
||
def test_s3_invalidate_cloudfront(monkeypatch): | ||
cli = get_s3_client() | ||
cli.put_object = cli.put_object # avoid magic | ||
put_object = Mock() | ||
monkeypatch.setattr(cli, "put_object", put_object) | ||
body = io.BytesIO(b"test") | ||
options = {"acl": "public-read", "cf-distribution-id": "UWUWU"} | ||
context = ArtContext() | ||
_write_file("s3://bukkit/key/foo/bar", body, options=options, context=context) | ||
_write_file("s3://bukkit/key/baz/quux", body, options=options, context=context) | ||
_write_file("s3://bukkit/key/baz/barple", body, options=options, context=context) | ||
cf_client = Mock() | ||
cf_client.create_invalidation.return_value = {"Invalidation": {"Id": "AAAAA"}} | ||
monkeypatch.setattr(cloudfront, "get_cloudfront_client", Mock(return_value=cf_client)) | ||
context.execute_post_run_tasks() | ||
# Assert the 3 files get a single invalidation | ||
cf_client.create_invalidation.assert_called_once() | ||
call_kwargs = cf_client.create_invalidation.call_args.kwargs | ||
assert call_kwargs["DistributionId"] == "UWUWU" | ||
assert set(call_kwargs["InvalidationBatch"]["Paths"]["Items"]) == { | ||
"/key/baz/barple", | ||
"/key/baz/quux", | ||
"/key/foo/bar", | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.