Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

msauth: improvement for token cache store handling #20

Merged
merged 3 commits into from
Jun 27, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ $ go generate ./gen
- [ ] Unit tests
- [x] CI
- [x] Persist OAuth2 tokens in file
- [ ] Persist OAuth2 tokens in object storage like Azure Blob
- [x] Persist OAuth2 tokens in object storage like Azure Blob
- [x] OAuth2 device auth grant
- [x] OAuth2 client credentials grant
- [x] Use string for EnumType (pointed out in #6)
Expand Down
5 changes: 3 additions & 2 deletions msauth/device_authorization_grant.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,9 +39,10 @@ func (m *Manager) DeviceAuthorizationGrant(ctx context.Context, tenantID, client
Endpoint: endpoint,
Scopes: scopes,
}
if t, ok := m.TokenCache[generateKey(tenantID, clientID)]; ok {
if t, ok := m.GetToken(CacheKey(tenantID, clientID)); ok {
tt, err := config.TokenSource(ctx, t).Token()
if err == nil {
m.PutToken(CacheKey(tenantID, clientID), tt)
return config.TokenSource(ctx, tt), nil
}
if _, ok := err.(*oauth2.RetrieveError); !ok {
Expand Down Expand Up @@ -85,7 +86,7 @@ func (m *Manager) DeviceAuthorizationGrant(ctx context.Context, tenantID, client
time.Sleep(time.Second * time.Duration(interval))
token, err := m.requestToken(ctx, tenantID, clientID, values)
if err == nil {
m.Cache(tenantID, clientID, token)
m.PutToken(CacheKey(tenantID, clientID), token)
return config.TokenSource(ctx, token), nil
}
tokenError, ok := err.(*TokenError)
Expand Down
60 changes: 47 additions & 13 deletions msauth/msauth.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,10 +36,6 @@ func (t *TokenError) Error() string {
return fmt.Sprintf("%s: %s", t.ErrorObject, t.ErrorDescription)
}

func generateKey(tenantID, clientID string) string {
return fmt.Sprintf("%s:%s", tenantID, clientID)
}

func deviceCodeURL(tenantID string) string {
return fmt.Sprintf(endpointURLFormat, tenantID, "devicecode")
}
Expand All @@ -65,6 +61,7 @@ func (e *tokenJSON) expiry() (t time.Time) {
// Manager is oauth2 token cache manager
type Manager struct {
mu sync.Mutex
Dirty bool
TokenCache map[string]*oauth2.Token
}

Expand All @@ -87,27 +84,64 @@ func (m *Manager) SaveBytes() ([]byte, error) {
return json.Marshal(m.TokenCache)
}

// LoadFile loads token cache from file
// LoadFile loads token cache from file with dirty state control
func (m *Manager) LoadFile(path string) error {
b, err := ioutil.ReadFile(path)
m.mu.Lock()
defer m.mu.Unlock()
b, err := ReadLocation(path)
if err != nil {
return err
}
err = json.Unmarshal(b, &m.TokenCache)
if err != nil {
return err
}
return m.LoadBytes(b)
m.Dirty = false
return nil
}

// SaveFile saves token cache to file
// SaveFile saves token cache to file with dirty state control
func (m *Manager) SaveFile(path string) error {
b, err := m.SaveBytes()
m.mu.Lock()
defer m.mu.Unlock()
if !m.Dirty {
return nil
}
b, err := json.Marshal(m.TokenCache)
if err != nil {
return err
}
err = WriteLocation(path, b, 0644)
if err != nil {
return err
}
return ioutil.WriteFile(path, b, 0644)
m.Dirty = false
return nil
}

// CacheKey generates a token cache key from tenantID/clientID
func CacheKey(tenantID, clientID string) string {
return fmt.Sprintf("%s:%s", tenantID, clientID)
}

// GetToken gets a token from token cache
func (m *Manager) GetToken(cacheKey string) (*oauth2.Token, bool) {
m.mu.Lock()
defer m.mu.Unlock()
token, ok := m.TokenCache[cacheKey]
return token, ok
}

// Cache stores a token into token cache
func (m *Manager) Cache(tenantID, clientID string, token *oauth2.Token) {
m.TokenCache[generateKey(tenantID, clientID)] = token
// PutToken puts a token into token cache
func (m *Manager) PutToken(cacheKey string, token *oauth2.Token) {
m.mu.Lock()
defer m.mu.Unlock()
oldToken, ok := m.TokenCache[cacheKey]
if ok && *oldToken == *token {
return
}
m.TokenCache[cacheKey] = token
m.Dirty = true
}

// requestToken requests a token from the token endpoint
Expand Down
70 changes: 70 additions & 0 deletions msauth/storage.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
package msauth

import (
"bytes"
"fmt"
"io/ioutil"
"net/http"
"net/url"
"os"
"strings"
)

// ReadLocation reads data from file with path or URL
func ReadLocation(loc string) ([]byte, error) {
u, err := url.Parse(loc)
if err != nil {
return nil, err
}
switch u.Scheme {
case "", "file":
return ioutil.ReadFile(u.Path)
case "http", "https":
res, err := http.Get(loc)
if err != nil {
return nil, err
}
defer res.Body.Close()
if res.StatusCode != http.StatusOK {
return nil, fmt.Errorf("%s", res.Status)
}
b, err := ioutil.ReadAll(res.Body)
if err != nil {
return nil, err
}
return b, nil
}
return nil, fmt.Errorf("Unsupported location to load: %s", loc)
}

// WriteLocation writes data to file with path or URL
func WriteLocation(loc string, b []byte, m os.FileMode) error {
u, err := url.Parse(loc)
if err != nil {
return err
}
switch u.Scheme {
case "", "file":
return ioutil.WriteFile(u.Path, b, m)
case "http", "https":
if strings.HasSuffix(u.Host, ".blob.core.windows.net") {
// Azure Blob Storage URL with SAS assumed here
cli := &http.Client{}
req, err := http.NewRequest(http.MethodPut, loc, bytes.NewBuffer(b))
if err != nil {
return err
}
req.Header.Set("x-ms-blob-type", "BlockBlob")
res, err := cli.Do(req)
if err != nil {
return err
}
defer res.Body.Close()
if res.StatusCode != http.StatusCreated {
return fmt.Errorf("%s", res.Status)
}
return nil
}
}
return fmt.Errorf("Unsupported location to save: %s", loc)
}