diff --git a/src/ir/auth_logic/BUILD b/src/ir/auth_logic/BUILD index cb0117dd1..2a0590f4e 100644 --- a/src/ir/auth_logic/BUILD +++ b/src/ir/auth_logic/BUILD @@ -36,6 +36,21 @@ cc_library( ], ) +cc_library( + name = "auth_logic_ast_traversing_visitor", + hdrs = [ + "auth_logic_ast_visitor.h", + "auth_logic_ast_traversing_visitor.h" + ], + deps = [ + ":ast", + "//src/common/utils:fold", + "//src/common/utils:types", + "//src/common/utils:overloaded", + "//src/common/logging" + ], +) + cc_library( name = "lowering_ast_datalog", srcs = ["lowering_ast_datalog.cc"], @@ -94,6 +109,18 @@ cc_test( ], ) +cc_test( + name = "ast_visitor_test", + srcs = ["auth_logic_ast_traversing_visitor_test.cc"], + deps = [ + ":ast", + ":auth_logic_ast_traversing_visitor", + "//src/common/testing:gtest", + "//src/ir/datalog:program", + "@absl//absl/container:flat_hash_set", + ] +) + cc_library( name = "ast_construction", srcs = ["ast_construction.cc"], diff --git a/src/ir/auth_logic/ast.h b/src/ir/auth_logic/ast.h index 78163f179..bcd798c38 100644 --- a/src/ir/auth_logic/ast.h +++ b/src/ir/auth_logic/ast.h @@ -25,6 +25,7 @@ #include #include "absl/hash/hash.h" +#include "src/ir/auth_logic/auth_logic_ast_visitor.h" #include "src/ir/datalog/program.h" namespace raksha::ir::auth_logic { @@ -34,6 +35,16 @@ class Principal { explicit Principal(std::string name) : name_(std::move(name)) {} const std::string& name() const { return name_; } + template + Result Accept(AuthLogicAstVisitor& visitor) { + return visitor.Visit(*this); + } + + template + Result Accept(AuthLogicAstVisitor& visitor) const { + return visitor.Visit(*this); + } + private: std::string name_; }; @@ -47,6 +58,16 @@ class Attribute { const Principal& principal() const { return principal_; } const datalog::Predicate& predicate() const { return predicate_; } + template + Result Accept(AuthLogicAstVisitor& visitor) { + return visitor.Visit(*this); + } + + template + Result Accept(AuthLogicAstVisitor& visitor) const { + return visitor.Visit(*this); + } + private: Principal principal_; datalog::Predicate predicate_; @@ -62,6 +83,16 @@ class CanActAs { const Principal& left_principal() const { return left_principal_; } const Principal& right_principal() const { return right_principal_; } + template + Result Accept(AuthLogicAstVisitor& visitor) { + return visitor.Visit(*this); + } + + template + Result Accept(AuthLogicAstVisitor& visitor) const { + return visitor.Visit(*this); + } + private: Principal left_principal_; Principal right_principal_; @@ -85,6 +116,16 @@ class BaseFact { explicit BaseFact(BaseFactVariantType value) : value_(std::move(value)){}; const BaseFactVariantType& GetValue() const { return value_; } + template + Result Accept(AuthLogicAstVisitor& visitor) { + return visitor.Visit(*this); + } + + template + Result Accept(AuthLogicAstVisitor& visitor) const { + return visitor.Visit(*this); + } + private: BaseFactVariantType value_; }; @@ -103,6 +144,16 @@ class Fact { const BaseFact& base_fact() const { return base_fact_; } + template + Result Accept(AuthLogicAstVisitor& visitor) { + return visitor.Visit(*this); + } + + template + Result Accept(AuthLogicAstVisitor& visitor) const { + return visitor.Visit(*this); + } + private: std::forward_list delegation_chain_; BaseFact base_fact_; @@ -118,6 +169,16 @@ class ConditionalAssertion { const Fact& lhs() const { return lhs_; } const std::vector& rhs() const { return rhs_; } + template + Result Accept(AuthLogicAstVisitor& visitor) { + return visitor.Visit(*this); + } + + template + Result Accept(AuthLogicAstVisitor& visitor) const { + return visitor.Visit(*this); + } + private: Fact lhs_; std::vector rhs_; @@ -135,6 +196,16 @@ class Assertion { explicit Assertion(AssertionVariantType value) : value_(std::move(value)) {} const AssertionVariantType& GetValue() const { return value_; } + template + Result Accept(AuthLogicAstVisitor& visitor) { + return visitor.Visit(*this); + } + + template + Result Accept(AuthLogicAstVisitor& visitor) const { + return visitor.Visit(*this); + } + private: AssertionVariantType value_; }; @@ -147,6 +218,16 @@ class SaysAssertion { const Principal& principal() const { return principal_; } const std::vector& assertions() const { return assertions_; } + template + Result Accept(AuthLogicAstVisitor& visitor) { + return visitor.Visit(*this); + } + + template + Result Accept(AuthLogicAstVisitor& visitor) const { + return visitor.Visit(*this); + } + private: Principal principal_; std::vector assertions_; @@ -164,6 +245,16 @@ class Query { const Principal& principal() const { return principal_; } const Fact& fact() const { return fact_; } + template + Result Accept(AuthLogicAstVisitor& visitor) { + return visitor.Visit(*this); + } + + template + Result Accept(AuthLogicAstVisitor& visitor) const { + return visitor.Visit(*this); + } + private: std::string name_; Principal principal_; @@ -191,6 +282,16 @@ class Program { const std::vector& queries() const { return queries_; } + template + Result Accept(AuthLogicAstVisitor& visitor) { + return visitor.Visit(*this); + } + + template + Result Accept(AuthLogicAstVisitor& visitor) const { + return visitor.Visit(*this); + } + private: std::vector relation_declarations_; std::vector says_assertions_; diff --git a/src/ir/auth_logic/auth_logic_ast_traversing_visitor.h b/src/ir/auth_logic/auth_logic_ast_traversing_visitor.h new file mode 100644 index 000000000..80151047b --- /dev/null +++ b/src/ir/auth_logic/auth_logic_ast_traversing_visitor.h @@ -0,0 +1,336 @@ +//----------------------------------------------------------------------------- +// Copyright 2022 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +//---------------------------------------------------------------------------- +#ifndef SRC_IR_AUTH_LOGIC_AST_TRAVERSING_VISITOR_H_ +#define SRC_IR_AUTH_LOGIC_AST_TRAVERSING_VISITOR_H_ + +#include "src/common/utils/fold.h" +#include "src/common/utils/types.h" +#include "src/common/utils/overloaded.h" +#include "src/ir/auth_logic/ast.h" +#include "src/ir/auth_logic/auth_logic_ast_visitor.h" +#include "src/ir/datalog/program.h" +#include + +// The implementation of this visitor over the AST nodes of authorizaiton logic +// directly follows the one for the IR in /src/ir/ir_traversing_visitor.h + +namespace raksha::ir::auth_logic { + +// A visitor that also traverses the children of a node and allows performing +// different actions before (PreVisit) and after (PostVisit) the children are +// visited. Override any of the `PreVisit` and `PostVisit` methods as needed. +template +class AuthLogicAstTraversingVisitor : public AuthLogicAstVisitor { + private: + template + struct DefaultValueGetter { + static ValueType Get() { LOG(FATAL) << "Override required for non-default-constructible type."; } + }; + + template + struct DefaultValueGetter>> { + static ValueType Get() { return ValueType(); } + }; + + public: + virtual ~AuthLogicAstTraversingVisitor() {} + + // Gives a default value for all 'PreVisit's to start with. + // Should be over-ridden if the Result is not default constructable. + virtual Result GetDefaultValue() { return DefaultValueGetter::Get(); } + + // Used to accumulate child results from the node's children. + // Should discard or merge `child_result` into the `accumulator`. + virtual Result FoldResult(Result accumulator, + Result child_result) { + return accumulator; + } + // Invoked before all the children of `principal` are visited. + virtual Result PreVisit(CopyConst& principal) { + return GetDefaultValue(); + } + // Invoked after all the children of `principal` are visited. + virtual Result PostVisit(CopyConst& principal, Result in_order_result) { + return in_order_result; + } + // Invoked before all the children of `attribute` are visited. + virtual Result PreVisit(CopyConst& attribute) { + return GetDefaultValue(); + } + // Invoked after all the children of `attribute` are visited. + virtual Result PostVisit(CopyConst& attribute, Result in_order_result) { + return in_order_result; + } + // Invoked before all the children of `canActAs` are visited. + virtual Result PreVisit(CopyConst& canActAs) { + return GetDefaultValue(); + } + // Invoked after all the children of `canActAs` are visited. + virtual Result PostVisit(CopyConst& canActAs, Result in_order_result) { + return in_order_result; + } + // Invoked before all the children of `baseFact` are visited. + virtual Result PreVisit(CopyConst& baseFact) { + return GetDefaultValue(); + } + // Invoked after all the children of `baseFact` are visited. + virtual Result PostVisit(CopyConst& baseFact, Result in_order_result) { + return in_order_result; + } + // Invoked before all the children of `fact` are visited. + virtual Result PreVisit(CopyConst& fact) { + return GetDefaultValue(); + } + // Invoked after all the children of `fact` are visited. + virtual Result PostVisit(CopyConst& fact, Result in_order_result) { + return in_order_result; + } + // Invoked before all the children of `conditionalAssertion` are visited. + virtual Result PreVisit(CopyConst& conditionalAssertion) { + return GetDefaultValue(); + } + // Invoked after all the children of `conditionalAssertion` are visited. + virtual Result PostVisit(CopyConst& conditionalAssertion, Result in_order_result) { + return in_order_result; + } + // Invoked before all the children of `assertion` are visited. + virtual Result PreVisit(CopyConst& assertion) { + return GetDefaultValue(); + } + // Invoked after all the children of `assertion` are visited. + virtual Result PostVisit(CopyConst& assertion, Result in_order_result) { + return in_order_result; + } + // Invoked before all the children of `saysAssertion` are visited. + virtual Result PreVisit(CopyConst& saysAssertion) { + return GetDefaultValue(); + } + // Invoked after all the children of `saysAssertion` are visited. + virtual Result PostVisit(CopyConst& saysAssertion, Result in_order_result) { + return in_order_result; + } + // Invoked before all the children of `query` are visited. + virtual Result PreVisit(CopyConst& query) { + return GetDefaultValue(); + } + // Invoked after all the children of `query` are visited. + virtual Result PostVisit(CopyConst& query, Result in_order_result) { + return in_order_result; + } + // Invoked before all the children of `program` are visited. + virtual Result PreVisit(CopyConst& program) { + return GetDefaultValue(); + } + // Invoked after all the children of `program` are visited. + virtual Result PostVisit(CopyConst& program, Result in_order_result) { + return in_order_result; + } + + // The Visits for the Datalog IR classes (RelationDeclaration, Predciate) + // are here temporarily until these AST classes are refactored out + // of the Datalog IR. + + virtual Result Visit(CopyConst& relationDeclaration) { + return GetDefaultValue(); + } + + virtual Result Visit(CopyConst& predicate) { + return GetDefaultValue(); + } + + // The remaining Visits are meant to follow the convention + + Result Visit(CopyConst& principal) final override { + Result pre_visit_result = PreVisit(principal); + return PostVisit(principal, std::move(pre_visit_result)); + } + + Result Visit(CopyConst& attribute) final override { + Result pre_visit_result = PreVisit(attribute); + Result fold_result = FoldResult( + FoldResult(std::move(pre_visit_result), + attribute.principal().Accept(*this)), + // TODO fix this to use predicate().Accept once predicate + // has been refactored into ast.h + Visit(attribute.predicate())); + return PostVisit(attribute, std::move(fold_result)); + } + + Result Visit(CopyConst& canActAs) final override { + Result pre_visit_result = PreVisit(canActAs); + Result fold_result = FoldResult( + FoldResult(std::move(pre_visit_result), + canActAs.left_principal().Accept(*this)), + canActAs.right_principal().Accept(*this)); + return PostVisit(canActAs, std::move(fold_result)); + } + + Result Visit(CopyConst& baseFact) final override { + Result pre_visit_result = PreVisit(baseFact); + Result variant_visit_result = std::visit( + raksha::utils::overloaded{ + [this](const datalog::Predicate& pred) { + return VariantVisit(pred); + }, + [this](const Attribute& attrib) { + return VariantVisit(attrib); + }, + [this](const CanActAs& canActAs) { + return VariantVisit(canActAs); + } + }, + baseFact.GetValue() + ); + Result fold_result = FoldResult(std::move(pre_visit_result), + std::move(variant_visit_result)); + return PostVisit(baseFact, std::move(fold_result)); + } + + Result Visit(CopyConst& fact) final override { + Result pre_visit_result = PreVisit(fact); + Result base_fact_result = FoldResult( + std::move(pre_visit_result), + fact.base_fact().Accept(*this)); + Result fold_result = common::utils::fold( + fact.delegation_chain(), std::move(base_fact_result), + [this](Result acc, + CopyConst principal) { + return FoldResult(std::move(acc), principal.Accept(*this)); + } + ); + return PostVisit(fact, std::move(fold_result)); + } + + Result Visit(CopyConst& conditionalAssertion) final override { + Result pre_visit_result = PreVisit(conditionalAssertion); + Result lhs_result = FoldResult( + std::move(pre_visit_result), + conditionalAssertion.lhs().Accept(*this) + ); + Result fold_result = common::utils::fold( + conditionalAssertion.rhs(), std::move(lhs_result), + [this](Result acc, + CopyConst baseFact) { + return FoldResult(std::move(acc), baseFact.Accept(*this)); + } + ); + return PostVisit(conditionalAssertion, std::move(fold_result)); + } + + Result Visit(CopyConst& assertion) final override { + Result pre_visit_result = PreVisit(assertion); + Result variant_visit_result = std::visit( + raksha::utils::overloaded{ + [this](const Fact& fact) { + return VariantVisit(fact); + }, + [this](const ConditionalAssertion& condAssertion) { + return VariantVisit(condAssertion); + } + }, + assertion.GetValue() + ); + Result fold_result = FoldResult(std::move(pre_visit_result), + std::move(variant_visit_result)); + return PostVisit(assertion, std::move(fold_result)); + } + + Result Visit(CopyConst& saysAssertion) final override { + Result pre_visit_result = PreVisit(saysAssertion); + Result principal_result = FoldResult( + std::move(pre_visit_result), + saysAssertion.principal().Accept(*this) + ); + Result fold_result = common::utils::fold( + saysAssertion.assertions(), std::move(principal_result), + [this](Result acc, + CopyConst assertion) { + return FoldResult(std::move(acc), assertion.Accept(*this)); + } + ); + return PostVisit(saysAssertion, fold_result); + } + + Result Visit(CopyConst& query) final override { + Result pre_visit_result = PreVisit(query); + Result fold_result = FoldResult( + std::move(pre_visit_result), + FoldResult( + query.principal().Accept(*this), + query.fact().Accept(*this) + )); + return PostVisit(query, fold_result); + } + + Result Visit(CopyConst& program) final override { + Result pre_visit_result = PreVisit(program); + Result declarations_result = common::utils::fold( + program.relation_declarations(), std::move(pre_visit_result), + [this](Result acc, + CopyConst relationDeclaration) { + // TODO Fix this to accept once once relationDeclaration + // has been refactored into ast.h + return FoldResult(std::move(acc), Visit(relationDeclaration)); + } + ); + Result says_assertions_result = common::utils::fold( + program.says_assertions(), std::move(declarations_result), + [this](Result acc, + CopyConst saysAssertion) { + return FoldResult(std::move(acc), saysAssertion.Accept(*this)); + } + ); + Result queries_result = common::utils::fold( + program.queries(), std::move(says_assertions_result), + [this](Result acc, + CopyConst query) { + return FoldResult(std::move(acc), query.Accept(*this)); + } + ); + return PostVisit(program, queries_result); + } + + // The VariantVisit methods use overloading to help visit + // the alternatives for the underlying std::variants in the AST + + // For BaseFactVariantType + Result VariantVisit(datalog::Predicate predicate) { + // TODO once a separate predicate has been added to ast.h + // this should use predicate.Accept(*this); + return Visit(predicate); + } + Result VariantVisit(Attribute attribute) { + return attribute.Accept(*this); + } + Result VariantVisit(CanActAs canActAs) { + return canActAs.Accept(*this); + } + + // For AssertionVariantType + Result VariantVisit(Fact fact) { + return fact.Accept(*this); + } + Result VariantVisit(ConditionalAssertion conditionalAssertion) { + return conditionalAssertion.Accept(*this); + } + + private: + +}; + +} // namespace raksha::ir::auth_logic + +#endif // SRC_IR_AUTH_LOGIC_AST_TRAVERSING_VISITOR_H_ diff --git a/src/ir/auth_logic/auth_logic_ast_traversing_visitor_test.cc b/src/ir/auth_logic/auth_logic_ast_traversing_visitor_test.cc new file mode 100644 index 000000000..031d360b7 --- /dev/null +++ b/src/ir/auth_logic/auth_logic_ast_traversing_visitor_test.cc @@ -0,0 +1,78 @@ +//----------------------------------------------------------------------------- +// Copyright 2022 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +//---------------------------------------------------------------------------- + +#include "src/ir/auth_logic/auth_logic_ast_traversing_visitor.h" + +#include "src/common/testing/gtest.h" +#include "src/ir/auth_logic/ast.h" +#include "src/ir/datalog/program.h" +#include "absl/container/flat_hash_set.h" + +namespace raksha::ir::auth_logic { +namespace { + +// A visitor that makes a set of all the names of principals in the program +class PrincipalNameCollectorVisitor + : public AuthLogicAstTraversingVisitor> { + + public: + PrincipalNameCollectorVisitor() {} + + absl::flat_hash_set GetDefaultValue() override { return {}; } + + absl::flat_hash_set FoldResult( + absl::flat_hash_set acc, + absl::flat_hash_set child_result) { + acc.merge(std::move(child_result)); + return std::move(acc); + } + + absl::flat_hash_set PreVisit(const Principal& principal) override { + return { principal.name() }; + } + +}; + +Program BuildTestProgram1() { + SaysAssertion assertion1 = SaysAssertion(Principal("PrincipalA"), + { Assertion( + Fact({}, BaseFact( + datalog::Predicate("foo", {"bar", "baz"}, datalog::kPositive)))) }); + SaysAssertion assertion2 = SaysAssertion(Principal("PrincipalA"), + { Assertion( + Fact({}, BaseFact( + datalog::Predicate("foo", {"barbar", "bazbaz"}, datalog::kPositive)))) }); + SaysAssertion assertion3 = SaysAssertion(Principal("PrincipalB"), + { Assertion( + Fact({}, BaseFact( + CanActAs(Principal("PrincipalA"), Principal("PrincipalC")))))}); + std::vector assertion_list = { + std::move(assertion1), std::move(assertion2), std::move(assertion3)}; + return Program({}, std::move(assertion_list), {}); +} + +TEST(AuthLogicAstTraversingVisitorTest, + PrincipalNameCollectorTest) { + Program test_prog = BuildTestProgram1(); + PrincipalNameCollectorVisitor collector_visitor; + const absl::flat_hash_set result = test_prog.Accept(collector_visitor); + const absl::flat_hash_set expected = {"PrincipalA", "PrincipalB", "PrincipalC"}; + EXPECT_EQ(result, expected); +} + +} // namespace +} // namespace raksha::ir::auth_logic \ No newline at end of file diff --git a/src/ir/auth_logic/auth_logic_ast_visitor.h b/src/ir/auth_logic/auth_logic_ast_visitor.h new file mode 100644 index 000000000..b15eebbc8 --- /dev/null +++ b/src/ir/auth_logic/auth_logic_ast_visitor.h @@ -0,0 +1,53 @@ +//----------------------------------------------------------------------------- +// Copyright 2022 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +//---------------------------------------------------------------------------- +#ifndef SRC_IR_AUTH_LOGIC_AST_VISITOR_H_ +#define SRC_IR_AUTH_LOGIC_AST_VISITOR_H_ + +#include "src/common/utils/types.h" + +namespace raksha::ir::auth_logic { + +class Principal; +class Attribute; +class CanActAs; +class BaseFact; +class Fact; +class ConditionalAssertion; +class Assertion; +class SaysAssertion; +class Query; +class Program; + +template +class AuthLogicAstVisitor { + public: + virtual ~AuthLogicAstVisitor() {} + virtual Result Visit(CopyConst& principal) = 0; + virtual Result Visit(CopyConst& attribute) = 0; + virtual Result Visit(CopyConst& canActAs) = 0; + virtual Result Visit(CopyConst& baseFact) = 0; + virtual Result Visit(CopyConst& fact) = 0; + virtual Result Visit( + CopyConst& conditionalAssertion) = 0; + virtual Result Visit(CopyConst& assertion) = 0; + virtual Result Visit(CopyConst& saysAssertion) = 0; + virtual Result Visit(CopyConst& query) = 0; + virtual Result Visit(CopyConst& program) = 0; +}; + +} // namespace raksha::ir::auth_logic + +#endif // SRC_IR_AUTH_LOGIC_AST_VISITOR_H_ \ No newline at end of file diff --git a/src/ir/ir_visitor.h b/src/ir/ir_visitor.h index a412a6743..4fdd88260 100644 --- a/src/ir/ir_visitor.h +++ b/src/ir/ir_visitor.h @@ -31,7 +31,7 @@ class IRVisitor { public: virtual ~IRVisitor() {} virtual Result Visit(CopyConst& module) = 0; - virtual Result Visit(CopyConst& operation) = 0; + virtual Result Visit(CopyConst& block) = 0; virtual Result Visit(CopyConst& operation) = 0; };