Skip to content

Commit

Permalink
Merge branch 'main' into json-demo
Browse files Browse the repository at this point in the history
  • Loading branch information
andrewwhitehead authored Jun 23, 2021
2 parents edda48a + 05805c0 commit e76e137
Show file tree
Hide file tree
Showing 21 changed files with 1,097 additions and 159 deletions.
2 changes: 1 addition & 1 deletion aries_cloudagent/config/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

from ..core.error import BaseError

InjectType = TypeVar("Inject")
InjectType = TypeVar("InjectType")


class ConfigError(BaseError):
Expand Down
12 changes: 6 additions & 6 deletions aries_cloudagent/core/tests/test_event_bus.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ def event():
yield event


class TestProcessor:
class MockProcessor:
def __init__(self):
self.context = None
self.event = None
Expand All @@ -39,7 +39,7 @@ async def __call__(self, context, event):

@pytest.fixture
def processor():
yield TestProcessor()
yield MockProcessor()


def test_event(event):
Expand Down Expand Up @@ -77,7 +77,7 @@ def test_unsub_unsubbed_processor(event_bus: EventBus, processor):
"""Test unsubscribing an unsubscribed processor does not error."""
event_bus.unsubscribe(re.compile(".*"), processor)
event_bus.subscribe(re.compile(".*"), processor)
another_processor = TestProcessor()
another_processor = MockProcessor()
event_bus.unsubscribe(re.compile(".*"), another_processor)


