diff --git a/internal/domain/code.go b/internal/domain/code.go new file mode 100644 index 0000000..4135550 --- /dev/null +++ b/internal/domain/code.go @@ -0,0 +1,63 @@ +//nolint: gosec +package domain + +import ( + "encoding/base64" + "testing" + + "github.com/brianvoe/gofakeit/v6" + "github.com/stretchr/testify/require" + + "source.toby3d.me/website/oauth/internal/random" +) + +// Code describes the PKCE challenge to validate the security of the request. +type Code struct { + Method CodeChallengeMethod + Verifier string + Challenge string +} + +const ( + CodeLengthMin int = 43 + CodeLengthMax int = 128 +) + +// TestCode returns valid random generated PKCE code for tests. +func TestCode(tb testing.TB) *Code { + tb.Helper() + + verifier, err := random.String( + gofakeit.Number(CodeLengthMin, CodeLengthMax), random.Alphanumeric, "-", ".", "_", "~", + ) + require.NoError(tb, err) + + h := CodeChallengeMethodS256.hash + h.Reset() + + _, err = h.Write([]byte(verifier)) + require.NoError(tb, err) + + return &Code{ + Method: CodeChallengeMethodS256, + Verifier: verifier, + Challenge: base64.RawURLEncoding.EncodeToString(h.Sum(nil)), + } +} + +// IsValid returns true if code challenge is equal to the generated hash from +// the verifier. +func (c Code) IsValid() bool { + if c.Method == CodeChallengeMethodUndefined { + return false + } + + if c.Method == CodeChallengeMethodPLAIN { + return c.Challenge == c.Verifier + } + + h := c.Method.hash + h.Reset() + + return c.Challenge == base64.RawURLEncoding.EncodeToString(h.Sum([]byte(c.Verifier))) +} diff --git a/internal/domain/code_test.go b/internal/domain/code_test.go new file mode 100644 index 0000000..6cc34c5 --- /dev/null +++ b/internal/domain/code_test.go @@ -0,0 +1,93 @@ +package domain_test + +import ( + "crypto/md5" + "crypto/sha1" + "crypto/sha256" + "crypto/sha512" + "encoding/base64" + "hash" + "testing" + + "github.com/brianvoe/gofakeit/v6" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "source.toby3d.me/website/oauth/internal/domain" + "source.toby3d.me/website/oauth/internal/random" +) + +//nolint: funlen +func TestCodeIsValid(t *testing.T) { + t.Parallel() + + verifier, err := random.String(gofakeit.Number(domain.CodeLengthMin, domain.CodeLengthMax)) + require.NoError(t, err) + + for _, testCase := range []struct { + hash hash.Hash + name string + method string + isValid bool + }{{ + name: "invalid", + method: domain.CodeChallengeMethodS256.String(), + hash: md5.New(), + isValid: false, + }, { + name: "MD5", + method: domain.CodeChallengeMethodMD5.String(), + hash: md5.New(), + isValid: true, + }, { + name: "plain", + method: domain.CodeChallengeMethodPLAIN.String(), + hash: nil, + isValid: true, + }, { + name: "S1", + method: domain.CodeChallengeMethodS1.String(), + hash: sha1.New(), + isValid: true, + }, { + name: "S256", + method: domain.CodeChallengeMethodS256.String(), + hash: sha256.New(), + isValid: true, + }, { + name: "S512", + method: domain.CodeChallengeMethodS512.String(), + hash: sha512.New(), + isValid: true, + }, { + name: "undefined", + method: "und", + hash: nil, + isValid: false, + }} { + testCase := testCase + + t.Run(testCase.name, func(t *testing.T) { + t.Parallel() + + method, _ := domain.ParseCodeChallengeMethod(testCase.method) + result := &domain.Code{ + Method: method, + Verifier: verifier, + Challenge: verifier, + } + + if method == domain.CodeChallengeMethodPLAIN || + method == domain.CodeChallengeMethodUndefined { + assert.Equal(t, testCase.isValid, result.IsValid()) + + return + } + + result.Challenge = base64.RawURLEncoding.EncodeToString( + testCase.hash.Sum([]byte(result.Verifier)), + ) + assert.Equal(t, testCase.isValid, result.IsValid()) + }) + } +}