Allow error_code unmarshling

This commit is contained in:
Maxim Lebedev 2022-02-08 00:10:16 +05:00
parent d2ff43d4a3
commit 8eaf349796
Signed by: toby3d
GPG Key ID: 1F14E25B7C119FC5
2 changed files with 47 additions and 0 deletions

View File

@ -3,6 +3,7 @@ package domain
import (
"fmt"
"strconv"
"strings"
http "github.com/valyala/fasthttp"
"golang.org/x/xerrors"
@ -185,11 +186,41 @@ var (
}
)
var ErrErrorCodeUnknown error = NewError(ErrorCodeInvalidRequest, "unknown error code", "")
//nolint: gochecknoglobals // maps cannot be constants
var uidsErrorCodes = map[string]ErrorCode{
ErrorCodeAccessDenied.uid: ErrorCodeAccessDenied,
ErrorCodeInsufficientScope.uid: ErrorCodeInsufficientScope,
ErrorCodeInvalidClient.uid: ErrorCodeInvalidClient,
ErrorCodeInvalidGrant.uid: ErrorCodeInvalidGrant,
ErrorCodeInvalidRequest.uid: ErrorCodeInvalidRequest,
ErrorCodeInvalidScope.uid: ErrorCodeInvalidScope,
ErrorCodeInvalidToken.uid: ErrorCodeInvalidToken,
ErrorCodeServerError.uid: ErrorCodeServerError,
ErrorCodeTemporarilyUnavailable.uid: ErrorCodeTemporarilyUnavailable,
ErrorCodeUnauthorizedClient.uid: ErrorCodeUnauthorizedClient,
ErrorCodeUnsupportedGrantType.uid: ErrorCodeUnsupportedGrantType,
ErrorCodeUnsupportedResponseType.uid: ErrorCodeUnsupportedResponseType,
}
// String returns a string representation of the error code.
func (ec ErrorCode) String() string {
return ec.uid
}
// UnmarshalForm implements custom unmarshler for form values.
func (ec *ErrorCode) UnmarshalForm(v []byte) error {
code, ok := uidsErrorCodes[strings.ToLower(string(v))]
if !ok {
return fmt.Errorf("UnmarshalForm: %w", ErrErrorCodeUnknown)
}
*ec = code
return nil
}
// MarshalJSON encodes the error code into its string representation in JSON.
func (ec ErrorCode) MarshalJSON() ([]byte, error) {
return []byte(strconv.QuoteToASCII(ec.uid)), nil

View File

@ -2,6 +2,7 @@ package domain_test
import (
"fmt"
"testing"
"source.toby3d.me/website/indieauth/internal/domain"
)
@ -10,3 +11,18 @@ func ExampleNewError() {
fmt.Printf("%v", domain.NewError(domain.ErrorCodeInvalidRequest, "client_id MUST be provided", ""))
// Output: invalid_request: client_id MUST be provided
}
func TestErrorCode_UnmarshalForm(t *testing.T) {
t.Parallel()
input := []byte("access_denied")
result := domain.ErrorCodeUndefined
if err := result.UnmarshalForm(input); err != nil {
t.Fatalf("%+v", err)
}
if result != domain.ErrorCodeAccessDenied {
t.Errorf("UnmarshalForm(%s) = %v, want %v", input, result, domain.ErrorCodeAccessDenied)
}
}