From 9ef1398b80b12286f182cc2b3fb769d2184299e0 Mon Sep 17 00:00:00 2001 From: Maxim Lebedev Date: Sun, 22 May 2022 17:55:50 +0500 Subject: [PATCH] :technologist: Improved recursive decoding --- error.go | 40 ++++++++++++ form.go | 176 +++++++++++++++++++++++++++++++++------------------ form_test.go | 140 ++++++++++++++++++++++++---------------- go.mod | 11 ++-- go.sum | 21 ++---- tags.go | 31 +++++++++ 6 files changed, 281 insertions(+), 138 deletions(-) create mode 100644 error.go create mode 100644 tags.go diff --git a/error.go b/error.go new file mode 100644 index 0000000..3c17dc6 --- /dev/null +++ b/error.go @@ -0,0 +1,40 @@ +package form + +import ( + "reflect" +) + +type ( + UnmarshalTypeError struct { + Type reflect.Type + Value string + Struct string + Field string + Offset int64 + } + + InvalidUnmarshalError struct { + Type reflect.Type + } +) + +func (e UnmarshalTypeError) Error() string { + if e.Struct != "" || e.Field != "" { + return "form: cannot unmarshal " + e.Value + " into Go struct field " + e.Struct + "." + + e.Field + " of type " + e.Type.String() + } + + return "form: cannot unmarshal " + e.Value + " into Go value of type " + e.Type.String() +} + +func (e InvalidUnmarshalError) Error() string { + if e.Type == nil { + return "form: Unmarshal(nil)" + } + + if e.Type.Kind() != reflect.Pointer { + return "form: Unmarshal(non-pointer " + e.Type.String() + "}" + } + + return "form: Unmarshal(nil " + e.Type.String() + ")" +} diff --git a/form.go b/form.go index 050c02b..46086f8 100644 --- a/form.go +++ b/form.go @@ -3,9 +3,12 @@ package form import ( - "errors" + "bytes" "fmt" + "io" "reflect" + "strconv" + "strings" http "github.com/valyala/fasthttp" ) @@ -15,25 +18,32 @@ type ( // a form description of themselves. The input can be assumed to be a // valid encoding of a form value. UnmarshalForm must copy the form data // if it wishes to retain the data after returning. - // - // By convention, to approximate the behavior of Unmarshal itself, - // Unmarshalers implement UnmarshalForm([]byte("null")) as a no-op. Unmarshaler interface { UnmarshalForm(v []byte) error } - - // A Decoder reads and decodes form values from an *fasthttp.Args. Decoder struct { - source *http.Args + tag string + args *http.Args } ) -const tagName string = "form" +const ( + tagIgnore = "-" + methodName = "UnmarshalForm" +) + +func NewDecoder(r io.Reader) *Decoder { + buf := new(bytes.Buffer) + defer buf.Reset() + + _, _ = buf.ReadFrom(r) + + args := http.AcquireArgs() + args.ParseBytes(buf.Bytes()) -// NewDecoder returns a new decoder that reads from *fasthttp.Args. -func NewDecoder(args *http.Args) *Decoder { return &Decoder{ - source: args, + tag: "form", + args: args, } } @@ -57,21 +67,16 @@ func NewDecoder(args *http.Args) *Decoder { // the keys (either the struct field name or its tag), preferring an exact match // but also accepting a case-insensitive match. By default, object keys which // don't have a corresponding struct field are ignored. -func Unmarshal(src *http.Args, dst interface{}) error { - if err := NewDecoder(src).Decode(dst); err != nil { - return fmt.Errorf("unmarshal: %w", err) - } - - return nil +func Unmarshal(data []byte, v any) error { + return NewDecoder(bytes.NewReader(data)).Decode(v) } -// Decode reads the next form-encoded value from its input and stores it in the -// value pointed to by v. -//nolint: funlen -func (dec *Decoder) Decode(src interface{}) (err error) { - v := reflect.ValueOf(src).Elem() - if !v.IsValid() { - return errors.New("invalid input") +func (d Decoder) Decode(dst any) (err error) { + src := reflect.ValueOf(dst) + if !src.IsValid() || src.Kind() != reflect.Pointer || src.Elem().Kind() != reflect.Struct { + return &InvalidUnmarshalError{ + Type: reflect.TypeOf(dst), + } } defer func() { @@ -84,54 +89,105 @@ func (dec *Decoder) Decode(src interface{}) (err error) { } }() - t := reflect.TypeOf(src).Elem() + return d.decode("", src) +} - for i := 0; i < v.NumField(); i++ { - ft := t.Field(i) +func (d Decoder) decode(key string, dst reflect.Value) error { + src := http.AcquireArgs() + defer http.ReleaseArgs(src) + d.args.CopyTo(src) - // NOTE(toby3d): get tag value as query name - tagValue, ok := ft.Tag.Lookup(tagName) - if !ok || tagValue == "" || tagValue == "-" || !dec.source.Has(tagValue) { - continue + if keyIndex := strings.LastIndex(key, ","); keyIndex != -1 { + if index, err := strconv.Atoi(key[keyIndex+1:]); err == nil { + key = key[:keyIndex] + + src.Reset() + src.SetBytesV(key, d.args.PeekMulti(key)[index]) + } + } + + switch dst.Kind() { + case reflect.Bool: + dst.SetBool(src.GetBool(key)) + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + dst.SetInt(int64(src.GetUfloatOrZero(key))) + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + dst.SetUint(uint64(src.GetUintOrZero(key))) + case reflect.Float32, reflect.Float64: + dst.SetFloat(src.GetUfloatOrZero(key)) + // case reflect.Array: // TODO(toby3d) + // case reflect.Interface: // TODO(toby3d) + case reflect.Slice: + // NOTE(toby3d): copy raw []byte value as is + if dst.Type().Elem().Kind() == reflect.Uint8 { + dst.SetBytes(src.Peek(key)) + + return nil } - field := v.Field(i) + if dst.IsNil() { + slice := d.args.PeekMulti(key) + dst.Set(reflect.MakeSlice(dst.Type(), len(slice), cap(slice))) + } - // NOTE(toby3d): read struct field type - switch ft.Type.Kind() { - case reflect.String: - field.SetString(string(dec.source.Peek(tagValue))) - case reflect.Int: - field.SetInt(int64(dec.source.GetUintOrZero(tagValue))) - case reflect.Float64: - field.SetFloat(dec.source.GetUfloatOrZero(tagValue)) - case reflect.Bool: - field.SetBool(dec.source.GetBool(tagValue)) - case reflect.Ptr: // NOTE(toby3d): pointer to another struct - // NOTE(toby3d): check what custom unmarshal method exists - unmarshalFunc := field.MethodByName("UnmarshalForm") - if unmarshalFunc.IsZero() { + for i := 0; i < dst.Len(); i++ { + if err := d.decode(fmt.Sprintf("%s,%d", key, i), dst.Index(i)); err != nil { + return err + } + } + case reflect.String: + dst.SetString(string(src.Peek(key))) + case reflect.Pointer: + if dst.IsNil() { + dst.Set(reflect.New(dst.Type().Elem())) + } + + // NOTE(toby3d): if contains UnmarshalForm method + for i := 0; i < dst.NumMethod(); i++ { + if dst.Type().Method(i).Name != methodName { continue } - field.Set(reflect.New(ft.Type.Elem())) // NOTE(toby3d): initialize zero value - unmarshalFunc.Call([]reflect.Value{reflect.ValueOf(dec.source.Peek(tagValue))}) - case reflect.Slice: - switch ft.Type.Elem().Kind() { - case reflect.Uint8: // NOTE(toby3d): bytes slice - field.SetBytes(dec.source.Peek(tagValue)) - case reflect.String: // NOTE(toby3d): string slice - values := dec.source.PeekMulti(tagValue) - slice := reflect.MakeSlice(ft.Type, len(values), len(values)) + in := make([]reflect.Value, 1) + in[0] = reflect.ValueOf(src.Peek(key)) - for j, vv := range values { - slice.Index(j).SetString(string(vv)) + out := dst.Method(i).Call(in) + if len(out) > 0 && out[0].Interface() != nil { + return out[0].Interface().(error) + } + + return nil + } + + if err := d.decode(key, dst.Elem()); err != nil { + return err + } + case reflect.Struct: + // NOTE(toby3d): if contains UnmarshalForm method + for i := 0; i < dst.Addr().NumMethod(); i++ { + if dst.Addr().Type().Method(i).Name != methodName { + continue + } + + in := make([]reflect.Value, 1) + in[0] = reflect.ValueOf(src.Peek(key)) + + out := dst.Addr().Method(i).Call(in) + if len(out) > 0 && out[0].Interface() != nil { + return out[0].Interface().(error) + } + + return nil + } + + for i := 0; i < dst.NumField(); i++ { + if name, _ := parseTag(string(dst.Type().Field(i).Tag.Get(d.tag))); name != tagIgnore { + if err := d.decode(name, dst.Field(i)); err != nil { + return err } - - field.Set(slice) } } } - return + return nil } diff --git a/form_test.go b/form_test.go index 048b263..b919e8a 100644 --- a/form_test.go +++ b/form_test.go @@ -1,83 +1,111 @@ package form_test import ( + "errors" "testing" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" + "github.com/google/go-cmp/cmp" http "github.com/valyala/fasthttp" "source.toby3d.me/toby3d/form" ) type ( - ResponseType string - - URI struct { - *http.URI `form:"-"` + TestResult struct { + ArrayStruct []Struct `form:"arrayStruct[]"` + ArrayPtrStruct []*Struct `form:"arrayPtrStruct[]"` + Bytes []byte `form:"bytes"` // TODO(toby3d) + Ints []int `form:"ints[]"` + Struct Struct `form:"struct"` + PtrStruct *Struct `form:"ptrStruct"` + Skip any `form:"-"` + // Interface any `form:"interface"` // TODO(toby3d) + Empty string `form:"empty"` + NotFormTag string `json:"notFormTag"` + String string `form:"string"` + Float float32 `form:"float"` + Uint uint `form:"uint"` + Int int `form:"int"` + Bool bool `form:"bool"` } - TestResult struct { - State []byte `form:"state"` - Scope []string `form:"scope[]"` - ClientID *URI `form:"client_id"` - RedirectURI *URI `form:"redirect_uri"` - Me *URI `form:"me"` - ResponseType ResponseType `form:"response_type"` - CodeChallenge string `form:"code_challenge"` - CodeChallengeMethod string `form:"code_challenge_method"` + Struct struct { + uid string `form:"-"` } ) -const testData string = `response_type=code` + // NOTE(toby3d): string type alias - `&state=1234567890` + // NOTE(toby3d): raw value - // NOTE(toby3d): custom URL types - `&client_id=https://app.example.com/` + - `&redirect_uri=https://app.example.com/redirect` + - `&me=https://user.example.net/` + - // NOTE(toby3d): plain strings - `&code_challenge=OfYAxt8zU2dAPDWQxTAUIteRzMsoj9QBdMIVEDOErUo` + - `&code_challenge_method=S256` + - // NOTE(toby3d): multiple values - `&scope[]=profile` + - `&scope[]=create` + - `&scope[]=update` + - `&scope[]=delete` +const testData string = `skip=dontTouchMe` + + `&bool=true` + + `&string=hello+world` + + `&int=42` + + `&uint=420` + + `&float=4.2` + + // `&interface=a1b2c3` + // TODO(toby3d) + `&struct=abc` + + `&ptrStruct=123` + + `&arrayStruct[]=abc` + + `&arrayStruct[]=123` + + `&arrayPtrStruct[]=321` + + `&arrayPtrStruct[]=bca` + + `&ints[]=240` + + `&ints[]=420` + + `&bytes=sampletext` + + `¬FormTag=dontParseMe` func TestUnmarshal(t *testing.T) { t.Parallel() args := http.AcquireArgs() - clientId, redirectUri, me := http.AcquireURI(), http.AcquireURI(), http.AcquireURI() - - t.Cleanup(func() { - http.ReleaseURI(me) - http.ReleaseURI(redirectUri) - http.ReleaseURI(clientId) - http.ReleaseArgs(args) - }) - - require.NoError(t, clientId.Parse(nil, []byte("https://app.example.com/"))) - require.NoError(t, redirectUri.Parse(nil, []byte("https://app.example.com/redirect"))) - require.NoError(t, me.Parse(nil, []byte("https://user.example.net/"))) args.Parse(testData) - result := new(TestResult) - require.NoError(t, form.Unmarshal(args, result)) - assert.Equal(t, &TestResult{ - ClientID: &URI{URI: clientId}, - Me: &URI{URI: me}, - RedirectURI: &URI{URI: redirectUri}, - State: []byte("1234567890"), - Scope: []string{"profile", "create", "update", "delete"}, - CodeChallengeMethod: "S256", - CodeChallenge: "OfYAxt8zU2dAPDWQxTAUIteRzMsoj9QBdMIVEDOErUo", - ResponseType: "code", - }, result) + var in TestResult + if err := form.Unmarshal(args.QueryString(), &in); err != nil { + t.Fatal(err) + } + + out := TestResult{ + Skip: nil, + Bool: true, + Float: 4.2, + Int: 42, + // Interface: []byte("a1b2c3"), // TODO(toby3d) + PtrStruct: &Struct{uid: "123"}, + String: "hello world", + Struct: Struct{uid: "abc"}, + Uint: 420, + ArrayStruct: []Struct{ + {uid: "abc"}, + {uid: "123"}, + }, + ArrayPtrStruct: []*Struct{ + {uid: "321"}, + {uid: "bca"}, + }, + Ints: []int{240, 420}, + Empty: "", + Bytes: []byte("sampletext"), + NotFormTag: "", + } + + opts := []cmp.Option{ + cmp.AllowUnexported(Struct{}), + } + + if !cmp.Equal(out, in, opts...) { + t.Errorf("Unmarshal(%s, &in)\n%+s", args.QueryString(), cmp.Diff(out, in, opts...)) + } } -func (src *URI) UnmarshalForm(v []byte) error { - src.URI = http.AcquireURI() +func (s *Struct) UnmarshalForm(v []byte) error { + src := string(v) + switch src { + case "123", "abc", "321", "bca": + s.uid = string(v) + default: + return errors.New("Struct: dough!") + } - return src.Parse(nil, v) + return nil } + +func (s Struct) GoString() string { return "Struct(" + s.uid + ")" } diff --git a/go.mod b/go.mod index 0c2d294..b38ed88 100644 --- a/go.mod +++ b/go.mod @@ -1,17 +1,14 @@ module source.toby3d.me/toby3d/form -go 1.17 +go 1.18 require ( - github.com/stretchr/testify v1.7.0 - github.com/valyala/fasthttp v1.34.0 + github.com/google/go-cmp v0.5.8 + github.com/valyala/fasthttp v1.37.0 ) require ( github.com/andybalholm/brotli v1.0.4 // indirect - github.com/davecgh/go-spew v1.1.0 // indirect - github.com/klauspost/compress v1.15.1 // indirect - github.com/pmezard/go-difflib v1.0.0 // indirect + github.com/klauspost/compress v1.15.4 // indirect github.com/valyala/bytebufferpool v1.0.0 // indirect - gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c // indirect ) diff --git a/go.sum b/go.sum index dc8d478..c914254 100644 --- a/go.sum +++ b/go.sum @@ -1,19 +1,14 @@ github.com/andybalholm/brotli v1.0.4 h1:V7DdXeJtZscaqfNuAdSRuRFzuiKlHSC/Zh3zl9qY3JY= github.com/andybalholm/brotli v1.0.4/go.mod h1:fO7iG3H7G2nSZ7m0zPUDn85XEX2GTukHGRSepvi9Eig= -github.com/davecgh/go-spew v1.1.0 h1:ZDRjVQ15GmhC3fiQ8ni8+OwkZQO4DARzQgrnXU1Liz8= -github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/google/go-cmp v0.5.8 h1:e6P7q2lk1O+qJJb4BtCQXlK8vWEO8V1ZeuEdJNOqZyg= +github.com/google/go-cmp v0.5.8/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= github.com/klauspost/compress v1.15.0/go.mod h1:/3/Vjq9QcHkK5uEr5lBEmyoZ1iFhe47etQ6QUkpK6sk= -github.com/klauspost/compress v1.15.1 h1:y9FcTHGyrebwfP0ZZqFiaxTaiDnUrGkJkI+f583BL1A= -github.com/klauspost/compress v1.15.1/go.mod h1:/3/Vjq9QcHkK5uEr5lBEmyoZ1iFhe47etQ6QUkpK6sk= -github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= -github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= -github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= -github.com/stretchr/testify v1.7.0 h1:nwc3DEeHmmLAfoZucVR881uASk0Mfjw8xYJ99tb5CcY= -github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= +github.com/klauspost/compress v1.15.4 h1:1kn4/7MepF/CHmYub99/nNX8az0IJjfSOU/jbnTVfqQ= +github.com/klauspost/compress v1.15.4/go.mod h1:PhcZ0MbTNciWF3rruxRgKxI5NkcHHrHUDtV4Yw2GlzU= github.com/valyala/bytebufferpool v1.0.0 h1:GqA5TC/0021Y/b9FG4Oi9Mr3q7XYx6KllzawFIhcdPw= github.com/valyala/bytebufferpool v1.0.0/go.mod h1:6bBcMArwyJ5K/AmCkWv1jt77kVWyCJ6HpOuEn7z0Csc= -github.com/valyala/fasthttp v1.34.0 h1:d3AAQJ2DRcxJYHm7OXNXtXt2as1vMDfxeIcFvhmGGm4= -github.com/valyala/fasthttp v1.34.0/go.mod h1:epZA5N+7pY6ZaEKRmstzOuYJx9HI8DI1oaCGZpdH4h0= +github.com/valyala/fasthttp v1.37.0 h1:7WHCyI7EAkQMVmrfBhWTCOaeROb1aCBiTopx63LkMbE= +github.com/valyala/fasthttp v1.37.0/go.mod h1:t/G+3rLek+CyY9bnIE+YlMRddxVAAGjhxndDB4i4C0I= github.com/valyala/tcplisten v1.0.0/go.mod h1:T0xQ8SeCZGxckz9qRXTfG43PvQ/mcWh7FwZEA7Ioqkc= golang.org/x/crypto v0.0.0-20220214200702-86341886e292/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4= golang.org/x/net v0.0.0-20211112202133-69e39bad7dc2/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y= @@ -28,7 +23,3 @@ golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuX golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ= golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= -gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= -gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= -gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c h1:dUUwHk2QECo/6vqA44rthZ8ie2QXMNeKRTHCNY2nXvo= -gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/tags.go b/tags.go new file mode 100644 index 0000000..648bc8f --- /dev/null +++ b/tags.go @@ -0,0 +1,31 @@ +package form + +import ( + "strings" +) + +type tagOptions string + +const delim rune = ',' + +func parseTag(tag string) (string, tagOptions) { + tag, opt, _ := strings.Cut(tag, string(delim)) + + return tag, tagOptions(opt) +} + +func (o tagOptions) Contains(optionName string) bool { + if len(o) == 0 { + return false + } + + s := string(o) + for s != "" { + var name string + if name, s, _ = strings.Cut(s, string(delim)); name == optionName { + return true + } + } + + return false +}