package jwk import ( "context" "net/http" "reflect" "sync" "time" "github.com/lestrrat-go/backoff/v2" "github.com/lestrrat-go/httpcc" "github.com/pkg/errors" ) // AutoRefresh is a container that keeps track of jwk.Set object by their source URLs. // The jwk.Set objects are refreshed automatically behind the scenes. // // Before retrieving the jwk.Set objects, the user must pre-register the // URLs they intend to use by calling `Configure()` // // ar := jwk.NewAutoRefresh(ctx) // ar.Configure(url, options...) // // Once registered, you can call `Fetch()` to retrieve the jwk.Set object. // // All JWKS objects that are retrieved via the auto-fetch mechanism should be // treated read-only, as they are shared among the consumers and this object. type AutoRefresh struct { errSink chan AutoRefreshError cache map[string]Set configureCh chan struct{} removeCh chan removeReq fetching map[string]chan struct{} muErrSink sync.Mutex muCache sync.RWMutex muFetching sync.Mutex muRegistry sync.RWMutex registry map[string]*target resetTimerCh chan *resetTimerReq } type target struct { // The backoff policy to use when fetching the JWKS fails backoff backoff.Policy // The HTTP client to use. The user may opt to use a client which is // aware of HTTP caching, or one that goes through a proxy httpcl HTTPClient // Interval between refreshes are calculated two ways. // 1) You can set an explicit refresh interval by using WithRefreshInterval(). // In this mode, it doesn't matter what the HTTP response says in its // Cache-Control or Expires headers // 2) You can let us calculate the time-to-refresh based on the key's // Cache-Control or Expires headers. // First, the user provides us the absolute minimum interval before // refreshes. We will never check for refreshes before this specified // amount of time. // // Next, max-age directive in the Cache-Control header is consulted. // If `max-age` is not present, we skip the following section, and // proceed to the next option. // If `max-age > user-supplied minimum interval`, then we use the max-age, // otherwise the user-supplied minimum interval is used. // // Next, the value specified in Expires header is consulted. // If the header is not present, we skip the following seciont and // proceed to the next option. // We take the time until expiration `expires - time.Now()`, and // if `time-until-expiration > user-supplied minimum interval`, then // we use the expires value, otherwise the user-supplied minimum interval is used. // // If all of the above fails, we used the user-supplied minimum interval refreshInterval *time.Duration minRefreshInterval time.Duration url string // The timer for refreshing the keyset. should not be set by anyone // other than the refreshing goroutine timer *time.Timer // Semaphore to limit the number of concurrent refreshes in the background sem chan struct{} // for debugging, snapshoting lastRefresh time.Time nextRefresh time.Time wl Whitelist parseOptions []ParseOption } type resetTimerReq struct { t *target d time.Duration } // NewAutoRefresh creates a container that keeps track of JWKS objects which // are automatically refreshed. // // The context object in the argument controls the life-span of the // auto-refresh worker. If you are using this in a long running process, this // should mostly be set to a context that ends when the main loop/part of your // program exits: // // func MainLoop() { // ctx, cancel := context.WithCancel(context.Background()) // defer cancel() // ar := jwk.AutoRefresh(ctx) // for ... { // ... // } // } func NewAutoRefresh(ctx context.Context) *AutoRefresh { af := &AutoRefresh{ cache: make(map[string]Set), configureCh: make(chan struct{}), removeCh: make(chan removeReq), fetching: make(map[string]chan struct{}), registry: make(map[string]*target), resetTimerCh: make(chan *resetTimerReq), } go af.refreshLoop(ctx) return af } func (af *AutoRefresh) getCached(url string) (Set, bool) { af.muCache.RLock() ks, ok := af.cache[url] af.muCache.RUnlock() if ok { return ks, true } return nil, false } type removeReq struct { replyCh chan error url string } // Remove removes `url` from the list of urls being watched by jwk.AutoRefresh. // If the url is not already registered, returns an error. func (af *AutoRefresh) Remove(url string) error { ch := make(chan error) af.removeCh <- removeReq{replyCh: ch, url: url} return <-ch } // Configure registers the url to be controlled by AutoRefresh, and also // sets any options associated to it. // // Note that options are treated as a whole -- you can't just update // one value. For example, if you did: // // ar.Configure(url, jwk.WithHTTPClient(...)) // ar.Configure(url, jwk.WithRefreshInterval(...)) // The the end result is that `url` is ONLY associated with the options // given in the second call to `Configure()`, i.e. `jwk.WithRefreshInterval`. // The other unspecified options, including the HTTP client, is set to // their default values. // // Configuration must propagate between goroutines, and therefore are // not atomic (But changes should be felt "soon enough" for practical // purposes) func (af *AutoRefresh) Configure(url string, options ...AutoRefreshOption) { var httpcl HTTPClient = http.DefaultClient var hasRefreshInterval bool var refreshInterval time.Duration var wl Whitelist var parseOptions []ParseOption minRefreshInterval := time.Hour bo := backoff.Null() for _, option := range options { if v, ok := option.(ParseOption); ok { parseOptions = append(parseOptions, v) continue } //nolint:forcetypeassert switch option.Ident() { case identFetchBackoff{}: bo = option.Value().(backoff.Policy) case identRefreshInterval{}: refreshInterval = option.Value().(time.Duration) hasRefreshInterval = true case identMinRefreshInterval{}: minRefreshInterval = option.Value().(time.Duration) case identHTTPClient{}: httpcl = option.Value().(HTTPClient) case identFetchWhitelist{}: wl = option.Value().(Whitelist) } } af.muRegistry.Lock() t, ok := af.registry[url] if ok { if t.httpcl != httpcl { t.httpcl = httpcl } if t.minRefreshInterval != minRefreshInterval { t.minRefreshInterval = minRefreshInterval } if t.refreshInterval != nil { if !hasRefreshInterval { t.refreshInterval = nil } else if *t.refreshInterval != refreshInterval { *t.refreshInterval = refreshInterval } } else { if hasRefreshInterval { t.refreshInterval = &refreshInterval } } if t.wl != wl { t.wl = wl } t.parseOptions = parseOptions } else { t = &target{ backoff: bo, httpcl: httpcl, minRefreshInterval: minRefreshInterval, url: url, sem: make(chan struct{}, 1), // This is a placeholder timer so we can call Reset() on it later // Make it sufficiently in the future so that we don't have bogus // events firing timer: time.NewTimer(24 * time.Hour), wl: wl, parseOptions: parseOptions, } if hasRefreshInterval { t.refreshInterval = &refreshInterval } // Record this in the registry af.registry[url] = t } af.muRegistry.Unlock() // Tell the backend to reconfigure itself af.configureCh <- struct{}{} } func (af *AutoRefresh) releaseFetching(url string) { // first delete the entry from the map, then close the channel or // otherwise we may end up getting multiple groutines doing the fetch af.muFetching.Lock() fetchingCh, ok := af.fetching[url] if !ok { // Juuuuuuust in case. But shouldn't happen af.muFetching.Unlock() return } delete(af.fetching, url) close(fetchingCh) af.muFetching.Unlock() } // IsRegistered checks if `url` is registered already. func (af *AutoRefresh) IsRegistered(url string) bool { _, ok := af.getRegistered(url) return ok } // Fetch returns a jwk.Set from the given url. func (af *AutoRefresh) getRegistered(url string) (*target, bool) { af.muRegistry.RLock() t, ok := af.registry[url] af.muRegistry.RUnlock() return t, ok } // Fetch returns a jwk.Set from the given url. // // If it has previously been fetched, then a cached value is returned. // // If this the first time `url` was requested, an HTTP request will be // sent, synchronously. // // When accessed via multiple goroutines concurrently, and the cache // has not been populated yet, only the first goroutine is // allowed to perform the initialization (HTTP fetch and cache population). // All other goroutines will be blocked until the operation is completed. // // DO NOT modify the jwk.Set object returned by this method, as the // objects are shared among all consumers and the backend goroutine func (af *AutoRefresh) Fetch(ctx context.Context, url string) (Set, error) { if _, ok := af.getRegistered(url); !ok { return nil, errors.Errorf(`url %s must be configured using "Configure()" first`, url) } ks, found := af.getCached(url) if found { return ks, nil } return af.refresh(ctx, url) } // Refresh is the same as Fetch(), except that HTTP fetching is done synchronously. // // This is useful when you want to force an HTTP fetch instead of waiting // for the background goroutine to do it, for example when you want to // make sure the AutoRefresh cache is warmed up before starting your main loop func (af *AutoRefresh) Refresh(ctx context.Context, url string) (Set, error) { if _, ok := af.getRegistered(url); !ok { return nil, errors.Errorf(`url %s must be configured using "Configure()" first`, url) } return af.refresh(ctx, url) } func (af *AutoRefresh) refresh(ctx context.Context, url string) (Set, error) { // To avoid a thundering herd, only one goroutine per url may enter into this // initial fetch phase. af.muFetching.Lock() fetchingCh, fetching := af.fetching[url] // unlock happens in each of the if/else clauses because we need to perform // the channel initialization when there is no channel present if fetching { af.muFetching.Unlock() select { case <-ctx.Done(): return nil, ctx.Err() case <-fetchingCh: } } else { fetchingCh = make(chan struct{}) af.fetching[url] = fetchingCh af.muFetching.Unlock() // Register a cleanup handler, to make sure we always defer af.releaseFetching(url) // The first time around, we need to fetch the keyset if err := af.doRefreshRequest(ctx, url, false); err != nil { return nil, errors.Wrapf(err, `failed to fetch resource pointed by %s`, url) } } // the cache should now be populated ks, ok := af.getCached(url) if !ok { return nil, errors.New("cache was not populated after explicit refresh") } return ks, nil } // Keeps looping, while refreshing the KeySet. func (af *AutoRefresh) refreshLoop(ctx context.Context) { // reflect.Select() is slow IF we are executing it over and over // in a very fast iteration, but we assume here that refreshes happen // seldom enough that being able to call one `select{}` with multiple // targets / channels outweighs the speed penalty of using reflect. // const ( ctxDoneIdx = iota configureIdx resetTimerIdx removeIdx baseSelcasesLen ) baseSelcases := make([]reflect.SelectCase, baseSelcasesLen) baseSelcases[ctxDoneIdx] = reflect.SelectCase{ Dir: reflect.SelectRecv, Chan: reflect.ValueOf(ctx.Done()), } baseSelcases[configureIdx] = reflect.SelectCase{ Dir: reflect.SelectRecv, Chan: reflect.ValueOf(af.configureCh), } baseSelcases[resetTimerIdx] = reflect.SelectCase{ Dir: reflect.SelectRecv, Chan: reflect.ValueOf(af.resetTimerCh), } baseSelcases[removeIdx] = reflect.SelectCase{ Dir: reflect.SelectRecv, Chan: reflect.ValueOf(af.removeCh), } var targets []*target var selcases []reflect.SelectCase for { // It seems silly, but it's much easier to keep track of things // if we re-build the select cases every iteration af.muRegistry.RLock() if cap(targets) < len(af.registry) { targets = make([]*target, 0, len(af.registry)) } else { targets = targets[:0] } if cap(selcases) < len(af.registry) { selcases = make([]reflect.SelectCase, 0, len(af.registry)+baseSelcasesLen) } else { selcases = selcases[:0] } selcases = append(selcases, baseSelcases...) for _, data := range af.registry { targets = append(targets, data) selcases = append(selcases, reflect.SelectCase{ Dir: reflect.SelectRecv, Chan: reflect.ValueOf(data.timer.C), }) } af.muRegistry.RUnlock() chosen, recv, recvOK := reflect.Select(selcases) switch chosen { case ctxDoneIdx: // <-ctx.Done(). Just bail out of this loop return case configureIdx: // <-configureCh. rebuild the select list from the registry. // since we're rebuilding everything for each iteration, // we just need to start the loop all over again continue case resetTimerIdx: // <-resetTimerCh. interrupt polling, and reset the timer on // a single target. this needs to be handled inside this select if !recvOK { continue } req := recv.Interface().(*resetTimerReq) //nolint:forcetypeassert t := req.t d := req.d if !t.timer.Stop() { select { case <-t.timer.C: default: } } t.timer.Reset(d) case removeIdx: // <-removeCh. remove the URL from future fetching //nolint:forcetypeassert req := recv.Interface().(removeReq) replyCh := req.replyCh url := req.url af.muRegistry.Lock() if _, ok := af.registry[url]; !ok { replyCh <- errors.Errorf(`invalid url %q (not registered)`, url) } else { delete(af.registry, url) replyCh <- nil } af.muRegistry.Unlock() default: // Do not fire a refresh in case the channel was closed. if !recvOK { continue } // Time to refresh a target t := targets[chosen-baseSelcasesLen] // Check if there are other goroutines still doing the refresh asynchronously. // This could happen if the refreshing goroutine is stuck on a backoff // waiting for the HTTP request to complete. select { case t.sem <- struct{}{}: // There can only be one refreshing goroutine default: continue } go func() { //nolint:errcheck af.doRefreshRequest(ctx, t.url, true) <-t.sem }() } } } func (af *AutoRefresh) doRefreshRequest(ctx context.Context, url string, enableBackoff bool) error { af.muRegistry.RLock() t, ok := af.registry[url] if !ok { af.muRegistry.RUnlock() return errors.Errorf(`url "%s" is not registered`, url) } // In case the refresh fails due to errors in fetching/parsing the JWKS, // we want to retry. Create a backoff object, parseOptions := t.parseOptions fetchOptions := []FetchOption{WithHTTPClient(t.httpcl)} if enableBackoff { fetchOptions = append(fetchOptions, WithFetchBackoff(t.backoff)) } if t.wl != nil { fetchOptions = append(fetchOptions, WithFetchWhitelist(t.wl)) } af.muRegistry.RUnlock() res, err := fetch(ctx, url, fetchOptions...) if err == nil { if res.StatusCode != http.StatusOK { // now, can there be a remote resource that responds with a status code // other than 200 and still be valid...? naaaaaaahhhhhh.... err = errors.Errorf(`bad response status code (%d)`, res.StatusCode) } else { defer res.Body.Close() keyset, parseErr := ParseReader(res.Body, parseOptions...) if parseErr == nil { // Got a new key set. replace the keyset in the target af.muCache.Lock() af.cache[url] = keyset af.muCache.Unlock() af.muRegistry.RLock() nextInterval := calculateRefreshDuration(res, t.refreshInterval, t.minRefreshInterval) af.muRegistry.RUnlock() rtr := &resetTimerReq{ t: t, d: nextInterval, } select { case <-ctx.Done(): return ctx.Err() case af.resetTimerCh <- rtr: } now := time.Now() af.muRegistry.Lock() t.lastRefresh = now.Local() t.nextRefresh = now.Add(nextInterval).Local() af.muRegistry.Unlock() return nil } err = parseErr } } // At this point if err != nil, we know that there was something wrong // in either the fetching or the parsing. Send this error to be processed, // but take the extra mileage to not block regular processing by // discarding the error if we fail to send it through the channel if err != nil { select { case af.errSink <- AutoRefreshError{Error: err, URL: url}: default: } } // We either failed to perform the HTTP GET, or we failed to parse the // JWK set. Even in case of errors, we don't delete the old key. // We persist the old key set, even if it may be stale so the user has something to work with // TODO: maybe this behavior should be customizable? // If we failed to get a single time, then queue another fetch in the future. rtr := &resetTimerReq{ t: t, d: calculateRefreshDuration(res, t.refreshInterval, t.minRefreshInterval), } select { case <-ctx.Done(): return ctx.Err() case af.resetTimerCh <- rtr: } return err } // ErrorSink sets a channel to receive JWK fetch errors, if any. // Only the errors that occurred *after* the channel was set will be sent. // // The user is responsible for properly draining the channel. If the channel // is not drained properly, errors will be discarded. // // To disable, set a nil channel. func (af *AutoRefresh) ErrorSink(ch chan AutoRefreshError) { af.muErrSink.Lock() af.errSink = ch af.muErrSink.Unlock() } func calculateRefreshDuration(res *http.Response, refreshInterval *time.Duration, minRefreshInterval time.Duration) time.Duration { // This always has precedence if refreshInterval != nil { return *refreshInterval } if res != nil { if v := res.Header.Get(`Cache-Control`); v != "" { dir, err := httpcc.ParseResponse(v) if err == nil { maxAge, ok := dir.MaxAge() if ok { resDuration := time.Duration(maxAge) * time.Second if resDuration > minRefreshInterval { return resDuration } return minRefreshInterval } // fallthrough } // fallthrough } if v := res.Header.Get(`Expires`); v != "" { expires, err := http.ParseTime(v) if err == nil { resDuration := time.Until(expires) if resDuration > minRefreshInterval { return resDuration } return minRefreshInterval } // fallthrough } } // Previous fallthroughs are a little redandunt, but hey, it's all good. return minRefreshInterval } // TargetSnapshot is the structure returned by the Snapshot method. // It contains information about a url that has been configured // in AutoRefresh. type TargetSnapshot struct { URL string NextRefresh time.Time LastRefresh time.Time } func (af *AutoRefresh) Snapshot() <-chan TargetSnapshot { af.muRegistry.Lock() ch := make(chan TargetSnapshot, len(af.registry)) for url, t := range af.registry { ch <- TargetSnapshot{ URL: url, NextRefresh: t.nextRefresh, LastRefresh: t.lastRefresh, } } af.muRegistry.Unlock() close(ch) return ch }