♻️ Refactored Code validation into CodeChallengeMethod
This commit is contained in:
parent
4cc934a48c
commit
38c3039ba7
|
@ -1,63 +0,0 @@
|
||||||
//nolint: gosec
|
|
||||||
package domain
|
|
||||||
|
|
||||||
import (
|
|
||||||
"encoding/base64"
|
|
||||||
"testing"
|
|
||||||
|
|
||||||
"github.com/brianvoe/gofakeit/v6"
|
|
||||||
"github.com/stretchr/testify/require"
|
|
||||||
|
|
||||||
"source.toby3d.me/website/indieauth/internal/random"
|
|
||||||
)
|
|
||||||
|
|
||||||
// Code describes the PKCE challenge to validate the security of the request.
|
|
||||||
type Code struct {
|
|
||||||
Method CodeChallengeMethod
|
|
||||||
Verifier string
|
|
||||||
Challenge string
|
|
||||||
}
|
|
||||||
|
|
||||||
const (
|
|
||||||
CodeLengthMin int = 43
|
|
||||||
CodeLengthMax int = 128
|
|
||||||
)
|
|
||||||
|
|
||||||
// TestCode returns valid random generated PKCE code for tests.
|
|
||||||
func TestCode(tb testing.TB) *Code {
|
|
||||||
tb.Helper()
|
|
||||||
|
|
||||||
verifier, err := random.String(
|
|
||||||
gofakeit.Number(CodeLengthMin, CodeLengthMax), random.Alphanumeric, "-", ".", "_", "~",
|
|
||||||
)
|
|
||||||
require.NoError(tb, err)
|
|
||||||
|
|
||||||
h := CodeChallengeMethodS256.hash
|
|
||||||
h.Reset()
|
|
||||||
|
|
||||||
_, err = h.Write([]byte(verifier))
|
|
||||||
require.NoError(tb, err)
|
|
||||||
|
|
||||||
return &Code{
|
|
||||||
Method: CodeChallengeMethodS256,
|
|
||||||
Verifier: verifier,
|
|
||||||
Challenge: base64.RawURLEncoding.EncodeToString(h.Sum(nil)),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// IsValid returns true if code challenge is equal to the generated hash from
|
|
||||||
// the verifier.
|
|
||||||
func (c Code) IsValid() bool {
|
|
||||||
if c.Method == CodeChallengeMethodUndefined {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
if c.Method == CodeChallengeMethodPLAIN {
|
|
||||||
return c.Challenge == c.Verifier
|
|
||||||
}
|
|
||||||
|
|
||||||
h := c.Method.hash
|
|
||||||
h.Reset()
|
|
||||||
|
|
||||||
return c.Challenge == base64.RawURLEncoding.EncodeToString(h.Sum([]byte(c.Verifier)))
|
|
||||||
}
|
|
|
@ -5,6 +5,7 @@ import (
|
||||||
"crypto/sha1"
|
"crypto/sha1"
|
||||||
"crypto/sha256"
|
"crypto/sha256"
|
||||||
"crypto/sha512"
|
"crypto/sha512"
|
||||||
|
"encoding/base64"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"hash"
|
"hash"
|
||||||
|
@ -106,6 +107,14 @@ func (ccm CodeChallengeMethod) String() string {
|
||||||
return ccm.slug
|
return ccm.slug
|
||||||
}
|
}
|
||||||
|
|
||||||
func (ccm CodeChallengeMethod) Encoder() hash.Hash {
|
func (ccm CodeChallengeMethod) Validate(codeChallenge, verifier string) bool {
|
||||||
return ccm.hash
|
if ccm.slug == CodeChallengeMethodUndefined.slug {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
if ccm.slug == CodeChallengeMethodPLAIN.slug {
|
||||||
|
return codeChallenge == verifier
|
||||||
|
}
|
||||||
|
|
||||||
|
return codeChallenge == base64.RawURLEncoding.EncodeToString(ccm.hash.Sum([]byte(verifier)))
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,14 +1,23 @@
|
||||||
package domain_test
|
package domain_test
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"crypto/md5"
|
||||||
|
"crypto/sha1"
|
||||||
|
"crypto/sha256"
|
||||||
|
"crypto/sha512"
|
||||||
|
"encoding/base64"
|
||||||
|
"hash"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
|
"github.com/brianvoe/gofakeit/v6"
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
|
||||||
"source.toby3d.me/website/indieauth/internal/domain"
|
"source.toby3d.me/website/indieauth/internal/domain"
|
||||||
|
"source.toby3d.me/website/indieauth/internal/random"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestParseCodeChallengeMethod(t *testing.T) {
|
func TestCodeChallengeMethod_Parse(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
|
|
||||||
for _, testCase := range []struct {
|
for _, testCase := range []struct {
|
||||||
|
@ -58,3 +67,71 @@ func TestParseCodeChallengeMethod(t *testing.T) {
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
//nolint: funlen
|
||||||
|
func TestCodeChallengeMethod_Validate(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
verifier, err := random.String(gofakeit.Number(42, 128))
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
for _, testCase := range []struct {
|
||||||
|
hash hash.Hash
|
||||||
|
name string
|
||||||
|
method string
|
||||||
|
isValid bool
|
||||||
|
}{{
|
||||||
|
name: "invalid",
|
||||||
|
method: domain.CodeChallengeMethodS256.String(),
|
||||||
|
hash: md5.New(),
|
||||||
|
isValid: false,
|
||||||
|
}, {
|
||||||
|
name: "MD5",
|
||||||
|
method: domain.CodeChallengeMethodMD5.String(),
|
||||||
|
hash: md5.New(),
|
||||||
|
isValid: true,
|
||||||
|
}, {
|
||||||
|
name: "plain",
|
||||||
|
method: domain.CodeChallengeMethodPLAIN.String(),
|
||||||
|
hash: nil,
|
||||||
|
isValid: true,
|
||||||
|
}, {
|
||||||
|
name: "S1",
|
||||||
|
method: domain.CodeChallengeMethodS1.String(),
|
||||||
|
hash: sha1.New(),
|
||||||
|
isValid: true,
|
||||||
|
}, {
|
||||||
|
name: "S256",
|
||||||
|
method: domain.CodeChallengeMethodS256.String(),
|
||||||
|
hash: sha256.New(),
|
||||||
|
isValid: true,
|
||||||
|
}, {
|
||||||
|
name: "S512",
|
||||||
|
method: domain.CodeChallengeMethodS512.String(),
|
||||||
|
hash: sha512.New(),
|
||||||
|
isValid: true,
|
||||||
|
}, {
|
||||||
|
name: "undefined",
|
||||||
|
method: "und",
|
||||||
|
hash: nil,
|
||||||
|
isValid: false,
|
||||||
|
}} {
|
||||||
|
testCase := testCase
|
||||||
|
|
||||||
|
t.Run(testCase.name, func(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
method, _ := domain.ParseCodeChallengeMethod(testCase.method)
|
||||||
|
if method == domain.CodeChallengeMethodPLAIN ||
|
||||||
|
method == domain.CodeChallengeMethodUndefined {
|
||||||
|
assert.Equal(t, testCase.isValid, method.Validate(verifier, verifier))
|
||||||
|
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
assert.Equal(t, testCase.isValid, method.Validate(base64.RawURLEncoding.EncodeToString(
|
||||||
|
testCase.hash.Sum([]byte(verifier)),
|
||||||
|
), verifier))
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
|
@ -1,93 +0,0 @@
|
||||||
package domain_test
|
|
||||||
|
|
||||||
import (
|
|
||||||
"crypto/md5"
|
|
||||||
"crypto/sha1"
|
|
||||||
"crypto/sha256"
|
|
||||||
"crypto/sha512"
|
|
||||||
"encoding/base64"
|
|
||||||
"hash"
|
|
||||||
"testing"
|
|
||||||
|
|
||||||
"github.com/brianvoe/gofakeit/v6"
|
|
||||||
"github.com/stretchr/testify/assert"
|
|
||||||
"github.com/stretchr/testify/require"
|
|
||||||
|
|
||||||
"source.toby3d.me/website/indieauth/internal/domain"
|
|
||||||
"source.toby3d.me/website/indieauth/internal/random"
|
|
||||||
)
|
|
||||||
|
|
||||||
//nolint: funlen
|
|
||||||
func TestCodeIsValid(t *testing.T) {
|
|
||||||
t.Parallel()
|
|
||||||
|
|
||||||
verifier, err := random.String(gofakeit.Number(domain.CodeLengthMin, domain.CodeLengthMax))
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
for _, testCase := range []struct {
|
|
||||||
hash hash.Hash
|
|
||||||
name string
|
|
||||||
method string
|
|
||||||
isValid bool
|
|
||||||
}{{
|
|
||||||
name: "invalid",
|
|
||||||
method: domain.CodeChallengeMethodS256.String(),
|
|
||||||
hash: md5.New(),
|
|
||||||
isValid: false,
|
|
||||||
}, {
|
|
||||||
name: "MD5",
|
|
||||||
method: domain.CodeChallengeMethodMD5.String(),
|
|
||||||
hash: md5.New(),
|
|
||||||
isValid: true,
|
|
||||||
}, {
|
|
||||||
name: "plain",
|
|
||||||
method: domain.CodeChallengeMethodPLAIN.String(),
|
|
||||||
hash: nil,
|
|
||||||
isValid: true,
|
|
||||||
}, {
|
|
||||||
name: "S1",
|
|
||||||
method: domain.CodeChallengeMethodS1.String(),
|
|
||||||
hash: sha1.New(),
|
|
||||||
isValid: true,
|
|
||||||
}, {
|
|
||||||
name: "S256",
|
|
||||||
method: domain.CodeChallengeMethodS256.String(),
|
|
||||||
hash: sha256.New(),
|
|
||||||
isValid: true,
|
|
||||||
}, {
|
|
||||||
name: "S512",
|
|
||||||
method: domain.CodeChallengeMethodS512.String(),
|
|
||||||
hash: sha512.New(),
|
|
||||||
isValid: true,
|
|
||||||
}, {
|
|
||||||
name: "undefined",
|
|
||||||
method: "und",
|
|
||||||
hash: nil,
|
|
||||||
isValid: false,
|
|
||||||
}} {
|
|
||||||
testCase := testCase
|
|
||||||
|
|
||||||
t.Run(testCase.name, func(t *testing.T) {
|
|
||||||
t.Parallel()
|
|
||||||
|
|
||||||
method, _ := domain.ParseCodeChallengeMethod(testCase.method)
|
|
||||||
result := &domain.Code{
|
|
||||||
Method: method,
|
|
||||||
Verifier: verifier,
|
|
||||||
Challenge: verifier,
|
|
||||||
}
|
|
||||||
|
|
||||||
if method == domain.CodeChallengeMethodPLAIN ||
|
|
||||||
method == domain.CodeChallengeMethodUndefined {
|
|
||||||
assert.Equal(t, testCase.isValid, result.IsValid())
|
|
||||||
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
result.Challenge = base64.RawURLEncoding.EncodeToString(
|
|
||||||
testCase.hash.Sum([]byte(result.Verifier)),
|
|
||||||
)
|
|
||||||
assert.Equal(t, testCase.isValid, result.IsValid())
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
Loading…
Reference in New Issue