From ebbb4f7ff5f8e973c5fd57f120ce17a182aa7b1b Mon Sep 17 00:00:00 2001 From: YAEGASHI Takeshi Date: Sat, 27 Jun 2020 15:26:18 +0900 Subject: [PATCH 1/3] msauth: token cache dirty state control - CacheKey(): new function - Manager - Dirty: cache dirty state - GetToken/PutToken(): new methods (Cache method is removed) - LoadFile()/SaveFile(): load/save with dirty state control - DeviceAuthorizationGrant(): save updated token correctly --- msauth/device_authorization_grant.go | 5 ++- msauth/msauth.go | 58 ++++++++++++++++++++++------ 2 files changed, 49 insertions(+), 14 deletions(-) diff --git a/msauth/device_authorization_grant.go b/msauth/device_authorization_grant.go index 4baafd8d..9ff23b3c 100644 --- a/msauth/device_authorization_grant.go +++ b/msauth/device_authorization_grant.go @@ -39,9 +39,10 @@ func (m *Manager) DeviceAuthorizationGrant(ctx context.Context, tenantID, client Endpoint: endpoint, Scopes: scopes, } - if t, ok := m.TokenCache[generateKey(tenantID, clientID)]; ok { + if t, ok := m.GetToken(CacheKey(tenantID, clientID)); ok { tt, err := config.TokenSource(ctx, t).Token() if err == nil { + m.PutToken(CacheKey(tenantID, clientID), tt) return config.TokenSource(ctx, tt), nil } if _, ok := err.(*oauth2.RetrieveError); !ok { @@ -85,7 +86,7 @@ func (m *Manager) DeviceAuthorizationGrant(ctx context.Context, tenantID, client time.Sleep(time.Second * time.Duration(interval)) token, err := m.requestToken(ctx, tenantID, clientID, values) if err == nil { - m.Cache(tenantID, clientID, token) + m.PutToken(CacheKey(tenantID, clientID), token) return config.TokenSource(ctx, token), nil } tokenError, ok := err.(*TokenError) diff --git a/msauth/msauth.go b/msauth/msauth.go index c2a86985..3aedf453 100644 --- a/msauth/msauth.go +++ b/msauth/msauth.go @@ -36,10 +36,6 @@ func (t *TokenError) Error() string { return fmt.Sprintf("%s: %s", t.ErrorObject, t.ErrorDescription) } -func generateKey(tenantID, clientID string) string { - return fmt.Sprintf("%s:%s", tenantID, clientID) -} - func deviceCodeURL(tenantID string) string { return fmt.Sprintf(endpointURLFormat, tenantID, "devicecode") } @@ -65,6 +61,7 @@ func (e *tokenJSON) expiry() (t time.Time) { // Manager is oauth2 token cache manager type Manager struct { mu sync.Mutex + Dirty bool TokenCache map[string]*oauth2.Token } @@ -87,27 +84,64 @@ func (m *Manager) SaveBytes() ([]byte, error) { return json.Marshal(m.TokenCache) } -// LoadFile loads token cache from file +// LoadFile loads token cache from file with dirty state control func (m *Manager) LoadFile(path string) error { + m.mu.Lock() + defer m.mu.Unlock() b, err := ioutil.ReadFile(path) if err != nil { return err } - return m.LoadBytes(b) + err = json.Unmarshal(b, &m.TokenCache) + if err != nil { + return err + } + m.Dirty = false + return nil } -// SaveFile saves token cache to file +// SaveFile saves token cache to file with dirty state control func (m *Manager) SaveFile(path string) error { - b, err := m.SaveBytes() + m.mu.Lock() + defer m.mu.Unlock() + if !m.Dirty { + return nil + } + b, err := json.Marshal(m.TokenCache) + if err != nil { + return err + } + err = ioutil.WriteFile(path, b, 0644) if err != nil { return err } - return ioutil.WriteFile(path, b, 0644) + m.Dirty = false + return nil +} + +// CacheKey generates a token cache key from tenantID/clientID +func CacheKey(tenantID, clientID string) string { + return fmt.Sprintf("%s:%s", tenantID, clientID) +} + +// GetToken gets a token from token cache +func (m *Manager) GetToken(cacheKey string) (*oauth2.Token, bool) { + m.mu.Lock() + defer m.mu.Unlock() + token, ok := m.TokenCache[cacheKey] + return token, ok } -// Cache stores a token into token cache -func (m *Manager) Cache(tenantID, clientID string, token *oauth2.Token) { - m.TokenCache[generateKey(tenantID, clientID)] = token +// PutToken puts a token into token cache +func (m *Manager) PutToken(cacheKey string, token *oauth2.Token) { + m.mu.Lock() + defer m.mu.Unlock() + oldToken, ok := m.TokenCache[cacheKey] + if ok && *oldToken == *token { + return + } + m.TokenCache[cacheKey] = token + m.Dirty = true } // requestToken requests a token from the token endpoint From b10d073427f3e2c0b97e4b5b6fedd34f6166ac34 Mon Sep 17 00:00:00 2001 From: YAEGASHI Takeshi Date: Sat, 27 Jun 2020 16:00:34 +0900 Subject: [PATCH 2/3] msauth: support network locations for token cache store - ReadLocation()/WriteLocation(): new functions - Manager - LoadFile()/SaveFile()/SaveFileAlways(): updated --- msauth/msauth.go | 4 +-- msauth/storage.go | 70 +++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 72 insertions(+), 2 deletions(-) create mode 100644 msauth/storage.go diff --git a/msauth/msauth.go b/msauth/msauth.go index 3aedf453..014f2522 100644 --- a/msauth/msauth.go +++ b/msauth/msauth.go @@ -88,7 +88,7 @@ func (m *Manager) SaveBytes() ([]byte, error) { func (m *Manager) LoadFile(path string) error { m.mu.Lock() defer m.mu.Unlock() - b, err := ioutil.ReadFile(path) + b, err := ReadLocation(path) if err != nil { return err } @@ -111,7 +111,7 @@ func (m *Manager) SaveFile(path string) error { if err != nil { return err } - err = ioutil.WriteFile(path, b, 0644) + err = WriteLocation(path, b, 0644) if err != nil { return err } diff --git a/msauth/storage.go b/msauth/storage.go new file mode 100644 index 00000000..7d8db8ae --- /dev/null +++ b/msauth/storage.go @@ -0,0 +1,70 @@ +package msauth + +import ( + "bytes" + "fmt" + "io/ioutil" + "net/http" + "net/url" + "os" + "strings" +) + +// ReadLocation reads data from file with path or URL +func ReadLocation(loc string) ([]byte, error) { + u, err := url.Parse(loc) + if err != nil { + return nil, err + } + switch u.Scheme { + case "", "file": + return ioutil.ReadFile(u.Path) + case "http", "https": + res, err := http.Get(loc) + if err != nil { + return nil, err + } + defer res.Body.Close() + if res.StatusCode != http.StatusOK { + return nil, fmt.Errorf("%s", res.Status) + } + b, err := ioutil.ReadAll(res.Body) + if err != nil { + return nil, err + } + return b, nil + } + return nil, fmt.Errorf("Unsupported location to load: %s", loc) +} + +// WriteLocation writes data to file with path or URL +func WriteLocation(loc string, b []byte, m os.FileMode) error { + u, err := url.Parse(loc) + if err != nil { + return err + } + switch u.Scheme { + case "", "file": + return ioutil.WriteFile(u.Path, b, m) + case "http", "https": + if strings.HasSuffix(u.Host, ".blob.core.windows.net") { + // Azure Blob Storage URL with SAS assumed here + cli := &http.Client{} + req, err := http.NewRequest(http.MethodPut, loc, bytes.NewBuffer(b)) + if err != nil { + return err + } + req.Header.Set("x-ms-blob-type", "BlockBlob") + res, err := cli.Do(req) + if err != nil { + return err + } + defer res.Body.Close() + if res.StatusCode != http.StatusCreated { + return fmt.Errorf("%s", res.Status) + } + return nil + } + } + return fmt.Errorf("Unsupported location to save: %s", loc) +} From aeb63b9996dc19b6a4a2bd7c7cfbfd110762dd74 Mon Sep 17 00:00:00 2001 From: YAEGASHI Takeshi Date: Sat, 27 Jun 2020 16:06:27 +0900 Subject: [PATCH 3/3] Update README.md --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index d389619c..b747694b 100644 --- a/README.md +++ b/README.md @@ -95,7 +95,7 @@ $ go generate ./gen - [ ] Unit tests - [x] CI - [x] Persist OAuth2 tokens in file -- [ ] Persist OAuth2 tokens in object storage like Azure Blob +- [x] Persist OAuth2 tokens in object storage like Azure Blob - [x] OAuth2 device auth grant - [x] OAuth2 client credentials grant - [x] Use string for EnumType (pointed out in #6)