Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ast+topdown: implement early exit semantics for complete document and function rules with one (ground) value #3898

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 43 additions & 19 deletions ast/index.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,10 +32,11 @@ type RuleIndex interface {

// IndexResult contains the result of an index lookup.
type IndexResult struct {
Kind DocKind
Rules []*Rule
Else map[*Rule][]*Rule
Default *Rule
Kind DocKind
Rules []*Rule
Else map[*Rule][]*Rule
Default *Rule
EarlyExit bool
}

// NewIndexResult returns a new IndexResult object.
Expand Down Expand Up @@ -114,13 +115,11 @@ func (i *baseDocEqIndex) Build(rules []*Rule) bool {
// Insert rule into trie with (insertion order, priority order)
// tuple. Retaining the insertion order allows us to return rules
// in the order they were passed to this function.
node.rules = append(node.rules, &ruleNode{[...]int{idx, prio}, rule})
node.append([...]int{idx, prio}, rule)
prio++
return false
})

}

return true
}

Expand All @@ -143,6 +142,7 @@ func (i *baseDocEqIndex) Lookup(resolver ValueResolver) (*IndexResult, error) {
})
nodes := tr.unordered[pos]
root := nodes[0].rule

result.Rules = append(result.Rules, root)
if len(nodes) > 1 {
result.Else[root] = make([]*Rule, len(nodes)-1)
Expand All @@ -152,6 +152,8 @@ func (i *baseDocEqIndex) Lookup(resolver ValueResolver) (*IndexResult, error) {
}
}

result.EarlyExit = tr.values.Len() == 1 && tr.values.Slice()[0].IsGround()

return result, nil
}

Expand Down Expand Up @@ -181,6 +183,8 @@ func (i *baseDocEqIndex) AllRules(resolver ValueResolver) (*IndexResult, error)
}
}

result.EarlyExit = tr.values.Len() == 1 && tr.values.Slice()[0].IsGround()

return result, nil
}

Expand All @@ -190,9 +194,7 @@ type ruleWalker struct {

func (r *ruleWalker) Do(x interface{}) trieWalker {
tn := x.(*trieNode)
for _, rn := range tn.rules {
r.result.Add(rn)
}
r.result.Add(tn)
return r
}

Expand Down Expand Up @@ -397,25 +399,33 @@ type trieWalker interface {
type trieTraversalResult struct {
unordered map[int][]*ruleNode
ordering []int
values Set
}

func newTrieTraversalResult() *trieTraversalResult {
return &trieTraversalResult{
unordered: map[int][]*ruleNode{},
values: NewSet(),
}
}

func (tr *trieTraversalResult) Add(node *ruleNode) {
root := node.prio[0]
nodes, ok := tr.unordered[root]
if !ok {
tr.ordering = append(tr.ordering, root)
func (tr *trieTraversalResult) Add(t *trieNode) {
for _, node := range t.rules {
root := node.prio[0]
nodes, ok := tr.unordered[root]
if !ok {
tr.ordering = append(tr.ordering, root)
}
tr.unordered[root] = append(nodes, node)
}
if t.values != nil {
t.values.Foreach(func(v *Term) { tr.values.Add(v) })
}
tr.unordered[root] = append(nodes, node)
}

type trieNode struct {
ref Ref
values Set
mappers []*valueMapper
next *trieNode
any *trieNode
Expand Down Expand Up @@ -457,9 +467,25 @@ func (node *trieNode) String() string {
if len(node.mappers) > 0 {
flags = append(flags, fmt.Sprintf("%d mapper(s)", len(node.mappers)))
}
if l := node.values.Len(); l > 0 {
flags = append(flags, fmt.Sprintf("%d value(s)", l))
}
return strings.Join(flags, " ")
}

func (node *trieNode) append(prio [2]int, rule *Rule) {
node.rules = append(node.rules, &ruleNode{prio, rule})

if node.values != nil {
node.values.Add(rule.Head.Value)
return
}

if node.values == nil && rule.Head.DocKind() == CompleteDoc {
node.values = NewSet(rule.Head.Value)
}
}

type ruleNode struct {
prio [2]int
rule *Rule
Expand Down Expand Up @@ -513,9 +539,7 @@ func (node *trieNode) Traverse(resolver ValueResolver, tr *trieTraversalResult)
return nil
}

for i := range node.rules {
tr.Add(node.rules[i])
}
tr.Add(node)

return node.next.traverse(resolver, tr)
}
Expand Down
Loading