diff --git a/internal/encoding/form/form.go b/internal/encoding/form/form.go new file mode 100644 index 0000000..573b968 --- /dev/null +++ b/internal/encoding/form/form.go @@ -0,0 +1,95 @@ +// Package form implements encoding and decoding of urlencoded form. The mapping +// between form and Go values is described by `form:"query_name"` struct tags. +package form + +import ( + "errors" + "reflect" + + http "github.com/valyala/fasthttp" +) + +type ( + // Unmarshaler is the interface implemented by types that can unmarshal + // a form description of themselves. The input can be assumed to be a + // valid encoding of a form value. UnmarshalForm must copy the form data + // if it wishes to retain the data after returning. + // + // By convention, to approximate the behavior of Unmarshal itself, + // Unmarshalers implement UnmarshalForm([]byte("null")) as a no-op. + Unmarshaler interface { + UnmarshalForm(v []byte) error + } + + // A Decoder reads and decodes form values from an *fasthttp.Args. + Decoder struct { + source *http.Args + } +) + +const tagName string = "form" + +// NewDecoder returns a new decoder that reads from *fasthttp.Args. +func NewDecoder(args *http.Args) *Decoder { + return &Decoder{ + source: args, + } +} + +// Decode reads the next form-encoded value from its input and stores it in the +// value pointed to by v. +func (dec *Decoder) Decode(v interface{}) error { + dst := reflect.ValueOf(v).Elem() + if !dst.IsValid() { + return errors.New("invalid input") + } + + st := reflect.TypeOf(v).Elem() + + for i := 0; i < dst.NumField(); i++ { + field := st.Field(i) + + // NOTE(toby3d): get tag value as query name + tagValue, ok := field.Tag.Lookup(tagName) + if !ok || tagValue == "" || tagValue == "-" || !dec.source.Has(tagValue) { + continue + } + + // NOTE(toby3d): read struct field type + switch field.Type.Kind() { + case reflect.String: + dst.Field(i).SetString(string(dec.source.Peek(tagValue))) + case reflect.Int: + dst.Field(i).SetInt(int64(dec.source.GetUintOrZero(tagValue))) + case reflect.Float64: + dst.Field(i).SetFloat(dec.source.GetUfloatOrZero(tagValue)) + case reflect.Bool: + dst.Field(i).SetBool(dec.source.GetBool(tagValue)) + case reflect.Ptr: // NOTE(toby3d): pointer to another struct + // NOTE(toby3d): check what custom unmarshal method exists + beforeFunc := dst.Field(i).MethodByName("UnmarshalForm") + if beforeFunc.IsNil() { + continue + } + + dst.Field(i).Set(reflect.New(field.Type.Elem())) + beforeFunc.Call([]reflect.Value{reflect.ValueOf(dec.source.Peek(tagValue))}) + case reflect.Slice: + switch field.Type.Elem().Kind() { + case reflect.Uint8: // NOTE(toby3d): bytes slice + dst.Field(i).SetBytes(dec.source.Peek(tagValue)) + case reflect.String: // NOTE(toby3d): string slice + values := dec.source.PeekMulti(tagValue) + slice := reflect.MakeSlice(field.Type, len(values), len(values)) + + for j, v := range values { + slice.Index(j).SetString(string(v)) + } + + dst.Field(i).Set(slice) + } + } + } + + return nil +} diff --git a/internal/encoding/form/form_test.go b/internal/encoding/form/form_test.go new file mode 100644 index 0000000..e0a3906 --- /dev/null +++ b/internal/encoding/form/form_test.go @@ -0,0 +1,83 @@ +package form_test + +import ( + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + http "github.com/valyala/fasthttp" + + "source.toby3d.me/website/oauth/internal/encoding/form" +) + +type ( + ResponseType string + + URI struct { + *http.URI + } + + TestResult struct { + State []byte `form:"state"` + Scope []string `form:"scope[]"` + ClientID *URI `form:"client_id"` + RedirectURI *URI `form:"redirect_uri"` + Me *URI `form:"me"` + ResponseType ResponseType `form:"response_type"` + CodeChallenge string `form:"code_challenge"` + CodeChallengeMethod string `form:"code_challenge_method"` + } +) + +const testData string = `response_type=code` + // NOTE(toby3d): string type alias + `&state=1234567890` + // NOTE(toby3d): raw value + // NOTE(toby3d): custom URL types + `&client_id=https://app.example.com/` + + `&redirect_uri=https://app.example.com/redirect` + + `&me=https://user.example.net/` + + // NOTE(toby3d): plain strings + `&code_challenge=OfYAxt8zU2dAPDWQxTAUIteRzMsoj9QBdMIVEDOErUo` + + `&code_challenge_method=S256` + + // NOTE(toby3d): multiple values + `&scope[]=profile` + + `&scope[]=create` + + `&scope[]=update` + + `&scope[]=delete` + +func TestDecode(t *testing.T) { + t.Parallel() + + args := http.AcquireArgs() + clientId, redirectUri, me := http.AcquireURI(), http.AcquireURI(), http.AcquireURI() + + t.Cleanup(func() { + http.ReleaseURI(me) + http.ReleaseURI(redirectUri) + http.ReleaseURI(clientId) + http.ReleaseArgs(args) + }) + + require.NoError(t, clientId.Parse(nil, []byte("https://app.example.com/"))) + require.NoError(t, redirectUri.Parse(nil, []byte("https://app.example.com/redirect"))) + require.NoError(t, me.Parse(nil, []byte("https://user.example.net/"))) + args.Parse(testData) + + result := new(TestResult) + require.NoError(t, form.NewDecoder(args).Decode(result)) + assert.Equal(t, &TestResult{ + ClientID: &URI{URI: clientId}, + Me: &URI{URI: me}, + RedirectURI: &URI{URI: redirectUri}, + State: []byte("1234567890"), + Scope: []string{"profile", "create", "update", "delete"}, + CodeChallengeMethod: "S256", + CodeChallenge: "OfYAxt8zU2dAPDWQxTAUIteRzMsoj9QBdMIVEDOErUo", + ResponseType: "code", + }, result) +} + +func (src *URI) UnmarshalForm(v []byte) error { + src.URI = http.AcquireURI() + + return src.Parse(nil, v) +}