diff --git a/internal/cmd/home/home.go b/internal/cmd/home/home.go index c25e8fc..98ba2c5 100644 --- a/internal/cmd/home/home.go +++ b/internal/cmd/home/home.go @@ -57,7 +57,7 @@ func NewApp(logger *log.Logger, config *domain.Config) (*App, error) { themer := themeucase.NewThemeUseCase(partialsDir, themes) pages := pagefsrepo.NewFileSystemPageRepository(contentDir) pager := pageucase.NewPageUseCase(pages, resources) - server := servercase.NewServerUseCase() + serverer := servercase.NewServerUseCase() handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { // INFO(toby3d): any static file is public and unprotected by design, so it's safe to search it // first before deep down to any page or it's resource which might be secured by middleware or @@ -77,7 +77,7 @@ func NewApp(logger *log.Logger, config *domain.Config) (*App, error) { return } - siteServer, err := server.Do(r.Context(), *s) + siteServer, err := serverer.Do(r.Context(), *s) if err != nil { http.Error(w, err.Error(), http.StatusInternalServerError) @@ -94,26 +94,8 @@ func NewApp(logger *log.Logger, config *domain.Config) (*App, error) { } } - var redirect *domain.Redirect - for i := range siteServer.Redirects { - if !siteServer.Redirects[i].IsMatch(r.URL.Path) { - continue - } - - if siteServer.Redirects[i].Force { - http.Redirect(w, r, siteServer.Redirects[i].To, siteServer.Redirects[i].Status) - - return - } - - redirect = &siteServer.Redirects[i] - - break - } - if s.IsMultiLingual() { head, tail := urlutil.ShiftPath(r.URL.Path) - if head == "" { supported := make([]language.Tag, len(s.Languages)) for i := range s.Languages { @@ -125,8 +107,8 @@ func NewApp(logger *log.Logger, config *domain.Config) (*App, error) { supported...) } - requested, _, err := language.ParseAcceptLanguage( - r.Header.Get(common.HeaderAcceptLanguage)) + requested, _, err := language.ParseAcceptLanguage(r.Header.Get( + common.HeaderAcceptLanguage)) if err != nil || len(requested) == 0 { requested = append(requested, language.English) } @@ -139,14 +121,9 @@ func NewApp(logger *log.Logger, config *domain.Config) (*App, error) { return } - lang = domain.NewLanguage(head) - r.URL.Path = tail - } - - if lang == domain.LanguageUnd && redirect != nil { - http.Redirect(w, r, redirect.To, redirect.Status) - - return + if lang = domain.NewLanguage(head); lang != domain.LanguageUnd { + r.URL.Path = tail + } } if s, err = siter.Do(r.Context(), lang); err != nil { @@ -155,7 +132,7 @@ func NewApp(logger *log.Logger, config *domain.Config) (*App, error) { return } - if siteServer, err = server.Do(r.Context(), *s); err != nil { + if siteServer, err = serverer.Do(r.Context(), *s); err != nil { http.Error(w, err.Error(), http.StatusInternalServerError) return @@ -171,22 +148,6 @@ func NewApp(logger *log.Logger, config *domain.Config) (*App, error) { } } - for i := range siteServer.Redirects { - if !siteServer.Redirects[i].IsMatch(r.URL.Path) { - continue - } - - if siteServer.Redirects[i].Force { - http.Redirect(w, r, siteServer.Redirects[i].To, siteServer.Redirects[i].Status) - - return - } - - redirect = &siteServer.Redirects[i] - - break - } - p, err := pager.Do(r.Context(), lang, r.URL.Path) if err != nil { if !errors.Is(err, page.ErrNotExist) { @@ -195,12 +156,6 @@ func NewApp(logger *log.Logger, config *domain.Config) (*App, error) { return } - if redirect != nil { - http.Redirect(w, r, redirect.To, redirect.Status) - - return - } - res, err := resourcer.Do(r.Context(), r.URL.Path) if err != nil { if errors.Is(err, fs.ErrNotExist) { @@ -255,7 +210,13 @@ func NewApp(logger *log.Logger, config *domain.Config) (*App, error) { http.Error(w, err.Error(), http.StatusInternalServerError) } }) - chain := middleware.Chain{middleware.LogFmt()} + chain := middleware.Chain{ + middleware.LogFmt(), + middleware.Redirect(middleware.RedirectConfig{ + Siter: siter, + Serverer: serverer, + }), + } return &App{server: &http.Server{ Addr: config.AddrPort().String(), diff --git a/internal/domain/redirects.go b/internal/domain/redirects.go new file mode 100644 index 0000000..bef891f --- /dev/null +++ b/internal/domain/redirects.go @@ -0,0 +1,15 @@ +package domain + +type Redirects []Redirect + +func (r Redirects) Match(p string) (*Redirect, bool) { + for i := range r { + if !r[i].IsMatch(p) { + continue + } + + return &r[i], true + } + + return nil, false +} diff --git a/internal/domain/server.go b/internal/domain/server.go index 60d0bf1..d10004d 100644 --- a/internal/domain/server.go +++ b/internal/domain/server.go @@ -2,12 +2,12 @@ package domain type Server struct { Headers []Header - Redirects []Redirect + Redirects Redirects } func NewServer() *Server { return &Server{ Headers: make([]Header, 0), - Redirects: make([]Redirect, 0), + Redirects: make(Redirects, 0), } } diff --git a/internal/middleware/redirect.go b/internal/middleware/redirect.go index 5db5791..79be6b4 100644 --- a/internal/middleware/redirect.go +++ b/internal/middleware/redirect.go @@ -2,25 +2,38 @@ package middleware import ( "net/http" - "net/url" + + "source.toby3d.me/toby3d/home/internal/domain" + "source.toby3d.me/toby3d/home/internal/server" + "source.toby3d.me/toby3d/home/internal/site" + "source.toby3d.me/toby3d/home/internal/urlutil" ) type ( RedirectConfig struct { - Skipper Skipper - Code int + Skipper Skipper + Siter site.UseCase + Serverer server.UseCase } - redirectLogic func(u *url.URL) (url string, ok bool) + redirectResponse struct { + http.ResponseWriter + error error + statusCode int + } ) -func Redirect(config RedirectConfig, redirect redirectLogic) Interceptor { +func Redirect(config RedirectConfig) Interceptor { if config.Skipper == nil { config.Skipper = DefaultSkipper } - if config.Code == 0 { - config.Code = http.StatusMovedPermanently + if config.Siter == nil { + panic("middleware: redirect: Siter is nil") + } + + if config.Serverer == nil { + panic("middleware: redirect: Serverer is nil") } return func(w http.ResponseWriter, r *http.Request, next http.HandlerFunc) { @@ -30,20 +43,69 @@ func Redirect(config RedirectConfig, redirect redirectLogic) Interceptor { return } - u := &url.URL{ - Scheme: "http", - Host: r.Host, - Path: r.RequestURI, + lang, path := domain.LanguageUnd, r.URL.Path + if head, tail := urlutil.ShiftPath(r.URL.Path); head != "" { + if lang = domain.NewLanguage(head); lang != domain.LanguageUnd { + path = tail + } } - if r.TLS != nil { - u.Scheme += "s" + site, err := config.Siter.Do(r.Context(), lang) + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + + return } - if target, ok := redirect(u); ok { - http.RedirectHandler(target, config.Code).ServeHTTP(w, r) - } else { + server, err := config.Serverer.Do(r.Context(), *site) + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + + return + } + + redirect, ok := server.Redirects.Match(path) + if !ok { next(w, r) + + return } + + // NOTE(toby3d): always redirect no matter what exists on + // requested URL. + if redirect.Force { + http.Redirect(w, r, redirect.To, redirect.Status) + + return + } + + tx := &redirectResponse{ + error: nil, + statusCode: http.StatusOK, + ResponseWriter: w, + } + + next(tx, r) + + // NOTE(toby3d): redirect only if something bad on requested + // URL. + if tx.error == nil && http.StatusOK < tx.statusCode && tx.statusCode < http.StatusBadRequest { + return + } + + http.Redirect(w, r, redirect.To, redirect.Status) } } + +func (r *redirectResponse) WriteHeader(status int) { + r.statusCode = status + + r.ResponseWriter.WriteHeader(status) +} + +func (r *redirectResponse) Write(src []byte) (int, error) { + var length int + length, r.error = r.ResponseWriter.Write(src) + + return length, r.error +}