♻️ Refactored ClientID domain contents, translated fasthttp.URI to url.URL

This commit is contained in:
Maxim Lebedev 2023-01-02 06:32:13 +06:00
parent 834af5d3cf
commit 2af2a432b0
Signed by: toby3d
GPG Key ID: 1F14E25B7C119FC5
2 changed files with 14 additions and 36 deletions

View File

@ -65,12 +65,12 @@ func (c *Client) ValidateRedirectURI(redirectURI *URL) bool {
rHost = string(redirectURI.Host()) rHost = string(redirectURI.Host())
} }
cHost, cPort, err := net.SplitHostPort(string(c.ID.clientID.Host())) cHost, cPort, err := net.SplitHostPort(c.ID.clientID.Host)
if err != nil { if err != nil {
cHost = string(c.ID.clientID.Host()) cHost = c.ID.clientID.Hostname()
} }
if bytes.EqualFold(redirectURI.Scheme(), c.ID.clientID.Scheme()) && if bytes.EqualFold(redirectURI.Scheme(), []byte(c.ID.clientID.Scheme)) &&
strings.EqualFold(rHost, cHost) && strings.EqualFold(rHost, cHost) &&
strings.EqualFold(rPort, cPort) { strings.EqualFold(rPort, cPort) {
return true return true

View File

@ -7,13 +7,12 @@ import (
"strings" "strings"
"testing" "testing"
http "github.com/valyala/fasthttp"
"inet.af/netaddr" "inet.af/netaddr"
) )
// ClientID is a URL client identifier. // ClientID is a URL client identifier.
type ClientID struct { type ClientID struct {
clientID *http.URI clientID *url.URL
} }
//nolint:gochecknoglobals // slices cannot be constants //nolint:gochecknoglobals // slices cannot be constants
@ -26,8 +25,8 @@ var (
// //
//nolint:funlen,cyclop //nolint:funlen,cyclop
func ParseClientID(src string) (*ClientID, error) { func ParseClientID(src string) (*ClientID, error) {
cid := http.AcquireURI() cid, err := url.Parse(src)
if err := cid.Parse(nil, []byte(src)); err != nil { if err != nil {
return nil, NewError( return nil, NewError(
ErrorCodeInvalidRequest, ErrorCodeInvalidRequest,
err.Error(), err.Error(),
@ -35,8 +34,7 @@ func ParseClientID(src string) (*ClientID, error) {
) )
} }
scheme := string(cid.Scheme()) if cid.Scheme != "http" && cid.Scheme != "https" {
if scheme != "http" && scheme != "https" {
return nil, NewError( return nil, NewError(
ErrorCodeInvalidRequest, ErrorCodeInvalidRequest,
"client identifier URL MUST have either an https or http scheme", "client identifier URL MUST have either an https or http scheme",
@ -44,8 +42,7 @@ func ParseClientID(src string) (*ClientID, error) {
) )
} }
path := string(cid.PathOriginal()) if cid.Path == "" || strings.Contains(cid.Path, "/.") || strings.Contains(cid.Path, "/..") {
if path == "" || strings.Contains(path, "/.") || strings.Contains(path, "/..") {
return nil, NewError( return nil, NewError(
ErrorCodeInvalidRequest, ErrorCodeInvalidRequest,
"client identifier URL MUST contain a path component and MUST NOT contain "+ "client identifier URL MUST contain a path component and MUST NOT contain "+
@ -54,7 +51,7 @@ func ParseClientID(src string) (*ClientID, error) {
) )
} }
if cid.Hash() != nil { if cid.Fragment != "" {
return nil, NewError( return nil, NewError(
ErrorCodeInvalidRequest, ErrorCodeInvalidRequest,
"client identifier URL MUST NOT contain a fragment component", "client identifier URL MUST NOT contain a fragment component",
@ -62,7 +59,7 @@ func ParseClientID(src string) (*ClientID, error) {
) )
} }
if cid.Username() != nil || cid.Password() != nil { if cid.User != nil {
return nil, NewError( return nil, NewError(
ErrorCodeInvalidRequest, ErrorCodeInvalidRequest,
"client identifier URL MUST NOT contain a username or password component", "client identifier URL MUST NOT contain a username or password component",
@ -70,7 +67,7 @@ func ParseClientID(src string) (*ClientID, error) {
) )
} }
domain := string(cid.Host()) domain := cid.Hostname()
if domain == "" { if domain == "" {
return nil, NewError( return nil, NewError(
ErrorCodeInvalidRequest, ErrorCodeInvalidRequest,
@ -150,30 +147,11 @@ func (cid ClientID) MarshalJSON() ([]byte, error) {
return []byte(strconv.Quote(cid.String())), nil return []byte(strconv.Quote(cid.String())), nil
} }
// URI returns copy of parsed *fasthttp.URI.
//
// WARN(toby3d): This copy MUST be released via fasthttp.ReleaseURI.
func (cid ClientID) URI() *http.URI {
uri := http.AcquireURI()
cid.clientID.CopyTo(uri)
return uri
}
// URL returns url.URL representation of client ID. // URL returns url.URL representation of client ID.
func (cid ClientID) URL() *url.URL { func (cid ClientID) URL() *url.URL {
return &url.URL{ out, _ := url.Parse(cid.clientID.String())
ForceQuery: false,
Fragment: string(cid.clientID.Hash()), return out
Host: string(cid.clientID.Host()),
Opaque: "",
Path: string(cid.clientID.Path()),
RawFragment: "",
RawPath: string(cid.clientID.PathOriginal()),
RawQuery: string(cid.clientID.QueryString()),
Scheme: string(cid.clientID.Scheme()),
User: nil,
}
} }
// String returns string representation of client ID. // String returns string representation of client ID.