From 38c3039ba7d2f448cb8a8a3d5a358c65c6eef496 Mon Sep 17 00:00:00 2001 From: Maxim Lebedev Date: Sat, 8 Jan 2022 15:27:37 +0500 Subject: [PATCH] :recycle: Refactored Code validation into CodeChallengeMethod --- internal/domain/code.go | 63 ------------- internal/domain/code_challenge_method.go | 13 ++- internal/domain/code_challenge_method_test.go | 79 +++++++++++++++- internal/domain/code_test.go | 93 ------------------- 4 files changed, 89 insertions(+), 159 deletions(-) delete mode 100644 internal/domain/code.go delete mode 100644 internal/domain/code_test.go diff --git a/internal/domain/code.go b/internal/domain/code.go deleted file mode 100644 index 5973112..0000000 --- a/internal/domain/code.go +++ /dev/null @@ -1,63 +0,0 @@ -//nolint: gosec -package domain - -import ( - "encoding/base64" - "testing" - - "github.com/brianvoe/gofakeit/v6" - "github.com/stretchr/testify/require" - - "source.toby3d.me/website/indieauth/internal/random" -) - -// Code describes the PKCE challenge to validate the security of the request. -type Code struct { - Method CodeChallengeMethod - Verifier string - Challenge string -} - -const ( - CodeLengthMin int = 43 - CodeLengthMax int = 128 -) - -// TestCode returns valid random generated PKCE code for tests. -func TestCode(tb testing.TB) *Code { - tb.Helper() - - verifier, err := random.String( - gofakeit.Number(CodeLengthMin, CodeLengthMax), random.Alphanumeric, "-", ".", "_", "~", - ) - require.NoError(tb, err) - - h := CodeChallengeMethodS256.hash - h.Reset() - - _, err = h.Write([]byte(verifier)) - require.NoError(tb, err) - - return &Code{ - Method: CodeChallengeMethodS256, - Verifier: verifier, - Challenge: base64.RawURLEncoding.EncodeToString(h.Sum(nil)), - } -} - -// IsValid returns true if code challenge is equal to the generated hash from -// the verifier. -func (c Code) IsValid() bool { - if c.Method == CodeChallengeMethodUndefined { - return false - } - - if c.Method == CodeChallengeMethodPLAIN { - return c.Challenge == c.Verifier - } - - h := c.Method.hash - h.Reset() - - return c.Challenge == base64.RawURLEncoding.EncodeToString(h.Sum([]byte(c.Verifier))) -} diff --git a/internal/domain/code_challenge_method.go b/internal/domain/code_challenge_method.go index 066a029..a862680 100644 --- a/internal/domain/code_challenge_method.go +++ b/internal/domain/code_challenge_method.go @@ -5,6 +5,7 @@ import ( "crypto/sha1" "crypto/sha256" "crypto/sha512" + "encoding/base64" "errors" "fmt" "hash" @@ -106,6 +107,14 @@ func (ccm CodeChallengeMethod) String() string { return ccm.slug } -func (ccm CodeChallengeMethod) Encoder() hash.Hash { - return ccm.hash +func (ccm CodeChallengeMethod) Validate(codeChallenge, verifier string) bool { + if ccm.slug == CodeChallengeMethodUndefined.slug { + return false + } + + if ccm.slug == CodeChallengeMethodPLAIN.slug { + return codeChallenge == verifier + } + + return codeChallenge == base64.RawURLEncoding.EncodeToString(ccm.hash.Sum([]byte(verifier))) } diff --git a/internal/domain/code_challenge_method_test.go b/internal/domain/code_challenge_method_test.go index 0d47370..999bbbc 100644 --- a/internal/domain/code_challenge_method_test.go +++ b/internal/domain/code_challenge_method_test.go @@ -1,14 +1,23 @@ package domain_test import ( + "crypto/md5" + "crypto/sha1" + "crypto/sha256" + "crypto/sha512" + "encoding/base64" + "hash" "testing" + "github.com/brianvoe/gofakeit/v6" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" "source.toby3d.me/website/indieauth/internal/domain" + "source.toby3d.me/website/indieauth/internal/random" ) -func TestParseCodeChallengeMethod(t *testing.T) { +func TestCodeChallengeMethod_Parse(t *testing.T) { t.Parallel() for _, testCase := range []struct { @@ -58,3 +67,71 @@ func TestParseCodeChallengeMethod(t *testing.T) { }) } } + +//nolint: funlen +func TestCodeChallengeMethod_Validate(t *testing.T) { + t.Parallel() + + verifier, err := random.String(gofakeit.Number(42, 128)) + require.NoError(t, err) + + for _, testCase := range []struct { + hash hash.Hash + name string + method string + isValid bool + }{{ + name: "invalid", + method: domain.CodeChallengeMethodS256.String(), + hash: md5.New(), + isValid: false, + }, { + name: "MD5", + method: domain.CodeChallengeMethodMD5.String(), + hash: md5.New(), + isValid: true, + }, { + name: "plain", + method: domain.CodeChallengeMethodPLAIN.String(), + hash: nil, + isValid: true, + }, { + name: "S1", + method: domain.CodeChallengeMethodS1.String(), + hash: sha1.New(), + isValid: true, + }, { + name: "S256", + method: domain.CodeChallengeMethodS256.String(), + hash: sha256.New(), + isValid: true, + }, { + name: "S512", + method: domain.CodeChallengeMethodS512.String(), + hash: sha512.New(), + isValid: true, + }, { + name: "undefined", + method: "und", + hash: nil, + isValid: false, + }} { + testCase := testCase + + t.Run(testCase.name, func(t *testing.T) { + t.Parallel() + + method, _ := domain.ParseCodeChallengeMethod(testCase.method) + if method == domain.CodeChallengeMethodPLAIN || + method == domain.CodeChallengeMethodUndefined { + assert.Equal(t, testCase.isValid, method.Validate(verifier, verifier)) + + return + } + + assert.Equal(t, testCase.isValid, method.Validate(base64.RawURLEncoding.EncodeToString( + testCase.hash.Sum([]byte(verifier)), + ), verifier)) + }) + } +} diff --git a/internal/domain/code_test.go b/internal/domain/code_test.go deleted file mode 100644 index 42fd8e0..0000000 --- a/internal/domain/code_test.go +++ /dev/null @@ -1,93 +0,0 @@ -package domain_test - -import ( - "crypto/md5" - "crypto/sha1" - "crypto/sha256" - "crypto/sha512" - "encoding/base64" - "hash" - "testing" - - "github.com/brianvoe/gofakeit/v6" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" - - "source.toby3d.me/website/indieauth/internal/domain" - "source.toby3d.me/website/indieauth/internal/random" -) - -//nolint: funlen -func TestCodeIsValid(t *testing.T) { - t.Parallel() - - verifier, err := random.String(gofakeit.Number(domain.CodeLengthMin, domain.CodeLengthMax)) - require.NoError(t, err) - - for _, testCase := range []struct { - hash hash.Hash - name string - method string - isValid bool - }{{ - name: "invalid", - method: domain.CodeChallengeMethodS256.String(), - hash: md5.New(), - isValid: false, - }, { - name: "MD5", - method: domain.CodeChallengeMethodMD5.String(), - hash: md5.New(), - isValid: true, - }, { - name: "plain", - method: domain.CodeChallengeMethodPLAIN.String(), - hash: nil, - isValid: true, - }, { - name: "S1", - method: domain.CodeChallengeMethodS1.String(), - hash: sha1.New(), - isValid: true, - }, { - name: "S256", - method: domain.CodeChallengeMethodS256.String(), - hash: sha256.New(), - isValid: true, - }, { - name: "S512", - method: domain.CodeChallengeMethodS512.String(), - hash: sha512.New(), - isValid: true, - }, { - name: "undefined", - method: "und", - hash: nil, - isValid: false, - }} { - testCase := testCase - - t.Run(testCase.name, func(t *testing.T) { - t.Parallel() - - method, _ := domain.ParseCodeChallengeMethod(testCase.method) - result := &domain.Code{ - Method: method, - Verifier: verifier, - Challenge: verifier, - } - - if method == domain.CodeChallengeMethodPLAIN || - method == domain.CodeChallengeMethodUndefined { - assert.Equal(t, testCase.isValid, result.IsValid()) - - return - } - - result.Challenge = base64.RawURLEncoding.EncodeToString( - testCase.hash.Sum([]byte(result.Verifier)), - ) - assert.Equal(t, testCase.isValid, result.IsValid()) - }) - } -}