Skip to content

Commit

Permalink
Visitor for authorization logic AST
Browse files Browse the repository at this point in the history
  • Loading branch information
Andrew Ferraiuolo committed Aug 9, 2022
1 parent be6ef8e commit b748dc6
Show file tree
Hide file tree
Showing 6 changed files with 583 additions and 1 deletion.
27 changes: 27 additions & 0 deletions src/ir/auth_logic/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,21 @@ cc_library(
],
)

cc_library(
name = "auth_logic_ast_traversing_visitor",
hdrs = [
"auth_logic_ast_traversing_visitor.h",
"auth_logic_ast_visitor.h",
],
deps = [
":ast",
"//src/common/logging",
"//src/common/utils:fold",
"//src/common/utils:overloaded",
"//src/common/utils:types",
],
)

cc_library(
name = "lowering_ast_datalog",
srcs = ["lowering_ast_datalog.cc"],
Expand Down Expand Up @@ -94,6 +109,18 @@ cc_test(
],
)

cc_test(
name = "ast_visitor_test",
srcs = ["auth_logic_ast_traversing_visitor_test.cc"],
deps = [
":ast",
":auth_logic_ast_traversing_visitor",
"//src/common/testing:gtest",
"//src/ir/datalog:program",
"@absl//absl/container:flat_hash_set",
],
)

cc_library(
name = "ast_construction",
srcs = ["ast_construction.cc"],
Expand Down
101 changes: 101 additions & 0 deletions src/ir/auth_logic/ast.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
#include <vector>

#include "absl/hash/hash.h"
#include "src/ir/auth_logic/auth_logic_ast_visitor.h"
#include "src/ir/datalog/program.h"

namespace raksha::ir::auth_logic {
Expand All @@ -34,6 +35,16 @@ class Principal {
explicit Principal(std::string name) : name_(std::move(name)) {}
const std::string& name() const { return name_; }

template <typename Derived, typename Result>
Result Accept(AuthLogicAstVisitor<Derived, Result, false>& visitor) {
return visitor.Visit(*this);
}

template <typename Derived, typename Result>
Result Accept(AuthLogicAstVisitor<Derived, Result, true>& visitor) const {
return visitor.Visit(*this);
}

private:
std::string name_;
};
Expand All @@ -47,6 +58,16 @@ class Attribute {
const Principal& principal() const { return principal_; }
const datalog::Predicate& predicate() const { return predicate_; }

template <typename Derived, typename Result>
Result Accept(AuthLogicAstVisitor<Derived, Result, false>& visitor) {
return visitor.Visit(*this);
}

template <typename Derived, typename Result>
Result Accept(AuthLogicAstVisitor<Derived, Result, true>& visitor) const {
return visitor.Visit(*this);
}

private:
Principal principal_;
datalog::Predicate predicate_;
Expand All @@ -62,6 +83,16 @@ class CanActAs {
const Principal& left_principal() const { return left_principal_; }
const Principal& right_principal() const { return right_principal_; }

template <typename Derived, typename Result>
Result Accept(AuthLogicAstVisitor<Derived, Result, false>& visitor) {
return visitor.Visit(*this);
}

template <typename Derived, typename Result>
Result Accept(AuthLogicAstVisitor<Derived, Result, true>& visitor) const {
return visitor.Visit(*this);
}

private:
Principal left_principal_;
Principal right_principal_;
Expand All @@ -85,6 +116,16 @@ class BaseFact {
explicit BaseFact(BaseFactVariantType value) : value_(std::move(value)){};
const BaseFactVariantType& GetValue() const { return value_; }

template <typename Derived, typename Result>
Result Accept(AuthLogicAstVisitor<Derived, Result, false>& visitor) {
return visitor.Visit(*this);
}

template <typename Derived, typename Result>
Result Accept(AuthLogicAstVisitor<Derived, Result, true>& visitor) const {
return visitor.Visit(*this);
}

private:
BaseFactVariantType value_;
};
Expand All @@ -103,6 +144,16 @@ class Fact {

const BaseFact& base_fact() const { return base_fact_; }

template <typename Derived, typename Result>
Result Accept(AuthLogicAstVisitor<Derived, Result, false>& visitor) {
return visitor.Visit(*this);
}

template <typename Derived, typename Result>
Result Accept(AuthLogicAstVisitor<Derived, Result, true>& visitor) const {
return visitor.Visit(*this);
}

private:
std::forward_list<Principal> delegation_chain_;
BaseFact base_fact_;
Expand All @@ -118,6 +169,16 @@ class ConditionalAssertion {
const Fact& lhs() const { return lhs_; }
const std::vector<BaseFact>& rhs() const { return rhs_; }

template <typename Derived, typename Result>
Result Accept(AuthLogicAstVisitor<Derived, Result, false>& visitor) {
return visitor.Visit(*this);
}

template <typename Derived, typename Result>
Result Accept(AuthLogicAstVisitor<Derived, Result, true>& visitor) const {
return visitor.Visit(*this);
}

private:
Fact lhs_;
std::vector<BaseFact> rhs_;
Expand All @@ -135,6 +196,16 @@ class Assertion {
explicit Assertion(AssertionVariantType value) : value_(std::move(value)) {}
const AssertionVariantType& GetValue() const { return value_; }

template <typename Derived, typename Result>
Result Accept(AuthLogicAstVisitor<Derived, Result, false>& visitor) {
return visitor.Visit(*this);
}

template <typename Derived, typename Result>
Result Accept(AuthLogicAstVisitor<Derived, Result, true>& visitor) const {
return visitor.Visit(*this);
}

private:
AssertionVariantType value_;
};
Expand All @@ -147,6 +218,16 @@ class SaysAssertion {
const Principal& principal() const { return principal_; }
const std::vector<Assertion>& assertions() const { return assertions_; }

template <typename Derived, typename Result>
Result Accept(AuthLogicAstVisitor<Derived, Result, false>& visitor) {
return visitor.Visit(*this);
}

template <typename Derived, typename Result>
Result Accept(AuthLogicAstVisitor<Derived, Result, true>& visitor) const {
return visitor.Visit(*this);
}

private:
Principal principal_;
std::vector<Assertion> assertions_;
Expand All @@ -164,6 +245,16 @@ class Query {
const Principal& principal() const { return principal_; }
const Fact& fact() const { return fact_; }

template <typename Derived, typename Result>
Result Accept(AuthLogicAstVisitor<Derived, Result, false>& visitor) {
return visitor.Visit(*this);
}

template <typename Derived, typename Result>
Result Accept(AuthLogicAstVisitor<Derived, Result, true>& visitor) const {
return visitor.Visit(*this);
}

private:
std::string name_;
Principal principal_;
Expand Down Expand Up @@ -191,6 +282,16 @@ class Program {

const std::vector<Query>& queries() const { return queries_; }

template <typename Derived, typename Result>
Result Accept(AuthLogicAstVisitor<Derived, Result, false>& visitor) {
return visitor.Visit(*this);
}

template <typename Derived, typename Result>
Result Accept(AuthLogicAstVisitor<Derived, Result, true>& visitor) const {
return visitor.Visit(*this);
}

private:
std::vector<datalog::RelationDeclaration> relation_declarations_;
std::vector<SaysAssertion> says_assertions_;
Expand Down
Loading

0 comments on commit b748dc6

Please sign in to comment.