Skip to content

Commit

Permalink
Merge pull request #20 from yaegashi/msauth-token-cache
Browse files Browse the repository at this point in the history
msauth: improvement for token cache store handling
  • Loading branch information
yaegashi authored Jun 27, 2020
2 parents 7f5a8a4 + aeb63b9 commit 93fc163
Show file tree
Hide file tree
Showing 4 changed files with 121 additions and 16 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ $ go generate ./gen
- [ ] Unit tests
- [x] CI
- [x] Persist OAuth2 tokens in file
- [ ] Persist OAuth2 tokens in object storage like Azure Blob
- [x] Persist OAuth2 tokens in object storage like Azure Blob
- [x] OAuth2 device auth grant
- [x] OAuth2 client credentials grant
- [x] Use string for EnumType (pointed out in #6)
Expand Down
5 changes: 3 additions & 2 deletions msauth/device_authorization_grant.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,9 +39,10 @@ func (m *Manager) DeviceAuthorizationGrant(ctx context.Context, tenantID, client
Endpoint: endpoint,
Scopes: scopes,
}
if t, ok := m.TokenCache[generateKey(tenantID, clientID)]; ok {
if t, ok := m.GetToken(CacheKey(tenantID, clientID)); ok {
tt, err := config.TokenSource(ctx, t).Token()
if err == nil {
m.PutToken(CacheKey(tenantID, clientID), tt)
return config.TokenSource(ctx, tt), nil
}
if _, ok := err.(*oauth2.RetrieveError); !ok {
Expand Down Expand Up @@ -85,7 +86,7 @@ func (m *Manager) DeviceAuthorizationGrant(ctx context.Context, tenantID, client
time.Sleep(time.Second * time.Duration(interval))
token, err := m.requestToken(ctx, tenantID, clientID, values)
if err == nil {
m.Cache(tenantID, clientID, token)
m.PutToken(CacheKey(tenantID, clientID), token)
return config.TokenSource(ctx, token), nil
}
tokenError, ok := err.(*TokenError)
Expand Down
60 changes: 47 additions & 13 deletions msauth/msauth.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,10 +36,6 @@ func (t *TokenError) Error() string {
return fmt.Sprintf("%s: %s", t.ErrorObject, t.ErrorDescription)
}

func generateKey(tenantID, clientID string) string {
return fmt.Sprintf("%s:%s", tenantID, clientID)
}

func deviceCodeURL(tenantID string) string {
return fmt.Sprintf(endpointURLFormat, tenantID, "devicecode")
}
Expand All @@ -65,6 +61,7 @@ func (e *tokenJSON) expiry() (t time.Time) {
// Manager is oauth2 token cache manager
type Manager struct {
mu sync.Mutex
Dirty bool
TokenCache map[string]*oauth2.Token
}

Expand All @@ -87,27 +84,64 @@ func (m *Manager) SaveBytes() ([]byte, error) {
return json.Marshal(m.TokenCache)
}

// LoadFile loads token cache from file
// LoadFile loads token cache from file with dirty state control
func (m *Manager) LoadFile(path string) error {
b, err := ioutil.ReadFile(path)
m.mu.Lock()
defer m.mu.Unlock()
b, err := ReadLocation(path)
if err != nil {
return err
}
err = json.Unmarshal(b, &m.TokenCache)
if err != nil {
return err
}
return m.LoadBytes(b)
m.Dirty = false
return nil
}

// SaveFile saves token cache to file
// SaveFile saves token cache to file with dirty state control
func (m *Manager) SaveFile(path string) error {
b, err := m.SaveBytes()
m.mu.Lock()
defer m.mu.Unlock()
if !m.Dirty {
return nil
}
b, err := json.Marshal(m.TokenCache)
if err != nil {
return err
}
err = WriteLocation(path, b, 0644)
if err != nil {
return err
}
return ioutil.WriteFile(path, b, 0644)
m.Dirty = false
return nil
}

// CacheKey generates a token cache key from tenantID/clientID
func CacheKey(tenantID, clientID string) string {
return fmt.Sprintf("%s:%s", tenantID, clientID)
}

// GetToken gets a token from token cache
func (m *Manager) GetToken(cacheKey string) (*oauth2.Token, bool) {
m.mu.Lock()
defer m.mu.Unlock()
token, ok := m.TokenCache[cacheKey]
return token, ok
}

// Cache stores a token into token cache
func (m *Manager) Cache(tenantID, clientID string, token *oauth2.Token) {
m.TokenCache[generateKey(tenantID, clientID)] = token
// PutToken puts a token into token cache
func (m *Manager) PutToken(cacheKey string, token *oauth2.Token) {
m.mu.Lock()
defer m.mu.Unlock()
oldToken, ok := m.TokenCache[cacheKey]
if ok && *oldToken == *token {
return
}
m.TokenCache[cacheKey] = token
m.Dirty = true
}

// requestToken requests a token from the token endpoint
Expand Down
70 changes: 70 additions & 0 deletions msauth/storage.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
package msauth

import (
"bytes"
"fmt"
"io/ioutil"
"net/http"
"net/url"
"os"
"strings"
)

// ReadLocation reads data from file with path or URL
func ReadLocation(loc string) ([]byte, error) {
u, err := url.Parse(loc)
if err != nil {
return nil, err
}
switch u.Scheme {
case "", "file":
return ioutil.ReadFile(u.Path)
case "http", "https":
res, err := http.Get(loc)
if err != nil {
return nil, err
}
defer res.Body.Close()
if res.StatusCode != http.StatusOK {
return nil, fmt.Errorf("%s", res.Status)
}
b, err := ioutil.ReadAll(res.Body)
if err != nil {
return nil, err
}
return b, nil
}
return nil, fmt.Errorf("Unsupported location to load: %s", loc)
}

// WriteLocation writes data to file with path or URL
func WriteLocation(loc string, b []byte, m os.FileMode) error {
u, err := url.Parse(loc)
if err != nil {
return err
}
switch u.Scheme {
case "", "file":
return ioutil.WriteFile(u.Path, b, m)
case "http", "https":
if strings.HasSuffix(u.Host, ".blob.core.windows.net") {
// Azure Blob Storage URL with SAS assumed here
cli := &http.Client{}
req, err := http.NewRequest(http.MethodPut, loc, bytes.NewBuffer(b))
if err != nil {
return err
}
req.Header.Set("x-ms-blob-type", "BlockBlob")
res, err := cli.Do(req)
if err != nil {
return err
}
defer res.Body.Close()
if res.StatusCode != http.StatusCreated {
return fmt.Errorf("%s", res.Status)
}
return nil
}
}
return fmt.Errorf("Unsupported location to save: %s", loc)
}

0 comments on commit 93fc163

Please sign in to comment.