diff --git a/internal/domain/login.go b/internal/domain/login.go index b98e619..960bc2c 100644 --- a/internal/domain/login.go +++ b/internal/domain/login.go @@ -4,13 +4,16 @@ import ( "testing" "time" + "github.com/brianvoe/gofakeit/v6" + "github.com/stretchr/testify/require" + "source.toby3d.me/website/oauth/internal/random" ) type Login struct { + PKCE CreatedAt time.Time CompletedAt time.Time - PKCE Scopes []string ClientID string RedirectURI string @@ -25,10 +28,11 @@ type Login struct { func TestLogin(tb testing.TB) *Login { tb.Helper() - now := time.Now().UTC() + code, err := random.String(8) + require.NoError(tb, err) return &Login{ - CreatedAt: now.Add(-1 * time.Minute), + CreatedAt: gofakeit.Date(), CompletedAt: time.Time{}, PKCE: PKCE{ Method: PKCEMethodS256, @@ -40,7 +44,7 @@ func TestLogin(tb testing.TB) *Login { RedirectURI: "https://app.example.com/redirect", MeEntered: "user.example.net", MeResolved: "https://user.example.net/", - Code: random.New().String(8), + Code: code, Provider: "mastodon", IsCompleted: false, } @@ -50,15 +54,19 @@ func TestLogin(tb testing.TB) *Login { func TestLoginInvalid(tb testing.TB) *Login { tb.Helper() - now := time.Now().UTC() + challenge, err := random.String(42) + require.NoError(tb, err) + + verifier, err := random.String(64) + require.NoError(tb, err) return &Login{ - CreatedAt: now.Add(-1 * time.Hour), + CreatedAt: time.Now().UTC().Add(-1 * time.Hour), CompletedAt: time.Time{}, PKCE: PKCE{ Method: "UNDEFINED", - Challenge: random.New().String(42), - Verifier: random.New().String(64), + Challenge: challenge, + Verifier: verifier, }, Scopes: []string{}, ClientID: "whoisit", @@ -66,7 +74,7 @@ func TestLoginInvalid(tb testing.TB) *Login { MeEntered: "whoami", MeResolved: "", Code: "", - Provider: "", + Provider: "undefined", IsCompleted: true, } } diff --git a/internal/domain/token.go b/internal/domain/token.go index 7ceaaad..f4eb20e 100644 --- a/internal/domain/token.go +++ b/internal/domain/token.go @@ -31,6 +31,11 @@ func NewToken() *Token { func TestToken(tb testing.TB) *Token { tb.Helper() + require := require.New(tb) + + nonce, err := random.String(50) + require.NoError(err) + client := TestClient(tb) profile := TestProfile(tb) now := time.Now().UTC().Round(time.Second) @@ -48,10 +53,10 @@ func TestToken(tb testing.TB) *Token { // optional t.Set("scope", strings.Join(scopes, " ")) - t.Set("nonce", random.New().String(32)) + t.Set("nonce", nonce) accessToken, err := jwt.Sign(t, jwa.HS256, []byte("hackme")) - require.NoError(tb, err) + require.NoError(err) return &Token{ AccessToken: string(accessToken), diff --git a/internal/middleware/csrf.go b/internal/middleware/csrf.go index 22dea55..c04e5e9 100644 --- a/internal/middleware/csrf.go +++ b/internal/middleware/csrf.go @@ -108,7 +108,12 @@ func CSRFWithConfig(config CSRFConfig) Interceptor { if k := ctx.Request.Header.Cookie(config.CookieName); k != nil { token = k } else { - token = []byte(random.New().String(config.TokenLength)) + var err error + if token, err = random.Bytes(config.TokenLength); err != nil { + ctx.Error(err.Error(), http.StatusInternalServerError) + + return + } } switch string(ctx.Method()) { diff --git a/internal/random/random.go b/internal/random/random.go index 3590ee5..31a6082 100644 --- a/internal/random/random.go +++ b/internal/random/random.go @@ -1,13 +1,11 @@ package random import ( - "math/rand" + "crypto/rand" + "math/big" "strings" - "time" ) -type Random struct{} - const ( Uppercase = "ABCDEFGHIJKLMNOPQRSTUVWXYZ" Lowercase = "abcdefghijklmnopqrstuvwxyz" @@ -18,13 +16,17 @@ const ( Hex = Numeric + "abcdef" ) -func New() *Random { - rand.Seed(time.Now().UnixNano()) +func Bytes(length int) ([]byte, error) { + b := make([]byte, length) - return new(Random) + if _, err := rand.Read(b); err != nil { + return nil, err + } + + return b, nil } -func (r *Random) String(length int, charsets ...string) string { +func String(length int, charsets ...string) (string, error) { charset := strings.Join(charsets, "") if charset == "" { @@ -34,9 +36,13 @@ func (r *Random) String(length int, charsets ...string) string { b := make([]byte, length) for i := range b { - //nolint: gosec - b[i] = charset[rand.Int()%len(charset)] + n, err := rand.Int(rand.Reader, big.NewInt(int64(len(charset)))) + if err != nil { + return "", err + } + + b[i] = charset[n.Int64()] } - return string(b) + return string(b), nil }