diff --git a/internal/domain/action_test.go b/internal/domain/action_test.go new file mode 100644 index 0000000..35a0f01 --- /dev/null +++ b/internal/domain/action_test.go @@ -0,0 +1,96 @@ +package domain_test + +import ( + "testing" + + "source.toby3d.me/website/indieauth/internal/domain" +) + +func TestParseAction(t *testing.T) { + t.Parallel() + + for _, tc := range []struct { + in string + out domain.Action + }{{ + in: "revoke", + out: domain.ActionRevoke, + }, { + in: "ticket", + out: domain.ActionTicket, + }} { + tc := tc + + t.Run(tc.in, func(t *testing.T) { + t.Parallel() + + result, err := domain.ParseAction(tc.in) + if err != nil { + t.Fatalf("%+v", err) + } + + if result != tc.out { + t.Errorf("ParseAction(%s) = %v, want %v", tc.in, result, tc.out) + } + }) + } +} + +func TestAction_UnmarshalForm(t *testing.T) { + t.Parallel() + + input := []byte("revoke") + result := domain.ActionUndefined + + if err := result.UnmarshalForm(input); err != nil { + t.Fatalf("%+v", err) + } + + if result != domain.ActionRevoke { + t.Errorf("UnmarshalForm(%s) = %v, want %v", input, result, domain.ActionRevoke) + } +} + +func TestAction_UnmarshalJSON(t *testing.T) { + t.Parallel() + + input := []byte(`"revoke"`) + result := domain.ActionUndefined + + if err := result.UnmarshalJSON(input); err != nil { + t.Fatalf("%+v", err) + } + + if result != domain.ActionRevoke { + t.Errorf("UnmarshalJSON(%s) = %v, want %v", input, result, domain.ActionRevoke) + } +} + +func TestAction_String(t *testing.T) { + t.Parallel() + + for _, tc := range []struct { + name string + in domain.Action + out string + }{{ + name: "revoke", + in: domain.ActionRevoke, + out: "revoke", + }, { + name: "ticket", + in: domain.ActionTicket, + out: "ticket", + }} { + tc := tc + + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + result := tc.in.String() + if result != tc.out { + t.Errorf("String() = %v, want %v", result, tc.out) + } + }) + } +} diff --git a/internal/domain/client.go b/internal/domain/client.go index 701e13d..ba4a629 100644 --- a/internal/domain/client.go +++ b/internal/domain/client.go @@ -33,8 +33,8 @@ func TestClient(tb testing.TB) *Client { redirects := make([]*URL, 0) for _, redirect := range []string{ - "https://app.example.net/redirect", "https://app.example.com/redirect", + "https://app.example.net/redirect", } { redirects = append(redirects, TestURL(tb, redirect)) } @@ -89,7 +89,7 @@ func (c *Client) ValidateRedirectURI(redirectURI *URL) bool { // GetName safe returns first name, if any. func (c Client) GetName() string { - if len(c.Name) < 1 { + if len(c.Name) == 0 { return "" } @@ -98,7 +98,7 @@ func (c Client) GetName() string { // GetURL safe returns first uRL, if any. func (c Client) GetURL() *URL { - if len(c.URL) < 1 { + if len(c.URL) == 0 { return nil } @@ -107,7 +107,7 @@ func (c Client) GetURL() *URL { // GetLogo safe returns first logo, if any. func (c Client) GetLogo() *URL { - if len(c.Logo) < 1 { + if len(c.Logo) == 0 { return nil } diff --git a/internal/domain/client_id.go b/internal/domain/client_id.go index bb4b1ef..9e8a327 100644 --- a/internal/domain/client_id.go +++ b/internal/domain/client_id.go @@ -7,7 +7,6 @@ import ( "strings" "testing" - "github.com/stretchr/testify/require" http "github.com/valyala/fasthttp" "golang.org/x/xerrors" "inet.af/netaddr" @@ -126,7 +125,9 @@ func TestClientID(tb testing.TB) *ClientID { tb.Helper() clientID, err := ParseClientID("https://app.example.com/") - require.NoError(tb, err) + if err != nil { + tb.Fatalf("%+v", err) + } return clientID } diff --git a/internal/domain/client_id_test.go b/internal/domain/client_id_test.go index 7fe9e33..918c36c 100644 --- a/internal/domain/client_id_test.go +++ b/internal/domain/client_id_test.go @@ -1,11 +1,9 @@ package domain_test import ( + "fmt" "testing" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" - "source.toby3d.me/website/indieauth/internal/domain" ) @@ -13,67 +11,128 @@ import ( func TestParseClientID(t *testing.T) { t.Parallel() - for _, testCase := range []struct { - name string - input string - isValid bool + for _, tc := range []struct { + name string + in string + expError bool }{{ - name: "valid", - input: "https://example.com/", - isValid: true, + name: "valid", + in: "https://example.com/", + expError: false, }, { - name: "valid with path", - input: "https://example.com/username", - isValid: true, + name: "valid path", + in: "https://example.com/username", + expError: false, }, { - name: "valid with query", - input: "https://example.com/users?id=100", - isValid: true, + name: "valid query", + in: "https://example.com/users?id=100", + expError: false, }, { - name: "valid with port", - input: "https://example.com:8443/", - isValid: true, + name: "valid port", + in: "https://example.com:8443/", + expError: false, }, { - name: "valid loopback", - input: "https://127.0.0.1:8443/", - isValid: true, + name: "valid loopback", + in: "https://127.0.0.1:8443/", + expError: false, }, { - name: "missing scheme", - input: "example.com", - isValid: false, + name: "missing scheme", + in: "example.com", + expError: true, }, { - name: "invalid scheme", - input: "mailto:user@example.com", - isValid: false, + name: "invalid scheme", + in: "mailto:user@example.com", + expError: true, }, { - name: "contains a double-dot path segment", - input: "https://example.com/foo/../bar", - isValid: false, + name: "invalid double-dot path", + in: "https://example.com/foo/../bar", + expError: true, }, { - name: "contains a fragment", - input: "https://example.com/#me", - isValid: false, + name: "invalid fragment", + in: "https://example.com/#me", + expError: true, }, { - name: "contains a username and password", - input: "https://user:pass@example.com/", - isValid: false, + name: "invalid user", + in: "https://user:pass@example.com/", + expError: true, }, { - name: "host is an IP address", - input: "https://172.28.92.51/", - isValid: false, + name: "host is an IP address", + in: "https://172.28.92.51/", + expError: true, }} { - testCase := testCase + tc := tc - t.Run(testCase.name, func(t *testing.T) { + t.Run(tc.name, func(t *testing.T) { t.Parallel() - result, err := domain.ParseClientID(testCase.input) - if testCase.isValid { - require.NoError(t, err) - assert.Equal(t, testCase.input, result.String()) - } else { - assert.Error(t, err) + _, err := domain.ParseClientID(tc.in) + + switch { + case err != nil && !tc.expError: + t.Errorf("ParseClientID(%s) = %+v, want nil", tc.in, err) + case err == nil && tc.expError: + t.Errorf("ParseClientID(%s) = %+v, want error", tc.in, err) } }) } } + +func TestClientID_UnmarshalForm(t *testing.T) { + t.Parallel() + + cid := domain.TestClientID(t) + input := []byte(fmt.Sprint(cid)) + result := new(domain.ClientID) + + if err := result.UnmarshalForm(input); err != nil { + t.Fatalf("%+v", err) + } + + if fmt.Sprint(result) != fmt.Sprint(cid) { + t.Errorf("UnmarshalForm(%s) = %v, want %v", input, result, cid) + } +} + +func TestClientID_UnmarshalJSON(t *testing.T) { + t.Parallel() + + cid := domain.TestClientID(t) + input := []byte(fmt.Sprintf(`"%s"`, cid)) + result := new(domain.ClientID) + + if err := result.UnmarshalJSON(input); err != nil { + t.Fatalf("%+v", err) + } + + if fmt.Sprint(result) != fmt.Sprint(cid) { + t.Errorf("UnmarshalJSON(%s) = %v, want %v", input, result, cid) + } +} + +func TestClientID_MarshalJSON(t *testing.T) { + t.Parallel() + + cid := domain.TestClientID(t) + + result, err := cid.MarshalJSON() + if err != nil { + t.Fatalf("%+v", err) + } + + if string(result) != fmt.Sprintf(`"%s"`, cid) { + t.Errorf("MarshalJSON() = %s, want %s", result, fmt.Sprintf(`"%s"`, cid)) + } +} + +// TODO(toby3d): TestClientID_URI + +// TODO(toby3d): TestClientID_URL + +func TestClientID_String(t *testing.T) { + t.Parallel() + + cid := domain.TestClientID(t) + if result := cid.String(); result != fmt.Sprint(cid) { + t.Errorf("Strig() = %s, want %s", result, fmt.Sprint(cid)) + } +} diff --git a/internal/domain/client_test.go b/internal/domain/client_test.go index f16f46d..8d433d6 100644 --- a/internal/domain/client_test.go +++ b/internal/domain/client_test.go @@ -1,11 +1,9 @@ package domain_test import ( + "fmt" "testing" - "github.com/stretchr/testify/assert" - http "github.com/valyala/fasthttp" - "source.toby3d.me/website/indieauth/internal/domain" ) @@ -14,35 +12,51 @@ func TestClient_ValidateRedirectURI(t *testing.T) { client := domain.TestClient(t) - for _, testCase := range []struct { - name string - input func() *domain.URL - expResult bool + for _, tc := range []struct { + name string + in *domain.URL }{{ name: "client_id prefix", - input: func() *domain.URL { - u := &domain.URL{ - URI: http.AcquireURI(), - } - client.ID.URI().CopyTo(u.URI) - u.SetPath("/callback") - - return u - }, - expResult: true, + in: domain.TestURL(t, fmt.Sprint(client.ID, "/callback")), }, { name: "registered redirect_uri", - input: func() *domain.URL { - return client.RedirectURI[len(client.RedirectURI)-1] - }, - expResult: true, + in: client.RedirectURI[len(client.RedirectURI)-1], }} { - testCase := testCase + tc := tc - t.Run(testCase.name, func(t *testing.T) { + t.Run(tc.name, func(t *testing.T) { t.Parallel() - assert.Equal(t, testCase.expResult, client.ValidateRedirectURI(testCase.input())) + if result := client.ValidateRedirectURI(tc.in); !result { + t.Errorf("ValidateRedirectURI(%v) = %t, want %t", tc.in, result, true) + } }) } } + +func TestClient_GetName(t *testing.T) { + t.Parallel() + + client := domain.TestClient(t) + if result := client.GetName(); result != client.Name[0] { + t.Errorf("GetName() = %v, want %v", result, client.Name[0]) + } +} + +func TestClient_GetURL(t *testing.T) { + t.Parallel() + + client := domain.TestClient(t) + if result := client.GetURL(); result != client.URL[0] { + t.Errorf("GetURL() = %v, want %v", result, client.URL[0]) + } +} + +func TestClient_GetLogo(t *testing.T) { + t.Parallel() + + client := domain.TestClient(t) + if result := client.GetLogo(); result != client.Logo[0] { + t.Errorf("GetLogo() = %v, want %v", result, client.Logo[0]) + } +} diff --git a/internal/domain/code_challenge_method_test.go b/internal/domain/code_challenge_method_test.go index fbcaa17..182383f 100644 --- a/internal/domain/code_challenge_method_test.go +++ b/internal/domain/code_challenge_method_test.go @@ -10,8 +10,6 @@ import ( "testing" "github.com/brianvoe/gofakeit/v6" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" "source.toby3d.me/website/indieauth/internal/domain" "source.toby3d.me/website/indieauth/internal/random" @@ -20,49 +18,124 @@ import ( func TestParseCodeChallengeMethod(t *testing.T) { t.Parallel() - for _, testCase := range []struct { - output domain.CodeChallengeMethod + for _, tc := range []struct { name string - input string + in string + out domain.CodeChallengeMethod expError bool }{{ expError: true, name: "invalid", - input: "und", - output: domain.CodeChallengeMethodUndefined, + in: "und", + out: domain.CodeChallengeMethodUndefined, }, { - name: "PLAIN", - input: "plain", - output: domain.CodeChallengeMethodPLAIN, + name: "PLAIN", + in: "plain", + out: domain.CodeChallengeMethodPLAIN, }, { - name: "MD5", - input: "Md5", - output: domain.CodeChallengeMethodMD5, + name: "MD5", + in: "Md5", + out: domain.CodeChallengeMethodMD5, }, { - name: "S1", - input: "S1", - output: domain.CodeChallengeMethodS1, + name: "S1", + in: "S1", + out: domain.CodeChallengeMethodS1, }, { - name: "S256", - input: "S256", - output: domain.CodeChallengeMethodS256, + name: "S256", + in: "S256", + out: domain.CodeChallengeMethodS256, }, { - name: "S512", - input: "S512", - output: domain.CodeChallengeMethodS512, + name: "S512", + in: "S512", + out: domain.CodeChallengeMethodS512, }} { - testCase := testCase + tc := tc - t.Run(testCase.name, func(t *testing.T) { + t.Run(tc.name, func(t *testing.T) { t.Parallel() - result, err := domain.ParseCodeChallengeMethod(testCase.input) - if testCase.expError { - assert.Error(t, err) - assert.Equal(t, domain.CodeChallengeMethodUndefined, result) - } else { - assert.NoError(t, err) - assert.Equal(t, testCase.output, result) + result, err := domain.ParseCodeChallengeMethod(tc.in) + + switch { + case err != nil && !tc.expError: + t.Errorf("ParseCodeChallengeMethod(%s) = %+v, want nil", tc.in, err) + case err == nil && tc.expError: + t.Errorf("ParseCodeChallengeMethod(%s) = %+v, want error", tc.in, err) + } + + if result != tc.out { + t.Errorf("ParseCodeChallengeMethod(%s) = %v, want %v", tc.in, result, tc.out) + } + }) + } +} + +func TestCodeChallengeMethod_UnmarshalForm(t *testing.T) { + t.Parallel() + + input := []byte("S256") + result := domain.CodeChallengeMethodUndefined + + if err := result.UnmarshalForm(input); err != nil { + t.Fatalf("%+v", err) + } + + if result != domain.CodeChallengeMethodS256 { + t.Errorf("UnmarshalForm(%s) = %v, want %v", input, result, domain.CodeChallengeMethodS256) + } +} + +func TestCodeChallengeMethod_UnmarshalJSON(t *testing.T) { + t.Parallel() + + input := []byte(`"S256"`) + result := domain.CodeChallengeMethodUndefined + + if err := result.UnmarshalJSON(input); err != nil { + t.Fatalf("%+v", err) + } + + if result != domain.CodeChallengeMethodS256 { + t.Errorf("UnmarshalJSON(%s) = %v, want %v", input, result, domain.CodeChallengeMethodS256) + } +} + +func TestCodeChallengeMethod_String(t *testing.T) { + t.Parallel() + + for _, tc := range []struct { + name string + in domain.CodeChallengeMethod + out string + }{{ + name: "plain", + in: domain.CodeChallengeMethodPLAIN, + out: "PLAIN", + }, { + name: "md5", + in: domain.CodeChallengeMethodMD5, + out: "MD5", + }, { + name: "s1", + in: domain.CodeChallengeMethodS1, + out: "S1", + }, { + name: "s256", + in: domain.CodeChallengeMethodS256, + out: "S256", + }, { + name: "s512", + in: domain.CodeChallengeMethodS512, + out: "S512", + }} { + tc := tc + + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + result := tc.in.String() + if result != tc.out { + t.Errorf("String() = %v, want %v", result, tc.out) } }) } @@ -73,63 +146,68 @@ func TestCodeChallengeMethod_Validate(t *testing.T) { t.Parallel() verifier, err := random.String(gofakeit.Number(43, 128)) - require.NoError(t, err) + if err != nil { + t.Fatalf("%+v", err) + } - for _, testCase := range []struct { - hash hash.Hash - name string - method domain.CodeChallengeMethod - isValid bool + for _, tc := range []struct { + hash hash.Hash + in domain.CodeChallengeMethod + name string + expError bool }{{ - name: "invalid", - method: domain.CodeChallengeMethodS256, - hash: md5.New(), - isValid: false, + name: "invalid", + in: domain.CodeChallengeMethodS256, + hash: md5.New(), + expError: true, }, { - name: "MD5", - method: domain.CodeChallengeMethodMD5, - hash: md5.New(), - isValid: true, + name: "MD5", + in: domain.CodeChallengeMethodMD5, + hash: md5.New(), + expError: false, }, { - name: "plain", - method: domain.CodeChallengeMethodPLAIN, - hash: nil, - isValid: true, + name: "plain", + in: domain.CodeChallengeMethodPLAIN, + hash: nil, + expError: false, }, { - name: "S1", - method: domain.CodeChallengeMethodS1, - hash: sha1.New(), - isValid: true, + name: "S1", + in: domain.CodeChallengeMethodS1, + hash: sha1.New(), + expError: false, }, { - name: "S256", - method: domain.CodeChallengeMethodS256, - hash: sha256.New(), - isValid: true, + name: "S256", + in: domain.CodeChallengeMethodS256, + hash: sha256.New(), + expError: false, }, { - name: "S512", - method: domain.CodeChallengeMethodS512, - hash: sha512.New(), - isValid: true, + name: "S512", + in: domain.CodeChallengeMethodS512, + hash: sha512.New(), + expError: false, }, { - name: "undefined", - method: domain.CodeChallengeMethodUndefined, - hash: nil, - isValid: false, + name: "undefined", + in: domain.CodeChallengeMethodUndefined, + hash: nil, + expError: true, }} { - testCase := testCase + tc := tc - t.Run(testCase.name, func(t *testing.T) { + t.Run(tc.name, func(t *testing.T) { t.Parallel() - if testCase.method == domain.CodeChallengeMethodPLAIN || - testCase.method == domain.CodeChallengeMethodUndefined { - assert.Equal(t, testCase.isValid, testCase.method.Validate(verifier, verifier)) + var codeChallenge string - return + switch tc.in { + case domain.CodeChallengeMethodUndefined, domain.CodeChallengeMethodPLAIN: + codeChallenge = verifier + default: + codeChallenge = base64.RawURLEncoding.EncodeToString(tc.hash.Sum([]byte(verifier))) } - assert.Equal(t, testCase.isValid, testCase.method.Validate(base64.RawURLEncoding.EncodeToString( - testCase.hash.Sum([]byte(verifier))), verifier)) + if result := tc.in.Validate(codeChallenge, verifier); result != !tc.expError { + t.Errorf("Validate(%s, %s) = %t, want %t", codeChallenge, verifier, result, tc.expError) + } }) } } diff --git a/internal/domain/config_test.go b/internal/domain/config_test.go new file mode 100644 index 0000000..0f89f1a --- /dev/null +++ b/internal/domain/config_test.go @@ -0,0 +1,29 @@ +package domain_test + +import ( + "testing" + + "source.toby3d.me/website/indieauth/internal/domain" +) + +func TestConfigServer_GetAddress(t *testing.T) { + t.Parallel() + + config := domain.TestConfig(t) + expResult := config.Server.Host + ":" + config.Server.Port + + if result := config.Server.GetAddress(); result != expResult { + t.Errorf("GetAddress() = %s, want %s", result, expResult) + } +} + +func TestConfigServer_GetRootURL(t *testing.T) { + t.Parallel() + + config := domain.TestConfig(t) + expResult := config.Server.Protocol + "://" + config.Server.Domain + ":" + config.Server.Port + "/" + + if result := config.Server.GetRootURL(); result != expResult { + t.Errorf("GetRootURL() = %s, want %s", result, expResult) + } +} diff --git a/internal/domain/email.go b/internal/domain/email.go index 4c57861..49f5941 100644 --- a/internal/domain/email.go +++ b/internal/domain/email.go @@ -7,8 +7,9 @@ import ( // Email represent email identifier. type Email struct { - user string - host string + user string + host string + subAddress string } var ErrEmailInvalid error = NewError(ErrorCodeInvalidRequest, "cannot parse email", "") @@ -20,10 +21,18 @@ func ParseEmail(src string) (*Email, error) { return nil, ErrEmailInvalid } - return &Email{ - user: parts[0], - host: parts[1], - }, nil + result := &Email{ + user: parts[0], + host: parts[1], + subAddress: "", + } + + if userParts := strings.SplitN(parts[0], `+`, 2); len(userParts) > 1 { + result.user = userParts[0] + result.subAddress = userParts[1] + } + + return result, nil } // TestEmail returns valid random generated email identifier. @@ -31,12 +40,17 @@ func TestEmail(tb testing.TB) *Email { tb.Helper() return &Email{ - user: "user", - host: "example.com", + user: "user", + subAddress: "", + host: "example.com", } } // String returns string representation of email identifier. func (e Email) String() string { - return e.user + "@" + e.host + if e.subAddress == "" { + return e.user + "@" + e.host + } + + return e.user + "+" + e.subAddress + "@" + e.host } diff --git a/internal/domain/email_test.go b/internal/domain/email_test.go index 3f0b76f..5697c67 100644 --- a/internal/domain/email_test.go +++ b/internal/domain/email_test.go @@ -1,52 +1,54 @@ package domain_test import ( + "fmt" "testing" - "github.com/stretchr/testify/assert" - "source.toby3d.me/website/indieauth/internal/domain" ) func TestParseEmail(t *testing.T) { t.Parallel() - for _, testCase := range []struct { - name string - input string - expError bool - expResult string + for _, tc := range []struct { + name string + in string + out string }{{ - name: "simple", - input: "user@example.com", - expError: false, - expResult: "user@example.com", + name: "simple", + in: "user@example.com", + out: "user@example.com", }, { - name: "subaddress", - input: "user+suffix@example.com", - expError: false, - expResult: "user+suffix@example.com", + name: "subAddress", + in: "user+suffix@example.com", + out: "user+suffix@example.com", }, { - name: "prefix", - input: "mailto:user@example.com", - expError: false, - expResult: "user@example.com", + name: "mailto prefix", + in: "mailto:user@example.com", + out: "user@example.com", }} { - testCase := testCase + tc := tc - t.Run(testCase.name, func(t *testing.T) { + t.Run(tc.name, func(t *testing.T) { t.Parallel() - result, err := domain.ParseEmail(testCase.input) - if testCase.expError { - assert.Error(t, err) - assert.Nil(t, result) - - return + result, err := domain.ParseEmail(tc.in) + if err != nil { + t.Fatalf("%+v", err) } - assert.NoError(t, err) - assert.Equal(t, testCase.expResult, result.String()) + if fmt.Sprint(result) != tc.out { + t.Errorf("ParseEmail(%s) = %s, want %s", tc.in, result, tc.out) + } }) } } + +func TestEmail_String(t *testing.T) { + t.Parallel() + + email := domain.TestEmail(t) + if result := email.String(); result != fmt.Sprint(email) { + t.Errorf("String() = %v, want %v", result, email) + } +} diff --git a/internal/domain/grant_type_test.go b/internal/domain/grant_type_test.go new file mode 100644 index 0000000..40220fe --- /dev/null +++ b/internal/domain/grant_type_test.go @@ -0,0 +1,96 @@ +package domain_test + +import ( + "testing" + + "source.toby3d.me/website/indieauth/internal/domain" +) + +func TestParseGrantType(t *testing.T) { + t.Parallel() + + for _, tc := range []struct { + in string + out domain.GrantType + }{{ + in: "authorization_code", + out: domain.GrantTypeAuthorizationCode, + }, { + in: "ticket", + out: domain.GrantTypeTicket, + }} { + tc := tc + + t.Run(tc.in, func(t *testing.T) { + t.Parallel() + + result, err := domain.ParseGrantType(tc.in) + if err != nil { + t.Fatalf("%+v", err) + } + + if result != tc.out { + t.Errorf("ParseGrantType(%s) = %v, want %v", tc.in, result, tc.out) + } + }) + } +} + +func TestGrantType_UnmarshalForm(t *testing.T) { + t.Parallel() + + input := []byte("authorization_code") + result := domain.GrantTypeUndefined + + if err := result.UnmarshalForm(input); err != nil { + t.Fatalf("%+v", err) + } + + if result != domain.GrantTypeAuthorizationCode { + t.Errorf("UnmarshalForm(%s) = %v, want %v", input, result, domain.GrantTypeAuthorizationCode) + } +} + +func TestGrantType_UnmarshalJSON(t *testing.T) { + t.Parallel() + + input := []byte(`"authorization_code"`) + result := domain.GrantTypeUndefined + + if err := result.UnmarshalJSON(input); err != nil { + t.Fatalf("%+v", err) + } + + if result != domain.GrantTypeAuthorizationCode { + t.Errorf("UnmarshalJSON(%s) = %v, want %v", input, result, domain.GrantTypeAuthorizationCode) + } +} + +func TestGrantType_String(t *testing.T) { + t.Parallel() + + for _, tc := range []struct { + name string + in domain.GrantType + out string + }{{ + name: "authorization_code", + in: domain.GrantTypeAuthorizationCode, + out: "authorization_code", + }, { + name: "ticket", + in: domain.GrantTypeTicket, + out: "ticket", + }} { + tc := tc + + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + result := tc.in.String() + if result != tc.out { + t.Errorf("String() = %v, want %v", result, tc.out) + } + }) + } +} diff --git a/internal/domain/me.go b/internal/domain/me.go index 5780e2e..df3f510 100644 --- a/internal/domain/me.go +++ b/internal/domain/me.go @@ -8,7 +8,6 @@ import ( "strings" "testing" - "github.com/stretchr/testify/require" http "github.com/valyala/fasthttp" "golang.org/x/xerrors" ) @@ -114,7 +113,9 @@ func TestMe(tb testing.TB, src string) *Me { tb.Helper() me, err := ParseMe(src) - require.NoError(tb, err) + if err != nil { + tb.Fatalf("%+v", err) + } return me } diff --git a/internal/domain/me_test.go b/internal/domain/me_test.go index 0e9f406..236ada1 100644 --- a/internal/domain/me_test.go +++ b/internal/domain/me_test.go @@ -1,11 +1,9 @@ package domain_test import ( + "fmt" "testing" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" - "source.toby3d.me/website/indieauth/internal/domain" ) @@ -13,63 +11,124 @@ import ( func TestParseMe(t *testing.T) { t.Parallel() - for _, testCase := range []struct { - name string - input string - isValid bool + for _, tc := range []struct { + name string + in string + expError bool }{{ - name: "valid", - input: "https://example.com/", - isValid: true, + name: "valid", + in: "https://example.com/", + expError: false, }, { - name: "valid with path", - input: "https://example.com/username", - isValid: true, + name: "valid path", + in: "https://example.com/username", + expError: false, }, { - name: "valid with query", - input: "https://example.com/users?id=100", - isValid: true, + name: "valid query", + in: "https://example.com/users?id=100", + expError: false, }, { - name: "missing scheme", - input: "example.com", - isValid: false, + name: "missing scheme", + in: "example.com", + expError: true, }, { - name: "invalid scheme", - input: "mailto:user@example.com", - isValid: false, + name: "invalid scheme", + in: "mailto:user@example.com", + expError: true, }, { - name: "contains a double-dot path segment", - input: "https://example.com/foo/../bar", - isValid: false, + name: "contains double-dot path", + in: "https://example.com/foo/../bar", + expError: true, }, { - name: "contains a fragment", - input: "https://example.com/#me", - isValid: false, + name: "contains fragment", + in: "https://example.com/#me", + expError: true, }, { - name: "contains a username and password", - input: "https://user:pass@example.com/", - isValid: false, + name: "contains user", + in: "https://user:pass@example.com/", + expError: true, }, { - name: "contains a port", - input: "https://example.com:8443/", - isValid: false, + name: "contains port", + in: "https://example.com:8443/", + expError: true, }, { - name: "host is an IP address", - input: "https://172.28.92.51/", - isValid: false, + name: "host is an IP address", + in: "https://172.28.92.51/", + expError: true, }} { - testCase := testCase + tc := tc - t.Run(testCase.name, func(t *testing.T) { + t.Run(tc.name, func(t *testing.T) { t.Parallel() - result, err := domain.ParseMe(testCase.input) - if testCase.isValid { - require.NoError(t, err) - assert.Equal(t, testCase.input, result.String()) - } else { - assert.Error(t, err) + _, err := domain.ParseMe(tc.in) + + switch { + case err != nil && !tc.expError: + t.Errorf("ParseMe(%s) = %+v, want nil", tc.in, err) + case err == nil && tc.expError: + t.Errorf("ParseMe(%s) = %+v, want error", tc.in, err) } }) } } + +func TestMe_UnmarshalForm(t *testing.T) { + t.Parallel() + + me := domain.TestMe(t, "https://user.example.com/") + input := []byte(fmt.Sprint(me)) + result := new(domain.Me) + + if err := result.UnmarshalForm(input); err != nil { + t.Fatalf("%+v", err) + } + + if fmt.Sprint(result) != fmt.Sprint(me) { + t.Errorf("UnmarshalForm(%s) = %v, want %v", input, result, me) + } +} + +func TestMe_UnmarshalJSON(t *testing.T) { + t.Parallel() + + me := domain.TestMe(t, "https://user.example.com/") + input := []byte(fmt.Sprintf(`"%s"`, me)) + result := new(domain.Me) + + if err := result.UnmarshalJSON(input); err != nil { + t.Fatalf("%+v", err) + } + + if fmt.Sprint(result) != fmt.Sprint(me) { + t.Errorf("UnmarshalJSON(%s) = %v, want %v", input, result, me) + } +} + +func TestMe_MarshalJSON(t *testing.T) { + t.Parallel() + + me := domain.TestMe(t, "https://user.example.com/") + + result, err := me.MarshalJSON() + if err != nil { + t.Fatalf("%+v", err) + } + + if string(result) != fmt.Sprintf(`"%s"`, me) { + t.Errorf("MarshalJSON() = %s, want %s", result, fmt.Sprintf(`"%s"`, me)) + } +} + +// TODO(toby3d): TestMe_URI + +// TODO(toby3d): TestMe_URL + +func TestMe_String(t *testing.T) { + t.Parallel() + + me := domain.TestMe(t, "https://user.example.com/") + if result := me.String(); result != fmt.Sprint(me) { + t.Errorf("Strig() = %s, want %s", result, fmt.Sprint(me)) + } +} diff --git a/internal/domain/provider.go b/internal/domain/provider.go index 8d3d768..587c98b 100644 --- a/internal/domain/provider.go +++ b/internal/domain/provider.go @@ -23,7 +23,7 @@ type Provider struct { //nolint: gochecknoglobals var ( - DefaultProviderDirect = Provider{ + ProviderDirect = Provider{ AuthURL: "/authorize", Name: "IndieAuth", Photo: path.Join("static", "icon.svg"), @@ -34,7 +34,7 @@ var ( URL: "/", } - DefaultProviderTwitter = Provider{ + ProviderTwitter = Provider{ AuthURL: "https://twitter.com/i/oauth2/authorize", Name: "Twitter", Photo: path.Join("static", "providers", "twitter.svg"), @@ -48,7 +48,7 @@ var ( URL: "https://twitter.com/", } - DefaultProviderGitHub = Provider{ + ProviderGitHub = Provider{ AuthURL: "https://github.com/login/oauth/authorize", Name: "GitHub", Photo: path.Join("static", "providers", "github.svg"), @@ -62,7 +62,7 @@ var ( URL: "https://github.com/", } - DefaultProviderGitLab = Provider{ + ProviderGitLab = Provider{ AuthURL: "https://gitlab.com/oauth/authorize", Name: "GitLab", Photo: path.Join("static", "providers", "gitlab.svg"), @@ -75,7 +75,7 @@ var ( URL: "https://gitlab.com/", } - DefaultProviderMastodon = Provider{ + ProviderMastodon = Provider{ AuthURL: "https://mstdn.io/oauth/authorize", Name: "Mastodon", Photo: path.Join("static", "providers", "mastodon.svg"), @@ -90,8 +90,9 @@ var ( ) // AuthCodeURL returns URL for authorize user in RelMeAuth client. -func (p Provider) AuthCodeURL(state string) *URL { +func (p Provider) AuthCodeURL(state string) string { u := http.AcquireURI() + defer http.ReleaseURI(u) u.Update(p.AuthURL) for k, v := range map[string]string{ @@ -104,5 +105,5 @@ func (p Provider) AuthCodeURL(state string) *URL { u.QueryArgs().Set(k, v) } - return &URL{URI: u} + return u.String() } diff --git a/internal/domain/response_type_test.go b/internal/domain/response_type_test.go new file mode 100644 index 0000000..e88aded --- /dev/null +++ b/internal/domain/response_type_test.go @@ -0,0 +1,96 @@ +package domain_test + +import ( + "testing" + + "source.toby3d.me/website/indieauth/internal/domain" +) + +func TestResponseTypeType(t *testing.T) { + t.Parallel() + + for _, tc := range []struct { + in string + out domain.ResponseType + }{{ + in: "id", + out: domain.ResponseTypeID, + }, { + in: "code", + out: domain.ResponseTypeCode, + }} { + tc := tc + + t.Run(tc.in, func(t *testing.T) { + t.Parallel() + + result, err := domain.ParseResponseType(tc.in) + if err != nil { + t.Fatalf("%+v", err) + } + + if result != tc.out { + t.Errorf("ParseResponseType(%s) = %v, want %v", tc.in, result, tc.out) + } + }) + } +} + +func TestResponseType_UnmarshalForm(t *testing.T) { + t.Parallel() + + input := []byte("code") + result := domain.ResponseTypeUndefined + + if err := result.UnmarshalForm(input); err != nil { + t.Fatalf("%+v", err) + } + + if result != domain.ResponseTypeCode { + t.Errorf("UnmarshalForm(%s) = %v, want %v", input, result, domain.ResponseTypeCode) + } +} + +func TestResponseType_UnmarshalJSON(t *testing.T) { + t.Parallel() + + input := []byte(`"code"`) + result := domain.ResponseTypeUndefined + + if err := result.UnmarshalJSON(input); err != nil { + t.Fatalf("%+v", err) + } + + if result != domain.ResponseTypeCode { + t.Errorf("UnmarshalJSON(%s) = %v, want %v", input, result, domain.ResponseTypeCode) + } +} + +func TestResponseType_String(t *testing.T) { + t.Parallel() + + for _, tc := range []struct { + name string + in domain.ResponseType + out string + }{{ + name: "id", + in: domain.ResponseTypeID, + out: "id", + }, { + name: "code", + in: domain.ResponseTypeCode, + out: "code", + }} { + tc := tc + + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + result := tc.in.String() + if result != tc.out { + t.Errorf("String() = %v, want %v", result, tc.out) + } + }) + } +} diff --git a/internal/domain/scope.go b/internal/domain/scope.go index 7c36a76..59ac413 100644 --- a/internal/domain/scope.go +++ b/internal/domain/scope.go @@ -2,7 +2,6 @@ package domain import ( "fmt" - "sort" "strconv" "strings" ) @@ -139,8 +138,6 @@ func (s Scopes) MarshalJSON() ([]byte, error) { scopes[i] = s[i].String() } - sort.Strings(scopes) - return []byte(strconv.Quote(strings.Join(scopes, " "))), nil } diff --git a/internal/domain/scope_test.go b/internal/domain/scope_test.go index d488d5b..bed95c6 100644 --- a/internal/domain/scope_test.go +++ b/internal/domain/scope_test.go @@ -1,54 +1,199 @@ package domain_test import ( + "fmt" + "reflect" "testing" - "github.com/goccy/go-json" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" - "source.toby3d.me/website/indieauth/internal/domain" ) -/* TODO(toby3d): enable this after form package patch +func TestParseScope(t *testing.T) { + t.Parallel() + + for _, tc := range []struct { + in string + out domain.Scope + }{{ + in: "create", + out: domain.ScopeCreate, + }, { + in: "delete", + out: domain.ScopeDelete, + }, { + in: "draft", + out: domain.ScopeDraft, + }, { + in: "media", + out: domain.ScopeMedia, + }, { + in: "update", + out: domain.ScopeUpdate, + }, { + in: "block", + out: domain.ScopeBlock, + }, { + in: "channels", + out: domain.ScopeChannels, + }, { + in: "follow", + out: domain.ScopeFollow, + }, { + in: "mute", + out: domain.ScopeMute, + }, { + in: "read", + out: domain.ScopeRead, + }, { + in: "profile", + out: domain.ScopeProfile, + }, { + in: "email", + out: domain.ScopeEmail, + }} { + tc := tc + + t.Run(tc.in, func(t *testing.T) { + t.Parallel() + + result, err := domain.ParseScope(tc.in) + if err != nil { + t.Fatalf("%+v", err) + } + + if result != tc.out { + t.Errorf("ParseScope(%s) = %v, want %v", tc.in, result, tc.out) + } + }) + } +} + func TestScopes_UnmarshalForm(t *testing.T) { t.Parallel() - args := http.AcquireArgs() - defer http.ReleaseArgs(args) - args.Set("scope", "read update delete") + input := []byte("profile email") + results := make(domain.Scopes, 0) - result := struct { - Scope domain.Scopes - }{ - Scope: make(domain.Scopes, 0), + if err := results.UnmarshalForm(input); err != nil { + t.Fatalf("%+v", err) } - require.NoError(t, form.Unmarshal(args, &result)) - assert.Equal(t, "read update delete", result.Scope.String()) + expResults := domain.Scopes{domain.ScopeProfile, domain.ScopeEmail} + if !reflect.DeepEqual(results, expResults) { + t.Errorf("UnmarshalForm(%s) = %s, want %s", input, results, expResults) + } } -*/ func TestScopes_UnmarshalJSON(t *testing.T) { t.Parallel() - result := struct { - Scope domain.Scopes `json:"scope"` - }{} - require.NoError(t, json.Unmarshal([]byte(`{"scope":"read update delete"}`), &result)) - assert.Equal(t, domain.Scopes{domain.ScopeRead, domain.ScopeUpdate, domain.ScopeDelete}, result.Scope) + input := []byte(`"profile email"`) + results := make(domain.Scopes, 0) + + if err := results.UnmarshalJSON(input); err != nil { + t.Fatalf("%+v", err) + } + + expResults := domain.Scopes{domain.ScopeProfile, domain.ScopeEmail} + if !reflect.DeepEqual(results, expResults) { + t.Errorf("UnmarshalJSON(%s) = %s, want %s", input, results, expResults) + } } func TestScopes_MarshalJSON(t *testing.T) { t.Parallel() - result, err := json.Marshal(map[string]domain.Scopes{ - "scope": { - domain.ScopeEmail, - domain.ScopeProfile, - domain.ScopeRead, - }, - }) - require.NoError(t, err) - assert.Equal(t, `{"scope":"email profile read"}`, string(result)) + scopes := domain.Scopes{domain.ScopeEmail, domain.ScopeProfile} + + result, err := scopes.MarshalJSON() + if err != nil { + t.Fatalf("%+v", err) + } + + if string(result) != fmt.Sprintf(`"%s"`, scopes) { + t.Errorf("MarshalJSON() = %s, want %s", result, fmt.Sprintf(`"%s"`, scopes)) + } +} + +func TestScope_String(t *testing.T) { + t.Parallel() + + //nolint: paralleltest // NOTE(toby3d): false positive, tc.in is used. + for _, tc := range []struct { + in domain.Scope + out string + }{{ + in: domain.ScopeCreate, + out: "create", + }, { + in: domain.ScopeDelete, + out: "delete", + }, { + in: domain.ScopeDraft, + out: "draft", + }, { + in: domain.ScopeMedia, + out: "media", + }, { + in: domain.ScopeUpdate, + out: "update", + }, { + in: domain.ScopeBlock, + out: "block", + }, { + in: domain.ScopeChannels, + out: "channels", + }, { + in: domain.ScopeFollow, + out: "follow", + }, { + in: domain.ScopeMute, + out: "mute", + }, { + in: domain.ScopeRead, + out: "read", + }, { + in: domain.ScopeProfile, + out: "profile", + }, { + in: domain.ScopeEmail, + out: "email", + }} { + tc := tc + + t.Run(fmt.Sprint(tc.in), func(t *testing.T) { + t.Parallel() + + if result := tc.in.String(); result != tc.out { + t.Errorf("String() = %s, want %s", result, tc.out) + } + }) + } +} + +func TestScopes_String(t *testing.T) { + t.Parallel() + + scopes := domain.Scopes{domain.ScopeProfile, domain.ScopeEmail} + if result := scopes.String(); result != fmt.Sprint(scopes) { + t.Errorf("String() = %s, want %s", result, scopes) + } +} + +func TestScopes_IsEmpty(t *testing.T) { + t.Parallel() + + scopes := domain.Scopes{domain.ScopeUndefined} + if result := scopes.IsEmpty(); !result { + t.Errorf("IsEmpty() = %t, want %t", result, true) + } +} + +func TestScopes_Has(t *testing.T) { + t.Parallel() + + scopes := domain.Scopes{domain.ScopeProfile, domain.ScopeEmail} + if result := scopes.Has(domain.ScopeEmail); !result { + t.Errorf("Has(%s) = %t, want %t", domain.ScopeEmail, result, true) + } } diff --git a/internal/domain/session.go b/internal/domain/session.go index cf6a0ac..a01cec7 100644 --- a/internal/domain/session.go +++ b/internal/domain/session.go @@ -3,8 +3,6 @@ package domain import ( "testing" - "github.com/stretchr/testify/require" - "source.toby3d.me/website/indieauth/internal/random" ) @@ -24,7 +22,9 @@ func TestSession(tb testing.TB) *Session { tb.Helper() code, err := random.String(24) - require.NoError(tb, err) + if err != nil { + tb.Fatalf("%+v", err) + } return &Session{ ClientID: TestClientID(tb), diff --git a/internal/domain/token.go b/internal/domain/token.go index a567864..853fdc3 100644 --- a/internal/domain/token.go +++ b/internal/domain/token.go @@ -7,7 +7,6 @@ import ( "github.com/lestrrat-go/jwx/jwa" "github.com/lestrrat-go/jwx/jwt" - "github.com/stretchr/testify/require" http "github.com/valyala/fasthttp" "source.toby3d.me/website/indieauth/internal/random" @@ -90,7 +89,9 @@ func TestToken(tb testing.TB) *Token { tb.Helper() nonce, err := random.String(22) - require.NoError(tb, err) + if err != nil { + tb.Fatalf("%+v", err) + } t := jwt.New() cid := TestClientID(tb) @@ -116,7 +117,9 @@ func TestToken(tb testing.TB) *Token { t.Set("nonce", nonce) accessToken, err := jwt.Sign(t, jwa.HS256, []byte("hackme")) - require.NoError(tb, err) + if err != nil { + tb.Fatalf("%+v", err) + } return &Token{ ClientID: cid, diff --git a/internal/domain/token_test.go b/internal/domain/token_test.go new file mode 100644 index 0000000..a5b618b --- /dev/null +++ b/internal/domain/token_test.go @@ -0,0 +1,65 @@ +package domain_test + +import ( + "bytes" + "fmt" + "testing" + "time" + + http "github.com/valyala/fasthttp" + + "source.toby3d.me/website/indieauth/internal/domain" +) + +func TestNewToken(t *testing.T) { + t.Parallel() + + expResult := domain.TestToken(t) + opts := domain.NewTokenOptions{ + Algorithm: "", + NonceLength: 0, + Issuer: expResult.ClientID, + Expiration: 1 * time.Hour, + Scope: expResult.Scope, + Subject: expResult.Me, + Secret: []byte("hackme"), + } + + result, err := domain.NewToken(opts) + if err != nil { + t.Fatalf("%+v", err) + } + + if fmt.Sprint(result.ClientID) != fmt.Sprint(expResult.ClientID) || + fmt.Sprint(result.Me) != fmt.Sprint(expResult.Me) || + fmt.Sprint(result.Scope) != fmt.Sprint(expResult.Scope) { + t.Errorf("NewToken(%+v) = %+v, want %+v", opts, result, expResult) + } +} + +func TestToken_SetAuthHeader(t *testing.T) { + t.Parallel() + + token := domain.TestToken(t) + expResult := []byte("Bearer " + token.AccessToken) + + req := http.AcquireRequest() + defer http.ReleaseRequest(req) + token.SetAuthHeader(req) + + result := req.Header.Peek(http.HeaderAuthorization) + if result == nil || !bytes.Equal(result, expResult) { + t.Errorf("SetAuthHeader(%+v) = %s, want %s", req, result, expResult) + } +} + +func TestToken_String(t *testing.T) { + t.Parallel() + + token := domain.TestToken(t) + expResult := "Bearer " + token.AccessToken + + if result := token.String(); result != expResult { + t.Errorf("String() = %s, want %s", result, expResult) + } +} diff --git a/internal/domain/url_test.go b/internal/domain/url_test.go new file mode 100644 index 0000000..1374357 --- /dev/null +++ b/internal/domain/url_test.go @@ -0,0 +1,51 @@ +package domain_test + +import ( + "fmt" + "testing" + + "source.toby3d.me/website/indieauth/internal/domain" +) + +func TestParseURL(t *testing.T) { + t.Parallel() + + input := "https://user:pass@example.com:8443/users?id=100#me" + if _, err := domain.ParseURL(input); err != nil { + t.Errorf("ParseURL(%s) = %+v, want nil", input, err) + } +} + +func TestURL_UnmarshalForm(t *testing.T) { + t.Parallel() + + u := domain.TestURL(t, "https://user:pass@example.com:8443/users?id=100#me") + input := []byte(fmt.Sprint(u)) + result := new(domain.URL) + + if err := result.UnmarshalForm(input); err != nil { + t.Fatalf("%+v", err) + } + + if fmt.Sprint(result) != fmt.Sprint(u) { + t.Errorf("UnmarshalForm(%s) = %v, want %v", input, result, u) + } +} + +func TestURL_UnmarshalJSON(t *testing.T) { + t.Parallel() + + u := domain.TestURL(t, "https://user:pass@example.com:8443/users?id=100#me") + input := []byte(fmt.Sprintf(`"%s"`, u)) + result := new(domain.URL) + + if err := result.UnmarshalJSON(input); err != nil { + t.Fatalf("%+v", err) + } + + if fmt.Sprint(result) != fmt.Sprint(u) { + t.Errorf("UnmarshalJSON(%s) = %v, want %v", input, result, u) + } +} + +// TODO(toby3d): TestURL_URL