Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat/provide list of sending domains #2018

Open
wants to merge 9 commits into
base: main
Choose a base branch
from
6 changes: 5 additions & 1 deletion app/main/forms.py
Original file line number Diff line number Diff line change
Expand Up @@ -553,7 +553,11 @@ def __init__(self, *args, **kwargs):


class SendingDomainForm(StripWhitespaceForm):
sending_domain = StringField(_l("Sending domain"), validators=[])
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.sending_domain.choices = kwargs["sending_domain_choices"]

sending_domain = SelectField(_l("Sending domain"), validators=[])


class RenameOrganisationForm(StripWhitespaceForm):
Expand Down
3 changes: 2 additions & 1 deletion app/main/views/service_settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@
email_safe,
get_logo_cdn_domain,
get_new_default_reply_to_address,
get_verified_ses_domains,
user_has_permissions,
user_is_gov_user,
user_is_platform_admin,
Expand Down Expand Up @@ -549,7 +550,7 @@ def service_set_reply_to_email(service_id):
@main.route("/services/<service_id>/service-settings/sending-domain", methods=["GET", "POST"])
@user_is_platform_admin
def service_sending_domain(service_id):
form = SendingDomainForm()
form = SendingDomainForm(sending_domain_choices=get_verified_ses_domains())

if request.method == "GET":
form.sending_domain.data = current_service.sending_domain
Expand Down
21 changes: 21 additions & 0 deletions app/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -905,3 +905,24 @@ def get_limit_reset_time_et() -> dict[str, str]:

limit_reset_time_et = {"en": next_midnight_utc_in_et.strftime("%-I%p"), "fr": next_midnight_utc_in_et.strftime("%H")}
return limit_reset_time_et


@cache.memoize(timeout=5 * 60)
def get_verified_ses_domains():
"""Query AWS SES for verified domain identities"""
ses = boto3.client("ses", region_name="ca-central-1")

# Get all identities
identities = ses.list_identities(
IdentityType="Domain" # Only get domains, not email addresses
)

# Get verification status for each domain
verification_attrs = ses.get_identity_verification_attributes(Identities=identities["Identities"])["VerificationAttributes"]

# Format results
domains = [
domain for domain in identities["Identities"] if verification_attrs.get(domain, {}).get("VerificationStatus") == "Success"
]

return domains
66 changes: 66 additions & 0 deletions tests/app/main/views/test_service_settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -4377,3 +4377,69 @@ def test_should_suspend_service_callback_api(self, client_request, platform_admi
_data={"updated_by_id": platform_admin_user["id"], "suspend_unsuspend": False},
_expected_status=200,
)


class TestSendingDomain:
def test_sending_domain_page_shows_dropdown_of_verified_domains(
self, client_request, platform_admin_user, mock_get_service_settings_page_common, mocker
):
client_request.login(platform_admin_user)
mock_get_domains = mocker.patch(
"app.main.views.service_settings.get_verified_ses_domains", return_value=["domain1.com", "domain2.com"]
)

page = client_request.get("main.service_sending_domain", service_id=SERVICE_ONE_ID, _test_page_title=False)

assert [option["value"] for option in page.select("select[name=sending_domain] option")] == ["domain1.com", "domain2.com"]

mock_get_domains.assert_called_once()

def test_sending_domain_page_populates_with_current_domain(
self, client_request, platform_admin_user, mock_get_service_settings_page_common, mocker, service_one
):
service_one["sending_domain"] = "domain1.com"
mocker.patch("app.main.views.service_settings.get_verified_ses_domains", return_value=["domain1.com", "domain2.com"])

client_request.login(platform_admin_user)
page = client_request.get("main.service_sending_domain", service_id=SERVICE_ONE_ID, _test_page_title=False)

assert page.select_one("select[name=sending_domain] option[selected]")["value"] == "domain1.com"

def test_sending_domain_page_updates_domain_and_redirects_when_posted(
self, client_request, platform_admin_user, mock_get_service_settings_page_common, mocker, service_one, mock_update_service
):
mocker.patch("app.main.views.service_settings.get_verified_ses_domains", return_value=["domain1.com", "domain2.com"])

client_request.login(platform_admin_user)
client_request.post(
"main.service_sending_domain",
service_id=SERVICE_ONE_ID,
_data={"sending_domain": "domain2.com"},
_expected_redirect=url_for(
"main.service_settings",
service_id=SERVICE_ONE_ID,
),
_test_page_title=False,
)

mock_update_service.assert_called_once_with(service_one["id"], sending_domain="domain2.com")

def test_sending_domain_page_doesnt_update_if_domain_not_in_allowed_list(
self, client_request, platform_admin_user, mock_get_service_settings_page_common, mocker, service_one, mock_update_service
):
mocker.patch("app.main.views.service_settings.get_verified_ses_domains", return_value=["domain1.com", "domain2.com"])

client_request.login(platform_admin_user)
page = client_request.post(
"main.service_sending_domain",
service_id=SERVICE_ONE_ID,
_data={"sending_domain": "domain3.com"},
_expected_status=200,
_test_page_title=False,
)

assert mock_update_service.called is False
assert "Not a valid choice" in page.text

def test_sending_domain_page_404s_for_non_platform_admin(self, client_request, mock_get_service_settings_page_common, mocker):
client_request.get("main.service_sending_domain", service_id=SERVICE_ONE_ID, _expected_status=403)
51 changes: 51 additions & 0 deletions tests/app/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from csv import DictReader
from io import StringIO
from pathlib import Path
from unittest.mock import Mock, patch
from urllib.parse import unquote

import pytest
Expand All @@ -23,6 +24,7 @@
get_new_default_reply_to_address,
get_remote_addr,
get_template,
get_verified_ses_domains,
printing_today_or_tomorrow,
report_security_finding,
)
Expand Down Expand Up @@ -759,3 +761,52 @@ def test_get_new_default_reply_to_address_returns_none_if_one_reply_to(mocker: M

new_default = get_new_default_reply_to_address(email_reply_tos, reply_to_1) # type: ignore
assert new_default is None


class TestGetSESDomains:
@pytest.fixture
def mock_ses_client(self):
with patch("boto3.client") as mock_client:
mock_ses = Mock()
mock_client.return_value = mock_ses
yield mock_ses

@pytest.fixture(autouse=True)
def mock_cache_decorator(self):
with patch("app.utils.cache.memoize", lambda *args, **kwargs: lambda f: f):
yield

def test_get_verified_ses_domains(self, app_, mock_ses_client):
# Setup mock return values
mock_ses_client.list_identities.return_value = {"Identities": ["domain1.com", "domain2.com", "domain3.com"]}

mock_ses_client.get_identity_verification_attributes.return_value = {
"VerificationAttributes": {
"domain1.com": {"VerificationStatus": "Success"},
"domain2.com": {"VerificationStatus": "Failed"},
"domain3.com": {"VerificationStatus": "Pending"},
}
}

# Execute within app context
with app_.test_request_context():
result = get_verified_ses_domains()

# Assert
assert result == ["domain1.com"]
mock_ses_client.list_identities.assert_called_once_with(IdentityType="Domain")
mock_ses_client.get_identity_verification_attributes.assert_called_once_with(
Identities=["domain1.com", "domain2.com", "domain3.com"]
)

def test_get_verified_ses_domains_no_domains(self, app_, mock_ses_client):
# Setup empty response
mock_ses_client.list_identities.return_value = {"Identities": []}
mock_ses_client.get_identity_verification_attributes.return_value = {"VerificationAttributes": {}}

# Execute within app context
with app_.test_request_context():
result = get_verified_ses_domains()

# Assert
assert result == []
Loading