diff --git a/internal/domain/error.go b/internal/domain/error.go index 3a8e92a..28ab7b5 100644 --- a/internal/domain/error.go +++ b/internal/domain/error.go @@ -3,6 +3,7 @@ package domain import ( "fmt" "strconv" + "strings" http "github.com/valyala/fasthttp" "golang.org/x/xerrors" @@ -185,11 +186,41 @@ var ( } ) +var ErrErrorCodeUnknown error = NewError(ErrorCodeInvalidRequest, "unknown error code", "") + +//nolint: gochecknoglobals // maps cannot be constants +var uidsErrorCodes = map[string]ErrorCode{ + ErrorCodeAccessDenied.uid: ErrorCodeAccessDenied, + ErrorCodeInsufficientScope.uid: ErrorCodeInsufficientScope, + ErrorCodeInvalidClient.uid: ErrorCodeInvalidClient, + ErrorCodeInvalidGrant.uid: ErrorCodeInvalidGrant, + ErrorCodeInvalidRequest.uid: ErrorCodeInvalidRequest, + ErrorCodeInvalidScope.uid: ErrorCodeInvalidScope, + ErrorCodeInvalidToken.uid: ErrorCodeInvalidToken, + ErrorCodeServerError.uid: ErrorCodeServerError, + ErrorCodeTemporarilyUnavailable.uid: ErrorCodeTemporarilyUnavailable, + ErrorCodeUnauthorizedClient.uid: ErrorCodeUnauthorizedClient, + ErrorCodeUnsupportedGrantType.uid: ErrorCodeUnsupportedGrantType, + ErrorCodeUnsupportedResponseType.uid: ErrorCodeUnsupportedResponseType, +} + // String returns a string representation of the error code. func (ec ErrorCode) String() string { return ec.uid } +// UnmarshalForm implements custom unmarshler for form values. +func (ec *ErrorCode) UnmarshalForm(v []byte) error { + code, ok := uidsErrorCodes[strings.ToLower(string(v))] + if !ok { + return fmt.Errorf("UnmarshalForm: %w", ErrErrorCodeUnknown) + } + + *ec = code + + return nil +} + // MarshalJSON encodes the error code into its string representation in JSON. func (ec ErrorCode) MarshalJSON() ([]byte, error) { return []byte(strconv.QuoteToASCII(ec.uid)), nil diff --git a/internal/domain/error_test.go b/internal/domain/error_test.go index 9d4282b..3214940 100644 --- a/internal/domain/error_test.go +++ b/internal/domain/error_test.go @@ -2,6 +2,7 @@ package domain_test import ( "fmt" + "testing" "source.toby3d.me/website/indieauth/internal/domain" ) @@ -10,3 +11,18 @@ func ExampleNewError() { fmt.Printf("%v", domain.NewError(domain.ErrorCodeInvalidRequest, "client_id MUST be provided", "")) // Output: invalid_request: client_id MUST be provided } + +func TestErrorCode_UnmarshalForm(t *testing.T) { + t.Parallel() + + input := []byte("access_denied") + result := domain.ErrorCodeUndefined + + if err := result.UnmarshalForm(input); err != nil { + t.Fatalf("%+v", err) + } + + if result != domain.ErrorCodeAccessDenied { + t.Errorf("UnmarshalForm(%s) = %v, want %v", input, result, domain.ErrorCodeAccessDenied) + } +}