diff --git a/form.go b/form.go index 46086f8..cd2fba7 100644 --- a/form.go +++ b/form.go @@ -28,8 +28,9 @@ type ( ) const ( - tagIgnore = "-" - methodName = "UnmarshalForm" + tagIgnore = "-" + tagOmitempty = "omitempty" + methodName = "UnmarshalForm" ) func NewDecoder(r io.Reader) *Decoder { @@ -89,10 +90,10 @@ func (d Decoder) Decode(dst any) (err error) { } }() - return d.decode("", src) + return d.decode("", src, "") } -func (d Decoder) decode(key string, dst reflect.Value) error { +func (d Decoder) decode(key string, dst reflect.Value, opts tagOptions) error { src := http.AcquireArgs() defer http.ReleaseArgs(src) d.args.CopyTo(src) @@ -131,7 +132,7 @@ func (d Decoder) decode(key string, dst reflect.Value) error { } for i := 0; i < dst.Len(); i++ { - if err := d.decode(fmt.Sprintf("%s,%d", key, i), dst.Index(i)); err != nil { + if err := d.decode(fmt.Sprintf("%s,%d", key, i), dst.Index(i), ""); err != nil { return err } } @@ -152,14 +153,14 @@ func (d Decoder) decode(key string, dst reflect.Value) error { in[0] = reflect.ValueOf(src.Peek(key)) out := dst.Method(i).Call(in) - if len(out) > 0 && out[0].Interface() != nil { + if len(out) > 0 && out[0].Interface() != nil && !opts.Contains(tagOmitempty) { return out[0].Interface().(error) } return nil } - if err := d.decode(key, dst.Elem()); err != nil { + if err := d.decode(key, dst.Elem(), ""); err != nil { return err } case reflect.Struct: @@ -173,7 +174,7 @@ func (d Decoder) decode(key string, dst reflect.Value) error { in[0] = reflect.ValueOf(src.Peek(key)) out := dst.Addr().Method(i).Call(in) - if len(out) > 0 && out[0].Interface() != nil { + if len(out) > 0 && out[0].Interface() != nil && !opts.Contains(tagOmitempty) { return out[0].Interface().(error) } @@ -181,8 +182,8 @@ func (d Decoder) decode(key string, dst reflect.Value) error { } for i := 0; i < dst.NumField(); i++ { - if name, _ := parseTag(string(dst.Type().Field(i).Tag.Get(d.tag))); name != tagIgnore { - if err := d.decode(name, dst.Field(i)); err != nil { + if name, opts := parseTag(string(dst.Type().Field(i).Tag.Get(d.tag))); name != tagIgnore { + if err := d.decode(name, dst.Field(i), opts); err != nil { return err } } diff --git a/form_test.go b/form_test.go index b919e8a..62f57e2 100644 --- a/form_test.go +++ b/form_test.go @@ -12,13 +12,14 @@ import ( type ( TestResult struct { - ArrayStruct []Struct `form:"arrayStruct[]"` - ArrayPtrStruct []*Struct `form:"arrayPtrStruct[]"` - Bytes []byte `form:"bytes"` // TODO(toby3d) - Ints []int `form:"ints[]"` - Struct Struct `form:"struct"` - PtrStruct *Struct `form:"ptrStruct"` - Skip any `form:"-"` + ArrayStruct []Struct `form:"arrayStruct[]"` + ArrayPtrStruct []*Struct `form:"arrayPtrStruct[]"` + Bytes []byte `form:"bytes"` // TODO(toby3d) + Ints []int `form:"ints[]"` + Struct Struct `form:"struct"` + NullStruct NullStruct `form:"nullstruct,omitempty"` + PtrStruct *Struct `form:"ptrStruct"` + Skip any `form:"-"` // Interface any `form:"interface"` // TODO(toby3d) Empty string `form:"empty"` NotFormTag string `json:"notFormTag"` @@ -32,68 +33,89 @@ type ( Struct struct { uid string `form:"-"` } -) -const testData string = `skip=dontTouchMe` + - `&bool=true` + - `&string=hello+world` + - `&int=42` + - `&uint=420` + - `&float=4.2` + - // `&interface=a1b2c3` + // TODO(toby3d) - `&struct=abc` + - `&ptrStruct=123` + - `&arrayStruct[]=abc` + - `&arrayStruct[]=123` + - `&arrayPtrStruct[]=321` + - `&arrayPtrStruct[]=bca` + - `&ints[]=240` + - `&ints[]=420` + - `&bytes=sampletext` + - `¬FormTag=dontParseMe` + NullStruct struct { + uid string `form:"-"` + } +) func TestUnmarshal(t *testing.T) { t.Parallel() - args := http.AcquireArgs() - args.Parse(testData) + t.Run("valid", func(t *testing.T) { + t.Parallel() - var in TestResult - if err := form.Unmarshal(args.QueryString(), &in); err != nil { - t.Fatal(err) - } + args := http.AcquireArgs() + defer http.ReleaseArgs(args) + args.Parse(`skip=dontTouchMe` + + `&bool=true` + + `&string=hello+world` + + `&int=42` + + `&uint=420` + + `&float=4.2` + + // `&interface=a1b2c3` + // TODO(toby3d) + `&struct=abc` + + `&ptrStruct=123` + + `&arrayStruct[]=abc` + + `&arrayStruct[]=123` + + `&arrayPtrStruct[]=321` + + `&arrayPtrStruct[]=bca` + + `&ints[]=240` + + `&ints[]=420` + + `&bytes=sampletext` + + `¬FormTag=dontParseMe`) - out := TestResult{ - Skip: nil, - Bool: true, - Float: 4.2, - Int: 42, - // Interface: []byte("a1b2c3"), // TODO(toby3d) - PtrStruct: &Struct{uid: "123"}, - String: "hello world", - Struct: Struct{uid: "abc"}, - Uint: 420, - ArrayStruct: []Struct{ - {uid: "abc"}, - {uid: "123"}, - }, - ArrayPtrStruct: []*Struct{ - {uid: "321"}, - {uid: "bca"}, - }, - Ints: []int{240, 420}, - Empty: "", - Bytes: []byte("sampletext"), - NotFormTag: "", - } + var in TestResult + if err := form.Unmarshal(args.QueryString(), &in); err != nil { + t.Fatal(err) + } - opts := []cmp.Option{ - cmp.AllowUnexported(Struct{}), - } + out := TestResult{ + Skip: nil, + Bool: true, + Float: 4.2, + Int: 42, + // Interface: []byte("a1b2c3"), // TODO(toby3d) + PtrStruct: &Struct{uid: "123"}, + String: "hello world", + Struct: Struct{uid: "abc"}, + Uint: 420, + ArrayStruct: []Struct{ + {uid: "abc"}, + {uid: "123"}, + }, + ArrayPtrStruct: []*Struct{ + {uid: "321"}, + {uid: "bca"}, + }, + Ints: []int{240, 420}, + Empty: "", + Bytes: []byte("sampletext"), + NotFormTag: "", + NullStruct: NullStruct{uid: ""}, + } - if !cmp.Equal(out, in, opts...) { - t.Errorf("Unmarshal(%s, &in)\n%+s", args.QueryString(), cmp.Diff(out, in, opts...)) - } + opts := []cmp.Option{ + cmp.AllowUnexported(Struct{}, NullStruct{}), + } + + if !cmp.Equal(out, in, opts...) { + t.Errorf("Unmarshal(%s, &in)\n%+s", args.QueryString(), cmp.Diff(out, in, opts...)) + } + }) + + t.Run("invalid", func(t *testing.T) { + t.Parallel() + + args := http.AcquireArgs() + defer http.ReleaseArgs(args) + args.Parse("arrayStruct[]=wtf") + + var in TestResult + if err := form.Unmarshal(args.QueryString(), &in); err == nil { + t.Errorf("Unmarshal(%s, &in) = %#+v, want error", args.QueryString(), err) + } + }) } func (s *Struct) UnmarshalForm(v []byte) error { @@ -101,11 +123,21 @@ func (s *Struct) UnmarshalForm(v []byte) error { switch src { case "123", "abc", "321", "bca": s.uid = string(v) - default: - return errors.New("Struct: dough!") + + return nil } - return nil + return errors.New("Struct: dough!") +} + +func (ns *NullStruct) UnmarshalForm(v []byte) error { + if src := string(v); src != "" { + ns.uid = src + + return nil + } + + return errors.New("NullStruct: dough!") } func (s Struct) GoString() string { return "Struct(" + s.uid + ")" }