Expand All @@ -101,7 +101,7 @@ async def test_sub_notify_error_logged_and_exec_continues(
def _raise_exception(context, event):
raise Exception()

processor = TestProcessor()
processor = MockProcessor()
bad_processor = _raise_exception
event_bus.subscribe(re.compile(".*"), bad_processor)
event_bus.subscribe(re.compile(".*"), processor)
Expand Down Expand Up @@ -147,7 +147,7 @@ async def test_sub_notify_no_match(event_bus: EventBus, context, event, processo
@pytest.mark.asyncio
async def test_sub_notify_only_one(event_bus: EventBus, context, event, processor):
"""Test only one subscriber is called when pattern matches only one."""
processor1 = TestProcessor()
processor1 = MockProcessor()
event_bus.subscribe(re.compile(".*"), processor)
event_bus.subscribe(re.compile("^$"), processor1)
await event_bus.notify(context, event)
Expand All @@ -160,7 +160,7 @@ async def test_sub_notify_only_one(event_bus: EventBus, context, event, processo
@pytest.mark.asyncio
async def test_sub_notify_both(event_bus: EventBus, context, event, processor):
"""Test both subscribers are called when pattern matches both."""
processor1 = TestProcessor()
processor1 = MockProcessor()
event_bus.subscribe(re.compile(".*"), processor)
event_bus.subscribe(re.compile("anything"), processor1)
await event_bus.notify(context, event)
Expand Down
184 changes: 152 additions & 32 deletions aries_cloudagent/protocols/present_proof/dif/pres_exch_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -419,6 +419,8 @@ def create_vcrecord(self, cred_dict: dict) -> VCRecord:
del cred_dict["@graph"]
break
given_id = cred_dict.get("id")
if given_id and self.check_if_cred_id_derived(given_id):
given_id = str(uuid4())
# issuer
issuer = cred_dict.get("issuer")
if type(issuer) is dict:
Expand Down Expand Up @@ -1010,7 +1012,7 @@ async def apply_requirements(
return {}

nested_result = []
given_id_descriptors = {}
cred_uid_descriptors = {}
# recursion logic for nested requirements
for requirement in req.nested_req:
# recursive call
Expand All @@ -1019,7 +1021,7 @@ async def apply_requirements(
)
if result == {}:
continue
# given_id_descriptors maps applicable credentials
# cred_uid_descriptors maps applicable credentials
# to their respective descriptor.
# Structure: {cred.given_id: {
# desc_id_1: {}
Expand All @@ -1030,22 +1032,22 @@ async def apply_requirements(
for descriptor_id in result.keys():
credential_list = result.get(descriptor_id)
for credential in credential_list:
if credential.given_id not in given_id_descriptors:
given_id_descriptors[credential.given_id] = {}
given_id_descriptors[credential.given_id][descriptor_id] = {}
cred_id = credential.given_id or credential.record_id
if cred_id:
cred_uid_descriptors.setdefault(cred_id, {})[descriptor_id] = {}

if len(result.keys()) != 0:
nested_result.append(result)

exclude = {}
for given_id in given_id_descriptors.keys():
for uid in cred_uid_descriptors.keys():
# Check if number of applicable credentials
# does not meet requirement specification
if not self.is_len_applicable(req, len(given_id_descriptors[given_id])):
for descriptor_id in given_id_descriptors[given_id]:
if not self.is_len_applicable(req, len(cred_uid_descriptors[uid])):
for descriptor_id in cred_uid_descriptors[uid]:
# Add to exclude dict
# with cred.given_id + descriptor_id as key
exclude[descriptor_id + given_id] = {}
# with cred_uid + descriptor_id as key
exclude[descriptor_id + uid] = {}
# merging credentials and excluding credentials that don't satisfy the requirement
return await self.merge_nested_results(
nested_result=nested_result, exclude=exclude
Expand Down Expand Up @@ -1082,20 +1084,22 @@ async def merge_nested_results(
for res in nested_result:
for key in res.keys():
credentials = res[key]
given_id_dict = {}
uid_dict = {}
merged_credentials = []

if key in result:
for credential in result[key]:
if credential.given_id not in given_id_dict:
cred_id = credential.given_id or credential.record_id
if cred_id and cred_id not in uid_dict:
merged_credentials.append(credential)
given_id_dict[credential.given_id] = {}
uid_dict[cred_id] = {}

for credential in credentials:
if credential.given_id not in given_id_dict:
if (key + (credential.given_id)) not in exclude:
cred_id = credential.given_id or credential.record_id
if cred_id and cred_id not in uid_dict:
if (key + cred_id) not in exclude:
merged_credentials.append(credential)
given_id_dict[credential.given_id] = {}
uid_dict[cred_id] = {}
result[key] = merged_credentials
return result

Expand Down Expand Up @@ -1187,12 +1191,18 @@ async def create_vp(
def check_sign_pres(self, creds: Sequence[VCRecord]) -> bool:
"""Check if applicable creds have CredentialSubject.id set."""
for cred in creds:
if len(cred.subject_ids) > 0 and not next(
iter(cred.subject_ids)
).startswith("urn:"):
if len(cred.subject_ids) > 0 and not self.check_if_cred_id_derived(
next(iter(cred.subject_ids))
):
return True
return False

def check_if_cred_id_derived(self, id: str) -> bool:
"""Check if credential or credentialSubjet id is derived."""
if id.startswith("urn:bnid:_:c14n"):
return True
return False

async def merge(
self,
dict_descriptor_creds: dict,
Expand All @@ -1218,19 +1228,129 @@ async def merge(
for desc_id in sorted_desc_keys:
credentials = dict_descriptor_creds.get(desc_id)
for cred in credentials:
if cred.given_id not in dict_of_creds:
result.append(cred)
dict_of_creds[cred.given_id] = len(descriptors)

if f"{cred.given_id}-{cred.given_id}" not in dict_of_descriptors:
descriptor_map = InputDescriptorMapping(
id=desc_id,
fmt="ldp_vp",
path=(
f"$.verifiableCredential[{dict_of_creds[cred.given_id]}]"
),
)
descriptors.append(descriptor_map)
cred_id = cred.given_id or cred.record_id
if cred_id:
if cred_id not in dict_of_creds:
result.append(cred)
dict_of_creds[cred_id] = len(descriptors)
if f"{cred_id}-{cred_id}" not in dict_of_descriptors:
descriptor_map = InputDescriptorMapping(
id=desc_id,
fmt="ldp_vp",
path=(f"$.verifiableCredential[{dict_of_creds[cred_id]}]"),
)
descriptors.append(descriptor_map)

descriptors = sorted(descriptors, key=lambda i: i.id)
return (result, descriptors)

async def verify_received_pres(
self,
pd: PresentationDefinition,
pres: dict,
):
"""
Verify credentials received in presentation.
Args:
pres: received VerifiablePresentation
pd: PresentationDefinition
"""
descriptor_map_list = pres.get("presentation_submission").get("descriptor_map")
input_descriptors = pd.input_descriptors
inp_desc_id_contraint_map = {}
for input_descriptor in input_descriptors:
inp_desc_id_contraint_map[input_descriptor.id] = input_descriptor.constraint
for desc_map_item in descriptor_map_list:
desc_map_item_id = desc_map_item.get("id")
constraint = inp_desc_id_contraint_map.get(desc_map_item_id)
desc_map_item_path = desc_map_item.get("path")
jsonpath = parse(desc_map_item_path)
match = jsonpath.find(pres)
if len(match) == 0:
raise DIFPresExchError(
f"{desc_map_item_path} path in descriptor_map not applicable"
)
for match_item in match:
if not await self.apply_constraint_received_cred(
constraint, match_item.value
):
raise DIFPresExchError(
f"Constraint specified for {desc_map_item_id} does not "
f"apply to the enclosed credential in {desc_map_item_path}"
)

async def apply_constraint_received_cred(
self, constraint: Constraints, cred_dict: dict
) -> bool:
"""Evaluate constraint from the request against received credential."""
fields = constraint._fields
field_paths = []
credential = self.create_vcrecord(cred_dict)
for field in fields:
field_paths = field_paths + field.paths
if not await self.filter_by_field(field, credential):
return False
# Selective Disclosure check
if constraint.limit_disclosure == "required":
field_paths = set([path.replace("$.", "") for path in field_paths])
mandatory_paths = {
"@context",
"type",
"issuanceDate",
"issuer",
"proof",
"credentialSubject",
"id",
}
to_remove_from_field_paths = set()
nested_field_paths = {"credentialSubject": {"id", "type"}}
for field_path in field_paths:
if field_path.count(".") >= 1:
split_field_path = field_path.split(".")
key = ".".join(split_field_path[:-1])
value = split_field_path[-1]
nested_field_paths = self.build_nested_paths_dict(
key, value, nested_field_paths
)
to_remove_from_field_paths.add(field_path)
for to_remove_path in to_remove_from_field_paths:
field_paths.remove(to_remove_path)

field_paths = set.union(mandatory_paths, field_paths)

for attrs in cred_dict.keys():
if attrs not in field_paths:
return False
for nested_attr_key in nested_field_paths:
nested_attr_values = nested_field_paths[nested_attr_key]
split_nested_attr_key = nested_attr_key.split(".")
extracted_dict = self.nested_get(cred_dict, split_nested_attr_key)
for attrs in extracted_dict.keys():
if attrs not in nested_attr_values:
return False
return True

def nested_get(self, input_dict: dict, nested_key: Sequence[str]) -> dict:
"""Return internal dict from nested input_dict given list of nested_key."""
internal_dict_value = input_dict
for k in nested_key:
internal_dict_value = internal_dict_value.get(k, None)
return internal_dict_value

def build_nested_paths_dict(
self, key: str, value: str, nested_field_paths: dict
) -> dict:
"""Build and return nested_field_paths dict."""
if key in nested_field_paths.keys():
nested_field_paths[key].add(value)
else:
nested_field_paths[key] = {value}
split_key = key.split(".")
if len(split_key) > 1:
nested_field_paths.update(
self.build_nested_paths_dict(
".".join(split_key[:-1]), split_key[-1], nested_field_paths
)
)
return nested_field_paths
Loading

0 comments on commit e76e137

Please sign in to comment.