diff --git a/README.md b/README.md index d389619c..b747694b 100644 --- a/README.md +++ b/README.md @@ -95,7 +95,7 @@ $ go generate ./gen - [ ] Unit tests - [x] CI - [x] Persist OAuth2 tokens in file -- [ ] Persist OAuth2 tokens in object storage like Azure Blob +- [x] Persist OAuth2 tokens in object storage like Azure Blob - [x] OAuth2 device auth grant - [x] OAuth2 client credentials grant - [x] Use string for EnumType (pointed out in #6) diff --git a/msauth/device_authorization_grant.go b/msauth/device_authorization_grant.go index 4baafd8d..9ff23b3c 100644 --- a/msauth/device_authorization_grant.go +++ b/msauth/device_authorization_grant.go @@ -39,9 +39,10 @@ func (m *Manager) DeviceAuthorizationGrant(ctx context.Context, tenantID, client Endpoint: endpoint, Scopes: scopes, } - if t, ok := m.TokenCache[generateKey(tenantID, clientID)]; ok { + if t, ok := m.GetToken(CacheKey(tenantID, clientID)); ok { tt, err := config.TokenSource(ctx, t).Token() if err == nil { + m.PutToken(CacheKey(tenantID, clientID), tt) return config.TokenSource(ctx, tt), nil } if _, ok := err.(*oauth2.RetrieveError); !ok { @@ -85,7 +86,7 @@ func (m *Manager) DeviceAuthorizationGrant(ctx context.Context, tenantID, client time.Sleep(time.Second * time.Duration(interval)) token, err := m.requestToken(ctx, tenantID, clientID, values) if err == nil { - m.Cache(tenantID, clientID, token) + m.PutToken(CacheKey(tenantID, clientID), token) return config.TokenSource(ctx, token), nil } tokenError, ok := err.(*TokenError) diff --git a/msauth/msauth.go b/msauth/msauth.go index c2a86985..014f2522 100644 --- a/msauth/msauth.go +++ b/msauth/msauth.go @@ -36,10 +36,6 @@ func (t *TokenError) Error() string { return fmt.Sprintf("%s: %s", t.ErrorObject, t.ErrorDescription) } -func generateKey(tenantID, clientID string) string { - return fmt.Sprintf("%s:%s", tenantID, clientID) -} - func deviceCodeURL(tenantID string) string { return fmt.Sprintf(endpointURLFormat, tenantID, "devicecode") } @@ -65,6 +61,7 @@ func (e *tokenJSON) expiry() (t time.Time) { // Manager is oauth2 token cache manager type Manager struct { mu sync.Mutex + Dirty bool TokenCache map[string]*oauth2.Token } @@ -87,27 +84,64 @@ func (m *Manager) SaveBytes() ([]byte, error) { return json.Marshal(m.TokenCache) } -// LoadFile loads token cache from file +// LoadFile loads token cache from file with dirty state control func (m *Manager) LoadFile(path string) error { - b, err := ioutil.ReadFile(path) + m.mu.Lock() + defer m.mu.Unlock() + b, err := ReadLocation(path) + if err != nil { + return err + } + err = json.Unmarshal(b, &m.TokenCache) if err != nil { return err } - return m.LoadBytes(b) + m.Dirty = false + return nil } -// SaveFile saves token cache to file +// SaveFile saves token cache to file with dirty state control func (m *Manager) SaveFile(path string) error { - b, err := m.SaveBytes() + m.mu.Lock() + defer m.mu.Unlock() + if !m.Dirty { + return nil + } + b, err := json.Marshal(m.TokenCache) + if err != nil { + return err + } + err = WriteLocation(path, b, 0644) if err != nil { return err } - return ioutil.WriteFile(path, b, 0644) + m.Dirty = false + return nil +} + +// CacheKey generates a token cache key from tenantID/clientID +func CacheKey(tenantID, clientID string) string { + return fmt.Sprintf("%s:%s", tenantID, clientID) +} + +// GetToken gets a token from token cache +func (m *Manager) GetToken(cacheKey string) (*oauth2.Token, bool) { + m.mu.Lock() + defer m.mu.Unlock() + token, ok := m.TokenCache[cacheKey] + return token, ok } -// Cache stores a token into token cache -func (m *Manager) Cache(tenantID, clientID string, token *oauth2.Token) { - m.TokenCache[generateKey(tenantID, clientID)] = token +// PutToken puts a token into token cache +func (m *Manager) PutToken(cacheKey string, token *oauth2.Token) { + m.mu.Lock() + defer m.mu.Unlock() + oldToken, ok := m.TokenCache[cacheKey] + if ok && *oldToken == *token { + return + } + m.TokenCache[cacheKey] = token + m.Dirty = true } // requestToken requests a token from the token endpoint diff --git a/msauth/storage.go b/msauth/storage.go new file mode 100644 index 00000000..7d8db8ae --- /dev/null +++ b/msauth/storage.go @@ -0,0 +1,70 @@ +package msauth + +import ( + "bytes" + "fmt" + "io/ioutil" + "net/http" + "net/url" + "os" + "strings" +) + +// ReadLocation reads data from file with path or URL +func ReadLocation(loc string) ([]byte, error) { + u, err := url.Parse(loc) + if err != nil { + return nil, err + } + switch u.Scheme { + case "", "file": + return ioutil.ReadFile(u.Path) + case "http", "https": + res, err := http.Get(loc) + if err != nil { + return nil, err + } + defer res.Body.Close() + if res.StatusCode != http.StatusOK { + return nil, fmt.Errorf("%s", res.Status) + } + b, err := ioutil.ReadAll(res.Body) + if err != nil { + return nil, err + } + return b, nil + } + return nil, fmt.Errorf("Unsupported location to load: %s", loc) +} + +// WriteLocation writes data to file with path or URL +func WriteLocation(loc string, b []byte, m os.FileMode) error { + u, err := url.Parse(loc) + if err != nil { + return err + } + switch u.Scheme { + case "", "file": + return ioutil.WriteFile(u.Path, b, m) + case "http", "https": + if strings.HasSuffix(u.Host, ".blob.core.windows.net") { + // Azure Blob Storage URL with SAS assumed here + cli := &http.Client{} + req, err := http.NewRequest(http.MethodPut, loc, bytes.NewBuffer(b)) + if err != nil { + return err + } + req.Header.Set("x-ms-blob-type", "BlockBlob") + res, err := cli.Do(req) + if err != nil { + return err + } + defer res.Body.Close() + if res.StatusCode != http.StatusCreated { + return fmt.Errorf("%s", res.Status) + } + return nil + } + } + return fmt.Errorf("Unsupported location to save: %s", loc) +}