From 951da251e3e862f2d0a1e5076c028a481f0235dd Mon Sep 17 00:00:00 2001
From: arekkas <aeneas@ory.am>
Date: Fri, 15 Jun 2018 16:25:31 +0200
Subject: [PATCH] rules: Resolves an issue with cached matchers

This patch resolves an issue where updates would not properly propagate. This caused deleted rules to still be available in the proxy.

Closes #73
---
 rule/matcher_cached.go      |  9 +++++++
 rule/matcher_cached_http.go |  8 +++++++
 rule/matcher_test.go        | 48 +++++++++++++++++++++++++++----------
 3 files changed, 52 insertions(+), 13 deletions(-)

diff --git a/rule/matcher_cached.go b/rule/matcher_cached.go
index b8609be4e1..aa05d03e79 100644
--- a/rule/matcher_cached.go
+++ b/rule/matcher_cached.go
@@ -71,8 +71,17 @@ func (m *CachedMatcher) Refresh() error {
 		return errors.WithStack(err)
 	}
 
+	inserted := map[string]bool{}
 	for _, rule := range rules {
+		inserted[rule.ID] = true
 		m.Rules[rule.ID] = rule
 	}
+
+	for _, rule := range m.Rules {
+		if _, ok := inserted[rule.ID]; !ok {
+			delete(m.Rules, rule.ID)
+		}
+	}
+
 	return nil
 }
diff --git a/rule/matcher_cached_http.go b/rule/matcher_cached_http.go
index dc651c41c4..5313e149ce 100644
--- a/rule/matcher_cached_http.go
+++ b/rule/matcher_cached_http.go
@@ -51,6 +51,7 @@ func (m *HTTPMatcher) Refresh() error {
 		return errors.Errorf("Unable to fetch rules from backend, got status code %d but expected %s", response.StatusCode, http.StatusOK)
 	}
 
+	inserted := map[string]bool{}
 	for _, r := range rules {
 		if len(r.Match.Methods) == 0 {
 			r.Match.Methods = []string{}
@@ -64,6 +65,7 @@ func (m *HTTPMatcher) Refresh() error {
 			}
 		}
 
+		inserted[r.Id] = true
 		m.Rules[r.Id] = Rule{
 			ID:          r.Id,
 			Description: r.Description,
@@ -85,5 +87,11 @@ func (m *HTTPMatcher) Refresh() error {
 		}
 	}
 
+	for _, rule := range m.Rules {
+		if _, ok := inserted[rule.ID]; !ok {
+			delete(m.Rules, rule.ID)
+		}
+	}
+
 	return nil
 }
diff --git a/rule/matcher_test.go b/rule/matcher_test.go
index cee760b5bf..82bde2cc6a 100644
--- a/rule/matcher_test.go
+++ b/rule/matcher_test.go
@@ -100,29 +100,51 @@ func TestMatcher(t *testing.T) {
 	handler.SetRoutes(router)
 	server := httptest.NewServer(router)
 
-	for _, tr := range testRules {
-		require.NoError(t, manager.CreateRule(&tr))
-	}
-
 	matchers := map[string]Matcher{
 		"memory": NewCachedMatcher(manager),
 		"http":   NewHTTPMatcher(oathkeeper.NewSDK(server.URL)),
 	}
 
+	var testMatcher = func(t *testing.T, matcher Matcher, method string, url string, expectErr bool, expect *Rule) {
+		r, err := matcher.MatchRule(method, mustParseURL(t, url))
+		if expectErr {
+			require.Error(t, err)
+		} else {
+			require.NoError(t, err)
+			assert.EqualValues(t, *expect, *r)
+		}
+	}
+
 	for name, matcher := range matchers {
-		t.Run("matcher="+name, func(t *testing.T) {
+		t.Run("matcher="+name+"/case=empty", func(t *testing.T) {
 			require.NoError(t, matcher.Refresh())
+			testMatcher(t, matcher, "GET", "https://localhost:34/baz", true, nil)
+			testMatcher(t, matcher, "POST", "https://localhost:1234/foo", true, nil)
+			testMatcher(t, matcher, "DELETE", "https://localhost:1234/foo", true, nil)
+		})
+	}
 
-			r, err := matcher.MatchRule("GET", mustParseURL(t, "https://localhost:34/baz"))
-			require.NoError(t, err)
-			assert.EqualValues(t, testRules[1], *r)
+	for _, tr := range testRules {
+		require.NoError(t, manager.CreateRule(&tr))
+	}
 
-			r, err = matcher.MatchRule("POST", mustParseURL(t, "https://localhost:1234/foo"))
-			require.NoError(t, err)
-			assert.EqualValues(t, testRules[0], *r)
+	for name, matcher := range matchers {
+		t.Run("matcher="+name+"/case=created", func(t *testing.T) {
+			require.NoError(t, matcher.Refresh())
+			testMatcher(t, matcher, "GET", "https://localhost:34/baz", false, &testRules[1])
+			testMatcher(t, matcher, "POST", "https://localhost:1234/foo", false, &testRules[0])
+			testMatcher(t, matcher, "DELETE", "https://localhost:1234/foo", true, nil)
+		})
+	}
 
-			r, err = matcher.MatchRule("DELETE", mustParseURL(t, "https://localhost:1234/foo"))
-			require.Error(t, err)
+	require.NoError(t, manager.DeleteRule(testRules[0].ID))
+
+	for name, matcher := range matchers {
+		t.Run("matcher="+name+"/case=updated", func(t *testing.T) {
+			require.NoError(t, matcher.Refresh())
+			testMatcher(t, matcher, "GET", "https://localhost:34/baz", false, &testRules[1])
+			testMatcher(t, matcher, "POST", "https://localhost:1234/foo", true, nil)
+			testMatcher(t, matcher, "DELETE", "https://localhost:1234/foo", true, nil)
 		})
 	}
 }