home/internal/middleware/redirect.go

112 lines
2.2 KiB
Go

package middleware
import (
"net/http"
"source.toby3d.me/toby3d/home/internal/domain"
"source.toby3d.me/toby3d/home/internal/server"
"source.toby3d.me/toby3d/home/internal/site"
"source.toby3d.me/toby3d/home/internal/urlutil"
)
type (
RedirectConfig struct {
Skipper Skipper
Siter site.UseCase
Serverer server.UseCase
}
redirectResponse struct {
http.ResponseWriter
error error
statusCode int
}
)
func Redirect(config RedirectConfig) Interceptor {
if config.Skipper == nil {
config.Skipper = DefaultSkipper
}
if config.Siter == nil {
panic("middleware: redirect: Siter is nil")
}
if config.Serverer == nil {
panic("middleware: redirect: Serverer is nil")
}
return func(w http.ResponseWriter, r *http.Request, next http.HandlerFunc) {
if config.Skipper(r) {
next(w, r)
return
}
lang, path := domain.LanguageUnd, r.URL.Path
if head, tail := urlutil.ShiftPath(r.URL.Path); head != "" {
if lang = domain.NewLanguage(head); lang != domain.LanguageUnd {
path = tail
}
}
site, err := config.Siter.Do(r.Context(), lang)
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
server, err := config.Serverer.Do(r.Context(), *site)
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
redirect, ok := server.Redirects.Match(path)
if !ok {
next(w, r)
return
}
// NOTE(toby3d): always redirect no matter what exists on
// requested URL.
if redirect.Force {
http.Redirect(w, r, redirect.To, redirect.Status)
return
}
tx := &redirectResponse{
error: nil,
statusCode: http.StatusOK,
ResponseWriter: w,
}
next(tx, r)
// NOTE(toby3d): redirect only if something bad on requested
// URL.
if tx.error == nil && http.StatusOK < tx.statusCode && tx.statusCode < http.StatusBadRequest {
return
}
http.Redirect(w, r, redirect.To, redirect.Status)
}
}
func (r *redirectResponse) WriteHeader(status int) {
r.statusCode = status
r.ResponseWriter.WriteHeader(status)
}
func (r *redirectResponse) Write(src []byte) (int, error) {
var length int
length, r.error = r.ResponseWriter.Write(src)
return length, r.error
}