package httprc import ( "bytes" "context" "fmt" "io/ioutil" "net/http" "sync" "time" "github.com/lestrrat-go/httpcc" ) // ErrSink is an abstraction that allows users to consume errors // produced while the cache queue is running. type ErrSink interface { // Error accepts errors produced during the cache queue's execution. // The method should never block, otherwise the fetch loop may be // paused for a prolonged amount of time. Error(error) } type ErrSinkFunc func(err error) func (f ErrSinkFunc) Error(err error) { f(err) } // Transformer is responsible for converting an HTTP response // into an appropriate form of your choosing. type Transformer interface { // Transform receives an HTTP response object, and should // return an appropriate object that suits your needs. // // If you happen to use the response body, you are responsible // for closing the body Transform(string, *http.Response) (interface{}, error) } type TransformFunc func(string, *http.Response) (interface{}, error) func (f TransformFunc) Transform(u string, res *http.Response) (interface{}, error) { return f(u, res) } // BodyBytes is the default Transformer applied to all resources. // It takes an *http.Response object and extracts the body // of the response as `[]byte` type BodyBytes struct{} func (BodyBytes) Transform(_ string, res *http.Response) (interface{}, error) { buf, err := ioutil.ReadAll(res.Body) defer res.Body.Close() if err != nil { return nil, fmt.Errorf(`failed to read response body: %w`, err) } return buf, nil } type rqentry struct { fireAt time.Time url string } // entry represents a resource to be fetched over HTTP, // long with optional specifications such as the *http.Client // object to use. type entry struct { mu sync.RWMutex sem chan struct{} lastFetch time.Time // Interval between refreshes are calculated two ways. // 1) You can set an explicit refresh interval by using WithRefreshInterval(). // In this mode, it doesn't matter what the HTTP response says in its // Cache-Control or Expires headers // 2) You can let us calculate the time-to-refresh based on the key's // Cache-Control or Expires headers. // First, the user provides us the absolute minimum interval before // refreshes. We will never check for refreshes before this specified // amount of time. // // Next, max-age directive in the Cache-Control header is consulted. // If `max-age` is not present, we skip the following section, and // proceed to the next option. // If `max-age > user-supplied minimum interval`, then we use the max-age, // otherwise the user-supplied minimum interval is used. // // Next, the value specified in Expires header is consulted. // If the header is not present, we skip the following seciont and // proceed to the next option. // We take the time until expiration `expires - time.Now()`, and // if `time-until-expiration > user-supplied minimum interval`, then // we use the expires value, otherwise the user-supplied minimum interval is used. // // If all of the above fails, we used the user-supplied minimum interval refreshInterval time.Duration minRefreshInterval time.Duration request *fetchRequest transform Transformer data interface{} } func (e *entry) acquireSem() { e.sem <- struct{}{} } func (e *entry) releaseSem() { <-e.sem } func (e *entry) hasBeenFetched() bool { e.mu.RLock() defer e.mu.RUnlock() return !e.lastFetch.IsZero() } // queue is responsible for updating the contents of the storage type queue struct { mu sync.RWMutex registry map[string]*entry windowSize time.Duration fetch Fetcher fetchCond *sync.Cond fetchQueue []*rqentry // list is a sorted list of urls to their expected fire time // when we get a new tick in the RQ loop, we process everything // that can be fired up to the point the tick was called list []*rqentry } func newQueue(ctx context.Context, window time.Duration, fetch Fetcher, errSink ErrSink) *queue { fetchLocker := &sync.Mutex{} rq := &queue{ windowSize: window, fetch: fetch, fetchCond: sync.NewCond(fetchLocker), registry: make(map[string]*entry), } go rq.refreshLoop(ctx, errSink) return rq } func (q *queue) Register(u string, options ...RegisterOption) error { var refreshInterval time.Duration var client HTTPClient var wl Whitelist var transform Transformer = BodyBytes{} minRefreshInterval := 15 * time.Minute for _, option := range options { //nolint:forcetypeassert switch option.Ident() { case identHTTPClient{}: client = option.Value().(HTTPClient) case identRefreshInterval{}: refreshInterval = option.Value().(time.Duration) case identMinRefreshInterval{}: minRefreshInterval = option.Value().(time.Duration) case identTransformer{}: transform = option.Value().(Transformer) case identWhitelist{}: wl = option.Value().(Whitelist) } } q.mu.RLock() rWindow := q.windowSize q.mu.RUnlock() if refreshInterval > 0 && refreshInterval < rWindow { return fmt.Errorf(`refresh interval (%s) is smaller than refresh window (%s): this will not as expected`, refreshInterval, rWindow) } e := entry{ sem: make(chan struct{}, 1), minRefreshInterval: minRefreshInterval, transform: transform, refreshInterval: refreshInterval, request: &fetchRequest{ client: client, url: u, wl: wl, }, } q.mu.Lock() q.registry[u] = &e q.mu.Unlock() return nil } func (q *queue) Unregister(u string) error { q.mu.Lock() defer q.mu.Unlock() _, ok := q.registry[u] if !ok { return fmt.Errorf(`url %q has not been registered`, u) } delete(q.registry, u) return nil } func (q *queue) getRegistered(u string) (*entry, bool) { q.mu.RLock() e, ok := q.registry[u] q.mu.RUnlock() return e, ok } func (q *queue) IsRegistered(u string) bool { _, ok := q.getRegistered(u) return ok } func (q *queue) fetchLoop(ctx context.Context, errSink ErrSink) { for { q.fetchCond.L.Lock() for len(q.fetchQueue) <= 0 { select { case <-ctx.Done(): return default: q.fetchCond.Wait() } } list := make([]*rqentry, len(q.fetchQueue)) copy(list, q.fetchQueue) q.fetchQueue = q.fetchQueue[:0] q.fetchCond.L.Unlock() for _, rq := range list { select { case <-ctx.Done(): return default: } e, ok := q.getRegistered(rq.url) if !ok { continue } if err := q.fetchAndStore(ctx, e); err != nil { if errSink != nil { errSink.Error(&RefreshError{ URL: rq.url, Err: err, }) } } } } } // This loop is responsible for periodically updating the cached content func (q *queue) refreshLoop(ctx context.Context, errSink ErrSink) { // Tick every q.windowSize duration. ticker := time.NewTicker(q.windowSize) go q.fetchLoop(ctx, errSink) defer q.fetchCond.Signal() for { select { case <-ctx.Done(): return case t := <-ticker.C: t = t.Round(time.Second) // To avoid getting stuck here, we just copy the relevant // items, and release the lock within this critical section var list []*rqentry q.mu.Lock() var max int for i, r := range q.list { if r.fireAt.Before(t) || r.fireAt.Equal(t) { max = i list = append(list, r) continue } break } if len(list) > 0 { q.list = q.list[max+1:] } q.mu.Unlock() // release lock if len(list) > 0 { // Now we need to fetch these, but do this elsewhere so // that we don't block this main loop q.fetchCond.L.Lock() q.fetchQueue = append(q.fetchQueue, list...) q.fetchCond.L.Unlock() q.fetchCond.Signal() } } } } func (q *queue) fetchAndStore(ctx context.Context, e *entry) error { e.mu.Lock() defer e.mu.Unlock() // synchronously go fetch e.lastFetch = time.Now() res, err := q.fetch.fetch(ctx, e.request) if err != nil { // Even if the request failed, we need to queue the next fetch q.enqueueNextFetch(nil, e) return fmt.Errorf(`failed to fetch %q: %w`, e.request.url, err) } q.enqueueNextFetch(res, e) data, err := e.transform.Transform(e.request.url, res) if err != nil { return fmt.Errorf(`failed to transform HTTP response for %q: %w`, e.request.url, err) } e.data = data return nil } func (q *queue) Enqueue(u string, interval time.Duration) error { fireAt := time.Now().Add(interval).Round(time.Second) q.mu.Lock() defer q.mu.Unlock() list := q.list ll := len(list) if ll == 0 || list[ll-1].fireAt.Before(fireAt) { list = append(list, &rqentry{ fireAt: fireAt, url: u, }) } else { for i := 0; i < ll; i++ { if i == ll-1 || list[i].fireAt.After(fireAt) { // insert here list = append(append(list[:i], &rqentry{fireAt: fireAt, url: u}), list[i:]...) break } } } q.list = list return nil } func (q *queue) MarshalJSON() ([]byte, error) { var buf bytes.Buffer buf.WriteString(`{"list":[`) q.mu.RLock() for i, e := range q.list { if i > 0 { buf.WriteByte(',') } fmt.Fprintf(&buf, `{"fire_at":%q,"url":%q}`, e.fireAt.Format(time.RFC3339), e.url) } q.mu.RUnlock() buf.WriteString(`]}`) return buf.Bytes(), nil } func (q *queue) enqueueNextFetch(res *http.Response, e *entry) { dur := calculateRefreshDuration(res, e) // TODO send to error sink _ = q.Enqueue(e.request.url, dur) } func calculateRefreshDuration(res *http.Response, e *entry) time.Duration { if e.refreshInterval > 0 { return e.refreshInterval } if res != nil { if v := res.Header.Get(`Cache-Control`); v != "" { dir, err := httpcc.ParseResponse(v) if err == nil { maxAge, ok := dir.MaxAge() if ok { resDuration := time.Duration(maxAge) * time.Second if resDuration > e.minRefreshInterval { return resDuration } return e.minRefreshInterval } // fallthrough } // fallthrough } if v := res.Header.Get(`Expires`); v != "" { expires, err := http.ParseTime(v) if err == nil { resDuration := time.Until(expires) if resDuration > e.minRefreshInterval { return resDuration } return e.minRefreshInterval } // fallthrough } } // Previous fallthroughs are a little redandunt, but hey, it's all good. return e.minRefreshInterval } type SnapshotEntry struct { URL string `json:"url"` Data interface{} `json:"data"` LastFetched time.Time `json:"last_fetched"` } type Snapshot struct { Entries []SnapshotEntry `json:"entries"` } // Snapshot returns the contents of the cache at the given moment. func (q *queue) snapshot() *Snapshot { q.mu.RLock() list := make([]SnapshotEntry, 0, len(q.registry)) for url, e := range q.registry { list = append(list, SnapshotEntry{ URL: url, LastFetched: e.lastFetch, Data: e.data, }) } q.mu.RUnlock() return &Snapshot{ Entries: list, } }