From 5140ec4e474c88e4bf489dd4c59f8b6598e96df6 Mon Sep 17 00:00:00 2001 From: Maxim Lebedev Date: Sun, 6 Aug 2023 05:53:40 +0600 Subject: [PATCH] :technologist: Allow unmarshlers for custom slice types --- form.go | 73 ++++++++++++++++++++++++++++++------------ form_test.go | 90 ++++++++++++++++++++++++++++++++++++---------------- 2 files changed, 115 insertions(+), 48 deletions(-) diff --git a/form.go b/form.go index cd2fba7..457d2ba 100644 --- a/form.go +++ b/form.go @@ -6,11 +6,10 @@ import ( "bytes" "fmt" "io" + "net/url" "reflect" "strconv" "strings" - - http "github.com/valyala/fasthttp" ) type ( @@ -21,9 +20,10 @@ type ( Unmarshaler interface { UnmarshalForm(v []byte) error } + Decoder struct { + args url.Values tag string - args *http.Args } ) @@ -38,9 +38,7 @@ func NewDecoder(r io.Reader) *Decoder { defer buf.Reset() _, _ = buf.ReadFrom(r) - - args := http.AcquireArgs() - args.ParseBytes(buf.Bytes()) + args, _ := url.ParseQuery(buf.String()) return &Decoder{ tag: "form", @@ -94,40 +92,75 @@ func (d Decoder) Decode(dst any) (err error) { } func (d Decoder) decode(key string, dst reflect.Value, opts tagOptions) error { - src := http.AcquireArgs() - defer http.ReleaseArgs(src) - d.args.CopyTo(src) + src := d.args if keyIndex := strings.LastIndex(key, ","); keyIndex != -1 { if index, err := strconv.Atoi(key[keyIndex+1:]); err == nil { key = key[:keyIndex] - src.Reset() - src.SetBytesV(key, d.args.PeekMulti(key)[index]) + src = make(url.Values) + src.Set(key, d.args[key][index]) } } switch dst.Kind() { case reflect.Bool: - dst.SetBool(src.GetBool(key)) + out, err := strconv.ParseBool(src.Get(key)) + if err != nil { + return err + } + + dst.SetBool(out) case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: - dst.SetInt(int64(src.GetUfloatOrZero(key))) + out, err := strconv.ParseInt(src.Get(key), 10, 64) + if err != nil { + return err + } + + dst.SetInt(out) case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: - dst.SetUint(uint64(src.GetUintOrZero(key))) + out, err := strconv.ParseUint(src.Get(key), 10, 64) + if err != nil { + return err + } + + dst.SetUint(out) case reflect.Float32, reflect.Float64: - dst.SetFloat(src.GetUfloatOrZero(key)) + out, err := strconv.ParseFloat(src.Get(key), 64) + if err != nil { + return err + } + + dst.SetFloat(out) // case reflect.Array: // TODO(toby3d) // case reflect.Interface: // TODO(toby3d) case reflect.Slice: // NOTE(toby3d): copy raw []byte value as is if dst.Type().Elem().Kind() == reflect.Uint8 { - dst.SetBytes(src.Peek(key)) + dst.SetBytes([]byte(src.Get(key))) + + return nil + } + + // NOTE(toby3d): if contains UnmarshalForm method + for i := 0; i < dst.Addr().NumMethod(); i++ { + if dst.Addr().Type().Method(i).Name != methodName { + continue + } + + in := make([]reflect.Value, 1) + in[0] = reflect.ValueOf([]byte(src.Get(key))) + + out := dst.Addr().Method(i).Call(in) + if len(out) > 0 && out[0].Interface() != nil && !opts.Contains(tagOmitempty) { + return out[0].Interface().(error) + } return nil } if dst.IsNil() { - slice := d.args.PeekMulti(key) + slice := d.args[key] dst.Set(reflect.MakeSlice(dst.Type(), len(slice), cap(slice))) } @@ -137,7 +170,7 @@ func (d Decoder) decode(key string, dst reflect.Value, opts tagOptions) error { } } case reflect.String: - dst.SetString(string(src.Peek(key))) + dst.SetString(string(src.Get(key))) case reflect.Pointer: if dst.IsNil() { dst.Set(reflect.New(dst.Type().Elem())) @@ -150,7 +183,7 @@ func (d Decoder) decode(key string, dst reflect.Value, opts tagOptions) error { } in := make([]reflect.Value, 1) - in[0] = reflect.ValueOf(src.Peek(key)) + in[0] = reflect.ValueOf([]byte(src.Get(key))) out := dst.Method(i).Call(in) if len(out) > 0 && out[0].Interface() != nil && !opts.Contains(tagOmitempty) { @@ -171,7 +204,7 @@ func (d Decoder) decode(key string, dst reflect.Value, opts tagOptions) error { } in := make([]reflect.Value, 1) - in[0] = reflect.ValueOf(src.Peek(key)) + in[0] = reflect.ValueOf([]byte(src.Get(key))) out := dst.Addr().Method(i).Call(in) if len(out) > 0 && out[0].Interface() != nil && !opts.Contains(tagOmitempty) { diff --git a/form_test.go b/form_test.go index 62f57e2..995f06a 100644 --- a/form_test.go +++ b/form_test.go @@ -2,38 +2,42 @@ package form_test import ( "errors" + "net/url" + "strings" "testing" "github.com/google/go-cmp/cmp" - http "github.com/valyala/fasthttp" "source.toby3d.me/toby3d/form" ) type ( TestResult struct { + Skip any `form:"-"` + PtrStruct *Struct `form:"ptrStruct"` + PtrStructs *Structs `form:"ptrStructs"` + NullStruct NullStruct `form:"nullstruct,omitempty"` + Struct Struct `form:"struct"` + Empty string `form:"empty"` + String string `form:"string"` + NotFormTag string `json:"notFormTag"` ArrayStruct []Struct `form:"arrayStruct[]"` ArrayPtrStruct []*Struct `form:"arrayPtrStruct[]"` - Bytes []byte `form:"bytes"` // TODO(toby3d) Ints []int `form:"ints[]"` - Struct Struct `form:"struct"` - NullStruct NullStruct `form:"nullstruct,omitempty"` - PtrStruct *Struct `form:"ptrStruct"` - Skip any `form:"-"` - // Interface any `form:"interface"` // TODO(toby3d) - Empty string `form:"empty"` - NotFormTag string `json:"notFormTag"` - String string `form:"string"` - Float float32 `form:"float"` - Uint uint `form:"uint"` - Int int `form:"int"` - Bool bool `form:"bool"` + Structs Structs `form:"structs"` + Bytes []byte `form:"bytes"` + Uint uint `form:"uint"` + Int int `form:"int"` + Float float32 `form:"float"` + Bool bool `form:"bool"` } Struct struct { uid string `form:"-"` } + Structs []Struct + NullStruct struct { uid string `form:"-"` } @@ -45,9 +49,7 @@ func TestUnmarshal(t *testing.T) { t.Run("valid", func(t *testing.T) { t.Parallel() - args := http.AcquireArgs() - defer http.ReleaseArgs(args) - args.Parse(`skip=dontTouchMe` + + args, err := url.ParseQuery(`skip=dontTouchMe` + `&bool=true` + `&string=hello+world` + `&int=42` + @@ -63,10 +65,15 @@ func TestUnmarshal(t *testing.T) { `&ints[]=240` + `&ints[]=420` + `&bytes=sampletext` + - `¬FormTag=dontParseMe`) + `¬FormTag=dontParseMe` + + `&structs=123+abc` + + `&ptrStructs=bca+321`) + if err != nil { + t.Fatal(err) + } - var in TestResult - if err := form.Unmarshal(args.QueryString(), &in); err != nil { + in := new(TestResult) + if err := form.Unmarshal([]byte(args.Encode()), in); err != nil { t.Fatal(err) } @@ -93,31 +100,58 @@ func TestUnmarshal(t *testing.T) { Bytes: []byte("sampletext"), NotFormTag: "", NullStruct: NullStruct{uid: ""}, + Structs: Structs{ + {uid: "123"}, + {uid: "abc"}, + }, + PtrStructs: &Structs{ + {uid: "bca"}, + {uid: "321"}, + }, } opts := []cmp.Option{ cmp.AllowUnexported(Struct{}, NullStruct{}), } - if !cmp.Equal(out, in, opts...) { - t.Errorf("Unmarshal(%s, &in)\n%+s", args.QueryString(), cmp.Diff(out, in, opts...)) + if !cmp.Equal(&out, in, opts...) { + t.Errorf("Unmarshal(%s, &in)\n%+s", args.Encode(), cmp.Diff(&out, in, opts...)) } }) t.Run("invalid", func(t *testing.T) { t.Parallel() - args := http.AcquireArgs() - defer http.ReleaseArgs(args) - args.Parse("arrayStruct[]=wtf") + args, err := url.ParseQuery("arrayStruct[]=wtf") + if err != nil { + t.Fatal(err) + } - var in TestResult - if err := form.Unmarshal(args.QueryString(), &in); err == nil { - t.Errorf("Unmarshal(%s, &in) = %#+v, want error", args.QueryString(), err) + in := new(TestResult) + if err := form.Unmarshal([]byte(args.Encode()), in); err == nil { + t.Errorf("Unmarshal(%s, &in) = %#+v, want error", args.Encode(), err) } }) } +func (s *Structs) UnmarshalForm(v []byte) error { + for _, f := range strings.Fields(string(v)) { + *s = append(*s, Struct{uid: f}) + } + + return nil +} + +func (s Structs) GoString() string { + out := make([]string, len(s)) + + for i := range s { + out[i] = s[i].uid + } + + return "Structs(" + strings.Join(out, ", ") + ")" +} + func (s *Struct) UnmarshalForm(v []byte) error { src := string(v) switch src {