🧑‍💻 Added omitempty option support

This commit is contained in:
Maxim Lebedev 2022-05-24 21:24:52 +05:00
parent 9ef1398b80
commit 47e2700618
Signed by: toby3d
GPG Key ID: 1F14E25B7C119FC5
2 changed files with 106 additions and 73 deletions

21
form.go
View File

@ -28,8 +28,9 @@ type (
) )
const ( const (
tagIgnore = "-" tagIgnore = "-"
methodName = "UnmarshalForm" tagOmitempty = "omitempty"
methodName = "UnmarshalForm"
) )
func NewDecoder(r io.Reader) *Decoder { func NewDecoder(r io.Reader) *Decoder {
@ -89,10 +90,10 @@ func (d Decoder) Decode(dst any) (err error) {
} }
}() }()
return d.decode("", src) return d.decode("", src, "")
} }
func (d Decoder) decode(key string, dst reflect.Value) error { func (d Decoder) decode(key string, dst reflect.Value, opts tagOptions) error {
src := http.AcquireArgs() src := http.AcquireArgs()
defer http.ReleaseArgs(src) defer http.ReleaseArgs(src)
d.args.CopyTo(src) d.args.CopyTo(src)
@ -131,7 +132,7 @@ func (d Decoder) decode(key string, dst reflect.Value) error {
} }
for i := 0; i < dst.Len(); i++ { for i := 0; i < dst.Len(); i++ {
if err := d.decode(fmt.Sprintf("%s,%d", key, i), dst.Index(i)); err != nil { if err := d.decode(fmt.Sprintf("%s,%d", key, i), dst.Index(i), ""); err != nil {
return err return err
} }
} }
@ -152,14 +153,14 @@ func (d Decoder) decode(key string, dst reflect.Value) error {
in[0] = reflect.ValueOf(src.Peek(key)) in[0] = reflect.ValueOf(src.Peek(key))
out := dst.Method(i).Call(in) out := dst.Method(i).Call(in)
if len(out) > 0 && out[0].Interface() != nil { if len(out) > 0 && out[0].Interface() != nil && !opts.Contains(tagOmitempty) {
return out[0].Interface().(error) return out[0].Interface().(error)
} }
return nil return nil
} }
if err := d.decode(key, dst.Elem()); err != nil { if err := d.decode(key, dst.Elem(), ""); err != nil {
return err return err
} }
case reflect.Struct: case reflect.Struct:
@ -173,7 +174,7 @@ func (d Decoder) decode(key string, dst reflect.Value) error {
in[0] = reflect.ValueOf(src.Peek(key)) in[0] = reflect.ValueOf(src.Peek(key))
out := dst.Addr().Method(i).Call(in) out := dst.Addr().Method(i).Call(in)
if len(out) > 0 && out[0].Interface() != nil { if len(out) > 0 && out[0].Interface() != nil && !opts.Contains(tagOmitempty) {
return out[0].Interface().(error) return out[0].Interface().(error)
} }
@ -181,8 +182,8 @@ func (d Decoder) decode(key string, dst reflect.Value) error {
} }
for i := 0; i < dst.NumField(); i++ { for i := 0; i < dst.NumField(); i++ {
if name, _ := parseTag(string(dst.Type().Field(i).Tag.Get(d.tag))); name != tagIgnore { if name, opts := parseTag(string(dst.Type().Field(i).Tag.Get(d.tag))); name != tagIgnore {
if err := d.decode(name, dst.Field(i)); err != nil { if err := d.decode(name, dst.Field(i), opts); err != nil {
return err return err
} }
} }

View File

@ -12,13 +12,14 @@ import (
type ( type (
TestResult struct { TestResult struct {
ArrayStruct []Struct `form:"arrayStruct[]"` ArrayStruct []Struct `form:"arrayStruct[]"`
ArrayPtrStruct []*Struct `form:"arrayPtrStruct[]"` ArrayPtrStruct []*Struct `form:"arrayPtrStruct[]"`
Bytes []byte `form:"bytes"` // TODO(toby3d) Bytes []byte `form:"bytes"` // TODO(toby3d)
Ints []int `form:"ints[]"` Ints []int `form:"ints[]"`
Struct Struct `form:"struct"` Struct Struct `form:"struct"`
PtrStruct *Struct `form:"ptrStruct"` NullStruct NullStruct `form:"nullstruct,omitempty"`
Skip any `form:"-"` PtrStruct *Struct `form:"ptrStruct"`
Skip any `form:"-"`
// Interface any `form:"interface"` // TODO(toby3d) // Interface any `form:"interface"` // TODO(toby3d)
Empty string `form:"empty"` Empty string `form:"empty"`
NotFormTag string `json:"notFormTag"` NotFormTag string `json:"notFormTag"`
@ -32,68 +33,89 @@ type (
Struct struct { Struct struct {
uid string `form:"-"` uid string `form:"-"`
} }
)
const testData string = `skip=dontTouchMe` + NullStruct struct {
`&bool=true` + uid string `form:"-"`
`&string=hello+world` + }
`&int=42` + )
`&uint=420` +
`&float=4.2` +
// `&interface=a1b2c3` + // TODO(toby3d)
`&struct=abc` +
`&ptrStruct=123` +
`&arrayStruct[]=abc` +
`&arrayStruct[]=123` +
`&arrayPtrStruct[]=321` +
`&arrayPtrStruct[]=bca` +
`&ints[]=240` +
`&ints[]=420` +
`&bytes=sampletext` +
`&notFormTag=dontParseMe`
func TestUnmarshal(t *testing.T) { func TestUnmarshal(t *testing.T) {
t.Parallel() t.Parallel()
args := http.AcquireArgs() t.Run("valid", func(t *testing.T) {
args.Parse(testData) t.Parallel()
var in TestResult args := http.AcquireArgs()
if err := form.Unmarshal(args.QueryString(), &in); err != nil { defer http.ReleaseArgs(args)
t.Fatal(err) args.Parse(`skip=dontTouchMe` +
} `&bool=true` +
`&string=hello+world` +
`&int=42` +
`&uint=420` +
`&float=4.2` +
// `&interface=a1b2c3` + // TODO(toby3d)
`&struct=abc` +
`&ptrStruct=123` +
`&arrayStruct[]=abc` +
`&arrayStruct[]=123` +
`&arrayPtrStruct[]=321` +
`&arrayPtrStruct[]=bca` +
`&ints[]=240` +
`&ints[]=420` +
`&bytes=sampletext` +
`&notFormTag=dontParseMe`)
out := TestResult{ var in TestResult
Skip: nil, if err := form.Unmarshal(args.QueryString(), &in); err != nil {
Bool: true, t.Fatal(err)
Float: 4.2, }
Int: 42,
// Interface: []byte("a1b2c3"), // TODO(toby3d)
PtrStruct: &Struct{uid: "123"},
String: "hello world",
Struct: Struct{uid: "abc"},
Uint: 420,
ArrayStruct: []Struct{
{uid: "abc"},
{uid: "123"},
},
ArrayPtrStruct: []*Struct{
{uid: "321"},
{uid: "bca"},
},
Ints: []int{240, 420},
Empty: "",
Bytes: []byte("sampletext"),
NotFormTag: "",
}
opts := []cmp.Option{ out := TestResult{
cmp.AllowUnexported(Struct{}), Skip: nil,
} Bool: true,
Float: 4.2,
Int: 42,
// Interface: []byte("a1b2c3"), // TODO(toby3d)
PtrStruct: &Struct{uid: "123"},
String: "hello world",
Struct: Struct{uid: "abc"},
Uint: 420,
ArrayStruct: []Struct{
{uid: "abc"},
{uid: "123"},
},
ArrayPtrStruct: []*Struct{
{uid: "321"},
{uid: "bca"},
},
Ints: []int{240, 420},
Empty: "",
Bytes: []byte("sampletext"),
NotFormTag: "",
NullStruct: NullStruct{uid: ""},
}
if !cmp.Equal(out, in, opts...) { opts := []cmp.Option{
t.Errorf("Unmarshal(%s, &in)\n%+s", args.QueryString(), cmp.Diff(out, in, opts...)) cmp.AllowUnexported(Struct{}, NullStruct{}),
} }
if !cmp.Equal(out, in, opts...) {
t.Errorf("Unmarshal(%s, &in)\n%+s", args.QueryString(), cmp.Diff(out, in, opts...))
}
})
t.Run("invalid", func(t *testing.T) {
t.Parallel()
args := http.AcquireArgs()
defer http.ReleaseArgs(args)
args.Parse("arrayStruct[]=wtf")
var in TestResult
if err := form.Unmarshal(args.QueryString(), &in); err == nil {
t.Errorf("Unmarshal(%s, &in) = %#+v, want error", args.QueryString(), err)
}
})
} }
func (s *Struct) UnmarshalForm(v []byte) error { func (s *Struct) UnmarshalForm(v []byte) error {
@ -101,11 +123,21 @@ func (s *Struct) UnmarshalForm(v []byte) error {
switch src { switch src {
case "123", "abc", "321", "bca": case "123", "abc", "321", "bca":
s.uid = string(v) s.uid = string(v)
default:
return errors.New("Struct: dough!") return nil
} }
return nil return errors.New("Struct: dough!")
}
func (ns *NullStruct) UnmarshalForm(v []byte) error {
if src := string(v); src != "" {
ns.uid = src
return nil
}
return errors.New("NullStruct: dough!")
} }
func (s Struct) GoString() string { return "Struct(" + s.uid + ")" } func (s Struct) GoString() string { return "Struct(" + s.uid + ")" }