👽 Resolve ClientID hostname, ignore localhost fetching
This commit is contained in:
parent
741f7000d8
commit
5ad227700b
|
@ -43,6 +43,18 @@ func (httpClientRepository) Create(_ context.Context, _ domain.Client) error {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (repo httpClientRepository) Get(ctx context.Context, cid domain.ClientID) (*domain.Client, error) {
|
func (repo httpClientRepository) Get(ctx context.Context, cid domain.ClientID) (*domain.Client, error) {
|
||||||
|
out := &domain.Client{
|
||||||
|
ID: cid,
|
||||||
|
RedirectURI: make([]*url.URL, 0),
|
||||||
|
Logo: make([]*url.URL, 0),
|
||||||
|
URL: make([]*url.URL, 0),
|
||||||
|
Name: make([]string, 0),
|
||||||
|
}
|
||||||
|
|
||||||
|
if cid.IsLocalhost() {
|
||||||
|
return out, nil
|
||||||
|
}
|
||||||
|
|
||||||
resp, err := repo.client.Get(cid.String())
|
resp, err := repo.client.Get(cid.String())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("failed to make a request to the client: %w", err)
|
return nil, fmt.Errorf("failed to make a request to the client: %w", err)
|
||||||
|
@ -52,17 +64,9 @@ func (repo httpClientRepository) Get(ctx context.Context, cid domain.ClientID) (
|
||||||
return nil, fmt.Errorf("%w: status on client page is not 200", client.ErrNotExist)
|
return nil, fmt.Errorf("%w: status on client page is not 200", client.ErrNotExist)
|
||||||
}
|
}
|
||||||
|
|
||||||
client := &domain.Client{
|
extract(resp.Body, resp.Request.URL, out, resp.Header.Get(common.HeaderLink))
|
||||||
ID: cid,
|
|
||||||
RedirectURI: make([]*url.URL, 0),
|
|
||||||
Logo: make([]*url.URL, 0),
|
|
||||||
URL: make([]*url.URL, 0),
|
|
||||||
Name: make([]string, 0),
|
|
||||||
}
|
|
||||||
|
|
||||||
extract(resp.Body, resp.Request.URL, client, resp.Header.Get(common.HeaderLink))
|
return out, nil
|
||||||
|
|
||||||
return client, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
//nolint:gocognit,cyclop
|
//nolint:gocognit,cyclop
|
||||||
|
|
|
@ -2,6 +2,7 @@ package domain
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"net"
|
||||||
"net/url"
|
"net/url"
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
|
@ -14,7 +15,8 @@ import (
|
||||||
|
|
||||||
// ClientID is a URL client identifier.
|
// ClientID is a URL client identifier.
|
||||||
type ClientID struct {
|
type ClientID struct {
|
||||||
clientID *url.URL
|
clientID *url.URL
|
||||||
|
isLocalhost bool
|
||||||
}
|
}
|
||||||
|
|
||||||
//nolint:gochecknoglobals // slices cannot be constants
|
//nolint:gochecknoglobals // slices cannot be constants
|
||||||
|
@ -87,14 +89,27 @@ func ParseClientID(src string) (*ClientID, error) {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
ipPort, err := netaddr.ParseIPPort(domain)
|
ipPort, err := netaddr.ParseIPPort(domain)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
resolvedAddr, err := net.LookupIP(domain)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("cannot resolve client_id domain: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
ip, _ = netaddr.FromStdIP(resolvedAddr[0])
|
||||||
|
isLocalhost := ip.Compare(localhostIPv4) == 0 || ip.Compare(localhostIPv6) == 0
|
||||||
|
|
||||||
//nolint:nilerr // ClientID does not contain an IP address, so it is valid
|
//nolint:nilerr // ClientID does not contain an IP address, so it is valid
|
||||||
return &ClientID{clientID: cid}, nil
|
return &ClientID{
|
||||||
|
clientID: cid,
|
||||||
|
isLocalhost: isLocalhost,
|
||||||
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
ip = ipPort.IP()
|
ip = ipPort.IP()
|
||||||
}
|
}
|
||||||
|
|
||||||
if !ip.IsLoopback() && ip.Compare(localhostIPv4) != 0 && ip.Compare(localhostIPv6) != 0 {
|
isLocalhost := ip.Compare(localhostIPv4) == 0 || ip.Compare(localhostIPv6) == 0
|
||||||
|
|
||||||
|
if !ip.IsLoopback() && !isLocalhost {
|
||||||
return nil, NewError(
|
return nil, NewError(
|
||||||
ErrorCodeInvalidRequest,
|
ErrorCodeInvalidRequest,
|
||||||
"client identifier URL MUST NOT be IPv4 or IPv6 addresses except for IPv4 "+
|
"client identifier URL MUST NOT be IPv4 or IPv6 addresses except for IPv4 "+
|
||||||
|
@ -104,7 +119,8 @@ func ParseClientID(src string) (*ClientID, error) {
|
||||||
}
|
}
|
||||||
|
|
||||||
return &ClientID{
|
return &ClientID{
|
||||||
clientID: cid,
|
clientID: cid,
|
||||||
|
isLocalhost: isLocalhost,
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -112,17 +128,18 @@ func ParseClientID(src string) (*ClientID, error) {
|
||||||
func TestClientID(tb testing.TB, forceURL ...string) *ClientID {
|
func TestClientID(tb testing.TB, forceURL ...string) *ClientID {
|
||||||
tb.Helper()
|
tb.Helper()
|
||||||
|
|
||||||
in := "https://app.example.com/"
|
in := "https://127.0.0.1/"
|
||||||
|
|
||||||
if len(forceURL) > 0 {
|
if len(forceURL) > 0 {
|
||||||
in = forceURL[0]
|
in = forceURL[0]
|
||||||
}
|
}
|
||||||
|
|
||||||
clientID, err := ParseClientID(in)
|
u, _ := url.Parse(in)
|
||||||
if err != nil {
|
|
||||||
tb.Fatal(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return clientID
|
return &ClientID{
|
||||||
|
clientID: u,
|
||||||
|
isLocalhost: len(forceURL) < 1,
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// UnmarshalForm implements custom unmarshler for form values.
|
// UnmarshalForm implements custom unmarshler for form values.
|
||||||
|
@ -171,6 +188,10 @@ func (cid ClientID) URL() *url.URL {
|
||||||
return out
|
return out
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (cid ClientID) IsLocalhost() bool {
|
||||||
|
return cid.isLocalhost
|
||||||
|
}
|
||||||
|
|
||||||
// String returns string representation of client ID.
|
// String returns string representation of client ID.
|
||||||
func (cid ClientID) String() string {
|
func (cid ClientID) String() string {
|
||||||
if cid.clientID == nil {
|
if cid.clientID == nil {
|
||||||
|
|
Loading…
Reference in New Issue