From 8e09285f9a6b1961575b85694c216b3966ef9f98 Mon Sep 17 00:00:00 2001 From: Tim DiLauro Date: Wed, 4 Dec 2024 16:23:16 -0500 Subject: [PATCH] Better handle OPDS server root content types. (PP-1959) --- opds.py | 19 +++++- registrar.py | 9 ++- tests/test_controller.py | 64 ++++++++++++++++--- tests/test_opds.py | 68 ++++++++++++++++++++ tests/test_util_http.py | 132 +++++++++++++++++++++++++++++++++++++++ util/http.py | 56 +++++++++++++++++ 6 files changed, 335 insertions(+), 13 deletions(-) diff --git a/opds.py b/opds.py index aa1ed034..22561a50 100644 --- a/opds.py +++ b/opds.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import json import flask @@ -7,6 +9,7 @@ from authentication_document import AuthenticationDocument from config import Configuration from model import ConfigurationSetting, Hyperlink, LibraryType, Validation +from util.http import NormalizedMediaType class Annotator: @@ -38,6 +41,9 @@ class OPDSCatalog: CACHE_TIME = 3600 * 12 + _NORMALIZED_OPDS2_TYPE = NormalizedMediaType(OPDS_TYPE) + _NORMALIZED_OPDS1_TYPE = NormalizedMediaType(OPDS_1_TYPE) + @classmethod def _strftime(cls, date): """ @@ -92,6 +98,18 @@ def __init__( ) annotator.annotate_catalog(self, live=live) + @classmethod + def is_opds1_type(cls, media_type: str | None) -> bool: + return cls._NORMALIZED_OPDS1_TYPE.min_match(media_type) + + @classmethod + def is_opds2_type(cls, media_type: str | None) -> bool: + return cls._NORMALIZED_OPDS2_TYPE.min_match(media_type) + + @classmethod + def is_opds_type(cls, media_type: str | None) -> bool: + return cls.is_opds1_type(media_type) or cls.is_opds2_type(media_type) + @classmethod def _feed_is_large(cls, _db, libraries): """Determine whether a prospective feed is 'large' per a sitewide setting. @@ -125,7 +143,6 @@ def library_catalog( web_client_uri_template=None, include_service_area=False, ): - """Create an OPDS catalog for a library. :param distance: The distance, in meters, from the client's diff --git a/registrar.py b/registrar.py index fa8b5cb3..a21ab5c1 100644 --- a/registrar.py +++ b/registrar.py @@ -180,10 +180,9 @@ def register(self, library: Library, library_stage): opds_url=opds_url, auth_url=auth_url, ) - elif content_type not in (OPDSCatalog.OPDS_TYPE, OPDSCatalog.OPDS_1_TYPE): + elif not OPDSCatalog.is_opds_type(content_type): failure_detail = _( - "Supposed root document at %(url)s is not an OPDS document", - url=opds_url, + f"Supposed root document at {opds_url} does not appear to be an OPDS document (content_type={content_type!r}).", ) elif not self.opds_response_links_to_auth_document(opds_response, auth_url): failure_detail = _( @@ -316,14 +315,14 @@ def opds_response_links(cls, response, rel): if link: links.append(link.get("url")) media_type = response.headers.get("Content-Type") - if media_type == OPDSCatalog.OPDS_TYPE: + if OPDSCatalog.is_opds2_type(media_type): # Parse as OPDS 2. catalog = json.loads(response.content) links = [] for k, v in catalog.get("links", {}).items(): if k == rel: links.append(v.get("href")) - elif media_type == OPDSCatalog.OPDS_1_TYPE: + elif OPDSCatalog.is_opds1_type(media_type): # Parse as OPDS 1. feed = feedparser.parse(response.content) for link in feed.get("feed", {}).get("links", []): diff --git a/tests/test_controller.py b/tests/test_controller.py index 2a24d5b5..d39d7968 100644 --- a/tests/test_controller.py +++ b/tests/test_controller.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import base64 import datetime import json @@ -1507,30 +1509,78 @@ def test_register_fails_on_start_link_error( == "Error retrieving OPDS root document at http://circmanager.org/feed/" ) + @pytest.mark.parametrize( + "media_type, expect_success", + [ + pytest.param( + "application/atom+xml;profile=opds-catalog;kind=acquisition", + True, + id="opds1-acquisition", + ), + pytest.param( + "application/atom+xml;kind=acquisition;profile=opds-catalog", + True, + id="opds1_acquisition-different-order", + ), + pytest.param( + "application/atom+xml;profile=opds-catalog;kind=acquisition;api-version=1", + True, + id="opds1-acquisition-apiv1", + ), + pytest.param( + "application/atom+xml;api-version=1;kind=acquisition;profile=opds-catalog", + True, + id="opds1_acquisition_apiv1-different-order", + ), + pytest.param( + "application/atom+xml;api-version=2;kind=acquisition;profile=opds-catalog", + True, + id="opds1-acquisition-apiv2", + ), + pytest.param( + "application/atom+xml;profile=opds-catalog;kind=navigation", + False, + id="opds1-navigation", + ), + pytest.param("application/opds+json;api-version=1", True, id="opds2-apiv1"), + pytest.param("application/opds+json;api-version=2", True, id="opds2-apiv2"), + pytest.param("application/epub+zip", False, id="epub+zip"), + pytest.param("application/json", False, id="application-json"), + pytest.param("", False, id="empty-string"), + pytest.param(None, False, id="none-value"), + ], + ) def test_register_fails_on_start_link_not_opds_feed( - self, registry_controller_fixture: LibraryRegistryControllerFixture - ): + self, + media_type: str | None, + expect_success: bool, + registry_controller_fixture: LibraryRegistryControllerFixture, + ) -> None: fixture = registry_controller_fixture + # An empty string media type results in a no content-type header. + content_type = None if media_type == "" else media_type """The request returns an authentication document but an attempt to retrieve the corresponding OPDS feed gives a server-side error. """ auth_document = self._auth_document() + # The start link returns a 200 response code but the media type might be wrong. fixture.http_client.queue_response( 200, content=json.dumps(auth_document), url=auth_document["id"] ) - # The start link returns a 200 response code but the wrong - # Content-Type. - fixture.http_client.queue_response(200, "text/html") + fixture.http_client.queue_response(200, media_type) with fixture.app.test_request_context("/", method="POST"): flask.request.form = fixture.form response = fixture.controller.register(do_get=fixture.http_client.do_get) + # We expect to get INVALID_INTEGRATION_DOCUMENT problem detail here, in any case, + # since our test is not fully configured. assert response.uri == INVALID_INTEGRATION_DOCUMENT.uri + # But we should see the `not OPDS` detail only in the case of an invalid media type. assert ( response.detail - == "Supposed root document at http://circmanager.org/feed/ is not an OPDS document" - ) + != f"Supposed root document at http://circmanager.org/feed/ does not appear to be an OPDS document (content_type={content_type!r})." + ) == expect_success def test_register_fails_if_start_link_does_not_link_back_to_auth_document( self, registry_controller_fixture: LibraryRegistryControllerFixture diff --git a/tests/test_opds.py b/tests/test_opds.py index 1ff0f9aa..f595649e 100644 --- a/tests/test_opds.py +++ b/tests/test_opds.py @@ -1,6 +1,10 @@ +from __future__ import annotations + import datetime import json +import pytest + from authentication_document import AuthenticationDocument from config import Configuration from model import ( @@ -351,3 +355,67 @@ def assert_reservation_status(expect): # _hyperlink_args stops working. hyperlink.resource = None assert m(hyperlink) is None + + +class TestOpdsMediaTypeChecks: + @pytest.mark.parametrize( + "media_type, expect_is_opds1, expect_is_opds2", + [ + pytest.param( + "application/atom+xml;profile=opds-catalog;kind=acquisition", + True, + False, + id="opds1-acquisition", + ), + pytest.param( + "application/atom+xml;kind=acquisition;profile=opds-catalog", + True, + False, + id="opds1_acquisition-different-order", + ), + pytest.param( + "application/atom+xml;profile=opds-catalog;kind=acquisition;api-version=1", + True, + False, + id="opds1-acquisition-apiv1", + ), + pytest.param( + "application/atom+xml;api-version=1;kind=acquisition;profile=opds-catalog", + True, + False, + id="opds1_acquisition_apiv1-different-order", + ), + pytest.param( + "application/atom+xml;api-version=2;kind=acquisition;profile=opds-catalog", + True, + False, + id="opds1-acquisition-apiv2", + ), + pytest.param( + "application/atom+xml;profile=opds-catalog;kind=navigation", + False, + False, + id="opds1-navigation", + ), + pytest.param( + "application/opds+json;api-version=1", False, True, id="opds2-apiv1" + ), + pytest.param( + "application/opds+json;api-version=2", False, True, id="opds2-apiv2" + ), + pytest.param("application/epub+zip", False, False, id="epub+zip"), + pytest.param("application/json", False, False, id="application-json"), + pytest.param("", False, False, id="empty-string"), + pytest.param(None, False, False, id="none-value"), + ], + ) + def test_opds_catalog_types( + self, media_type: str | None, expect_is_opds1: bool, expect_is_opds2: bool + ) -> None: + is_opds1 = OPDSCatalog.is_opds1_type(media_type) + is_opds2 = OPDSCatalog.is_opds2_type(media_type) + is_opds_catalog_type = OPDSCatalog.is_opds_type(media_type) + + assert is_opds1 == expect_is_opds1 + assert is_opds2 == expect_is_opds2 + assert is_opds_catalog_type == (expect_is_opds1 or expect_is_opds2) diff --git a/tests/test_util_http.py b/tests/test_util_http.py index 4edf1317..9be729a6 100644 --- a/tests/test_util_http.py +++ b/tests/test_util_http.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import json import pytest @@ -7,6 +9,7 @@ from util.http import ( HTTP, BadResponseException, + NormalizedMediaType, RemoteIntegrationException, RequestNetworkException, RequestTimedOut, @@ -296,3 +299,132 @@ def test_as_problem_detail_document(self): # The status code corresponding to an upstream timeout is 502. document, status_code, headers = standard_detail.response assert status_code == 502 + + +class TestNormalizedMediaType: + @pytest.mark.parametrize( + "media_type, expected_prefix, expected_directives", + [ + pytest.param("text/html", "text/html", {}, id="simple-type"), + pytest.param( + "application/json; charset=utf-8", + "application/json", + {"charset": "utf-8"}, + id="type-with-single-directive", + ), + pytest.param( + "image/png; quality=high; version=1.0", + "image/png", + {"quality": "high", "version": "1.0"}, + id="type-with-multiple-directives", + ), + pytest.param( + "image/png; version=1.0; quality=high", + "image/png", + {"quality": "high", "version": "1.0"}, + id="directives-in-different-order", + ), + pytest.param( + "application/atom+xml;profile=opds-catalog;kind=acquisition", + "application/atom+xml", + {"profile": "opds-catalog", "kind": "acquisition"}, + id="opds-without-api-version", + ), + pytest.param( + "application/atom+xml;profile=opds-catalog;kind=acquisition;api-version=1", + "application/atom+xml", + {"profile": "opds-catalog", "kind": "acquisition", "api-version": "1"}, + id="opds-with-api-version", + ), + pytest.param( + "text/html;", "text/html", {}, id="simple-type-trailing-semicolon" + ), + pytest.param( + "text/html;=", "text/html", {}, id="simple-type-trailing-equal" + ), + pytest.param( + "text/html;=value", "text/html", {}, id="simple-type-trailing-equal" + ), + ], + ) + def test_initialization( + self, + media_type: str, + expected_prefix: str, + expected_directives: dict[str, str], + ) -> None: + normalized = NormalizedMediaType(media_type) + + assert normalized.prefix == expected_prefix + assert normalized.directives == expected_directives + + @pytest.mark.parametrize( + "media_type, other, expected_match", + [ + pytest.param( + "application/atom+xml;profile=opds-catalog;kind=acquisition", + "application/atom+xml;profile=opds-catalog;kind=acquisition;api-version=1", + True, + id="other-with-added-directives", + ), + pytest.param( + "application/atom+xml;profile=opds-catalog;kind=acquisition;api-version=1", + "application/atom+xml;profile=opds-catalog;kind=acquisition", + False, + id="other-with-fewer-directives", + ), + pytest.param( + "application/atom+xml;profile=opds-catalog;kind=acquisition", + "application/atom+xml;api-version=1;kind=acquisition;profile=opds-catalog", + True, + id="added-directives-different-order", + ), + pytest.param( + "image/png; quality=high; version=1.0", + "image/png;quality=high;version=1.0", + True, + id="self-extra-spaces", + ), + pytest.param( + "image/png;quality=high;version=1.0", + "image/png; quality=high; version=1.0", + True, + id="other-extra-spaces", + ), + pytest.param("text/html", "text/html", True, id="exact-match"), + pytest.param("text/htmlxx", "text/html", False, id="self-extra-characters"), + pytest.param( + "text/html", "text/htmlxx", False, id="other-extra-characters" + ), + pytest.param( + "application/json; charset=utf-8", + "application/json; charset=utf-8", + True, + id="exact-match-with-directives", + ), + pytest.param( + "image/png; quality=high", + "image/png; quality=high; version=1.0", + True, + id="match-with-extra-directives", + ), + pytest.param( + "image/png; quality=high", + "image/png; quality=low", + False, + id="mismatch-directive-value", + ), + pytest.param("text/html", "application/json", False, id="different-prefix"), + ], + ) + def test_min_match(self, media_type: str, other: str, expected_match) -> None: + normalized = NormalizedMediaType(media_type) + is_match = normalized.min_match(other) + none_match = normalized.min_match(None) + + assert is_match == expected_match + assert none_match == False + + def test_invalid_cases(self): + with pytest.raises(AttributeError): + NormalizedMediaType._from_string(None) diff --git a/util/http.py b/util/http.py index ffdfa683..44b108ae 100644 --- a/util/http.py +++ b/util/http.py @@ -1,5 +1,8 @@ +from __future__ import annotations + import logging import urllib.parse +from dataclasses import dataclass import requests from flask_babel import lazy_gettext as _ @@ -425,3 +428,56 @@ def process_debuggable_response( response.content, ) ) + + +@dataclass +class NormalizedMediaType: + """Represents a normalized HTTP media type with its type prefix and directives. + + This class provides methods to normalize a media type string into a canonical form + and compare media types. + + TODO: Parsing is currently very lenient, and will accept media types with missing + or malformed directives. + """ + + string: str + + def __post_init__(self): + self.prefix, self.directives = self._from_string(self.string) + + @staticmethod + def _from_string(media_type: str): + elements = media_type.split(";") + prefix = elements.pop(0) + directives = { + k.strip(): v.strip() + for k, v in ( + directive.split("=", 1) for directive in elements if "=" in directive + ) + if k.strip() + } + return prefix, directives + + def min_match(self, other: NormalizedMediaType | str | None) -> bool: + """Determine whether another media type is a 'minimum match' for this one. + + :param other: The other media type to compare with this one. + :return: True if the other media type is a minimum match for this one, False otherwise. + + A minimum match here is a media type that matches the prefix and + directives of this one, but may or may not have additional directives. + """ + if other is None: + return False + if isinstance(other, str): + other = NormalizedMediaType(other) + + if self.prefix != other.prefix: + return False + for k, v in self.directives.items(): + if k not in other.directives: + return False + if other.directives[k] != v: + return False + return True