🔒 Improved random package
This commit is contained in:
parent
eda0f7095c
commit
f9ec91c246
|
@ -4,13 +4,16 @@ import (
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/brianvoe/gofakeit/v6"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
|
||||||
"source.toby3d.me/website/oauth/internal/random"
|
"source.toby3d.me/website/oauth/internal/random"
|
||||||
)
|
)
|
||||||
|
|
||||||
type Login struct {
|
type Login struct {
|
||||||
|
PKCE
|
||||||
CreatedAt time.Time
|
CreatedAt time.Time
|
||||||
CompletedAt time.Time
|
CompletedAt time.Time
|
||||||
PKCE
|
|
||||||
Scopes []string
|
Scopes []string
|
||||||
ClientID string
|
ClientID string
|
||||||
RedirectURI string
|
RedirectURI string
|
||||||
|
@ -25,10 +28,11 @@ type Login struct {
|
||||||
func TestLogin(tb testing.TB) *Login {
|
func TestLogin(tb testing.TB) *Login {
|
||||||
tb.Helper()
|
tb.Helper()
|
||||||
|
|
||||||
now := time.Now().UTC()
|
code, err := random.String(8)
|
||||||
|
require.NoError(tb, err)
|
||||||
|
|
||||||
return &Login{
|
return &Login{
|
||||||
CreatedAt: now.Add(-1 * time.Minute),
|
CreatedAt: gofakeit.Date(),
|
||||||
CompletedAt: time.Time{},
|
CompletedAt: time.Time{},
|
||||||
PKCE: PKCE{
|
PKCE: PKCE{
|
||||||
Method: PKCEMethodS256,
|
Method: PKCEMethodS256,
|
||||||
|
@ -40,7 +44,7 @@ func TestLogin(tb testing.TB) *Login {
|
||||||
RedirectURI: "https://app.example.com/redirect",
|
RedirectURI: "https://app.example.com/redirect",
|
||||||
MeEntered: "user.example.net",
|
MeEntered: "user.example.net",
|
||||||
MeResolved: "https://user.example.net/",
|
MeResolved: "https://user.example.net/",
|
||||||
Code: random.New().String(8),
|
Code: code,
|
||||||
Provider: "mastodon",
|
Provider: "mastodon",
|
||||||
IsCompleted: false,
|
IsCompleted: false,
|
||||||
}
|
}
|
||||||
|
@ -50,15 +54,19 @@ func TestLogin(tb testing.TB) *Login {
|
||||||
func TestLoginInvalid(tb testing.TB) *Login {
|
func TestLoginInvalid(tb testing.TB) *Login {
|
||||||
tb.Helper()
|
tb.Helper()
|
||||||
|
|
||||||
now := time.Now().UTC()
|
challenge, err := random.String(42)
|
||||||
|
require.NoError(tb, err)
|
||||||
|
|
||||||
|
verifier, err := random.String(64)
|
||||||
|
require.NoError(tb, err)
|
||||||
|
|
||||||
return &Login{
|
return &Login{
|
||||||
CreatedAt: now.Add(-1 * time.Hour),
|
CreatedAt: time.Now().UTC().Add(-1 * time.Hour),
|
||||||
CompletedAt: time.Time{},
|
CompletedAt: time.Time{},
|
||||||
PKCE: PKCE{
|
PKCE: PKCE{
|
||||||
Method: "UNDEFINED",
|
Method: "UNDEFINED",
|
||||||
Challenge: random.New().String(42),
|
Challenge: challenge,
|
||||||
Verifier: random.New().String(64),
|
Verifier: verifier,
|
||||||
},
|
},
|
||||||
Scopes: []string{},
|
Scopes: []string{},
|
||||||
ClientID: "whoisit",
|
ClientID: "whoisit",
|
||||||
|
@ -66,7 +74,7 @@ func TestLoginInvalid(tb testing.TB) *Login {
|
||||||
MeEntered: "whoami",
|
MeEntered: "whoami",
|
||||||
MeResolved: "",
|
MeResolved: "",
|
||||||
Code: "",
|
Code: "",
|
||||||
Provider: "",
|
Provider: "undefined",
|
||||||
IsCompleted: true,
|
IsCompleted: true,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -31,6 +31,11 @@ func NewToken() *Token {
|
||||||
func TestToken(tb testing.TB) *Token {
|
func TestToken(tb testing.TB) *Token {
|
||||||
tb.Helper()
|
tb.Helper()
|
||||||
|
|
||||||
|
require := require.New(tb)
|
||||||
|
|
||||||
|
nonce, err := random.String(50)
|
||||||
|
require.NoError(err)
|
||||||
|
|
||||||
client := TestClient(tb)
|
client := TestClient(tb)
|
||||||
profile := TestProfile(tb)
|
profile := TestProfile(tb)
|
||||||
now := time.Now().UTC().Round(time.Second)
|
now := time.Now().UTC().Round(time.Second)
|
||||||
|
@ -48,10 +53,10 @@ func TestToken(tb testing.TB) *Token {
|
||||||
|
|
||||||
// optional
|
// optional
|
||||||
t.Set("scope", strings.Join(scopes, " "))
|
t.Set("scope", strings.Join(scopes, " "))
|
||||||
t.Set("nonce", random.New().String(32))
|
t.Set("nonce", nonce)
|
||||||
|
|
||||||
accessToken, err := jwt.Sign(t, jwa.HS256, []byte("hackme"))
|
accessToken, err := jwt.Sign(t, jwa.HS256, []byte("hackme"))
|
||||||
require.NoError(tb, err)
|
require.NoError(err)
|
||||||
|
|
||||||
return &Token{
|
return &Token{
|
||||||
AccessToken: string(accessToken),
|
AccessToken: string(accessToken),
|
||||||
|
|
|
@ -108,7 +108,12 @@ func CSRFWithConfig(config CSRFConfig) Interceptor {
|
||||||
if k := ctx.Request.Header.Cookie(config.CookieName); k != nil {
|
if k := ctx.Request.Header.Cookie(config.CookieName); k != nil {
|
||||||
token = k
|
token = k
|
||||||
} else {
|
} else {
|
||||||
token = []byte(random.New().String(config.TokenLength))
|
var err error
|
||||||
|
if token, err = random.Bytes(config.TokenLength); err != nil {
|
||||||
|
ctx.Error(err.Error(), http.StatusInternalServerError)
|
||||||
|
|
||||||
|
return
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
switch string(ctx.Method()) {
|
switch string(ctx.Method()) {
|
||||||
|
|
|
@ -1,13 +1,11 @@
|
||||||
package random
|
package random
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"math/rand"
|
"crypto/rand"
|
||||||
|
"math/big"
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
type Random struct{}
|
|
||||||
|
|
||||||
const (
|
const (
|
||||||
Uppercase = "ABCDEFGHIJKLMNOPQRSTUVWXYZ"
|
Uppercase = "ABCDEFGHIJKLMNOPQRSTUVWXYZ"
|
||||||
Lowercase = "abcdefghijklmnopqrstuvwxyz"
|
Lowercase = "abcdefghijklmnopqrstuvwxyz"
|
||||||
|
@ -18,13 +16,17 @@ const (
|
||||||
Hex = Numeric + "abcdef"
|
Hex = Numeric + "abcdef"
|
||||||
)
|
)
|
||||||
|
|
||||||
func New() *Random {
|
func Bytes(length int) ([]byte, error) {
|
||||||
rand.Seed(time.Now().UnixNano())
|
b := make([]byte, length)
|
||||||
|
|
||||||
return new(Random)
|
if _, err := rand.Read(b); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return b, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *Random) String(length int, charsets ...string) string {
|
func String(length int, charsets ...string) (string, error) {
|
||||||
charset := strings.Join(charsets, "")
|
charset := strings.Join(charsets, "")
|
||||||
|
|
||||||
if charset == "" {
|
if charset == "" {
|
||||||
|
@ -34,9 +36,13 @@ func (r *Random) String(length int, charsets ...string) string {
|
||||||
b := make([]byte, length)
|
b := make([]byte, length)
|
||||||
|
|
||||||
for i := range b {
|
for i := range b {
|
||||||
//nolint: gosec
|
n, err := rand.Int(rand.Reader, big.NewInt(int64(len(charset))))
|
||||||
b[i] = charset[rand.Int()%len(charset)]
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
|
||||||
|
b[i] = charset[n.Int64()]
|
||||||
}
|
}
|
||||||
|
|
||||||
return string(b)
|
return string(b), nil
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue