🧑‍💻 Allow unmarshlers for custom slice types

This commit is contained in:
Maxim Lebedev 2023-08-06 05:53:40 +06:00
parent 47e2700618
commit 5140ec4e47
Signed by: toby3d
GPG Key ID: 1F14E25B7C119FC5
2 changed files with 115 additions and 48 deletions

73
form.go
View File

@ -6,11 +6,10 @@ import (
"bytes" "bytes"
"fmt" "fmt"
"io" "io"
"net/url"
"reflect" "reflect"
"strconv" "strconv"
"strings" "strings"
http "github.com/valyala/fasthttp"
) )
type ( type (
@ -21,9 +20,10 @@ type (
Unmarshaler interface { Unmarshaler interface {
UnmarshalForm(v []byte) error UnmarshalForm(v []byte) error
} }
Decoder struct { Decoder struct {
args url.Values
tag string tag string
args *http.Args
} }
) )
@ -38,9 +38,7 @@ func NewDecoder(r io.Reader) *Decoder {
defer buf.Reset() defer buf.Reset()
_, _ = buf.ReadFrom(r) _, _ = buf.ReadFrom(r)
args, _ := url.ParseQuery(buf.String())
args := http.AcquireArgs()
args.ParseBytes(buf.Bytes())
return &Decoder{ return &Decoder{
tag: "form", tag: "form",
@ -94,40 +92,75 @@ func (d Decoder) Decode(dst any) (err error) {
} }
func (d Decoder) decode(key string, dst reflect.Value, opts tagOptions) error { func (d Decoder) decode(key string, dst reflect.Value, opts tagOptions) error {
src := http.AcquireArgs() src := d.args
defer http.ReleaseArgs(src)
d.args.CopyTo(src)
if keyIndex := strings.LastIndex(key, ","); keyIndex != -1 { if keyIndex := strings.LastIndex(key, ","); keyIndex != -1 {
if index, err := strconv.Atoi(key[keyIndex+1:]); err == nil { if index, err := strconv.Atoi(key[keyIndex+1:]); err == nil {
key = key[:keyIndex] key = key[:keyIndex]
src.Reset() src = make(url.Values)
src.SetBytesV(key, d.args.PeekMulti(key)[index]) src.Set(key, d.args[key][index])
} }
} }
switch dst.Kind() { switch dst.Kind() {
case reflect.Bool: case reflect.Bool:
dst.SetBool(src.GetBool(key)) out, err := strconv.ParseBool(src.Get(key))
if err != nil {
return err
}
dst.SetBool(out)
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
dst.SetInt(int64(src.GetUfloatOrZero(key))) out, err := strconv.ParseInt(src.Get(key), 10, 64)
if err != nil {
return err
}
dst.SetInt(out)
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
dst.SetUint(uint64(src.GetUintOrZero(key))) out, err := strconv.ParseUint(src.Get(key), 10, 64)
if err != nil {
return err
}
dst.SetUint(out)
case reflect.Float32, reflect.Float64: case reflect.Float32, reflect.Float64:
dst.SetFloat(src.GetUfloatOrZero(key)) out, err := strconv.ParseFloat(src.Get(key), 64)
if err != nil {
return err
}
dst.SetFloat(out)
// case reflect.Array: // TODO(toby3d) // case reflect.Array: // TODO(toby3d)
// case reflect.Interface: // TODO(toby3d) // case reflect.Interface: // TODO(toby3d)
case reflect.Slice: case reflect.Slice:
// NOTE(toby3d): copy raw []byte value as is // NOTE(toby3d): copy raw []byte value as is
if dst.Type().Elem().Kind() == reflect.Uint8 { if dst.Type().Elem().Kind() == reflect.Uint8 {
dst.SetBytes(src.Peek(key)) dst.SetBytes([]byte(src.Get(key)))
return nil
}
// NOTE(toby3d): if contains UnmarshalForm method
for i := 0; i < dst.Addr().NumMethod(); i++ {
if dst.Addr().Type().Method(i).Name != methodName {
continue
}
in := make([]reflect.Value, 1)
in[0] = reflect.ValueOf([]byte(src.Get(key)))
out := dst.Addr().Method(i).Call(in)
if len(out) > 0 && out[0].Interface() != nil && !opts.Contains(tagOmitempty) {
return out[0].Interface().(error)
}
return nil return nil
} }
if dst.IsNil() { if dst.IsNil() {
slice := d.args.PeekMulti(key) slice := d.args[key]
dst.Set(reflect.MakeSlice(dst.Type(), len(slice), cap(slice))) dst.Set(reflect.MakeSlice(dst.Type(), len(slice), cap(slice)))
} }
@ -137,7 +170,7 @@ func (d Decoder) decode(key string, dst reflect.Value, opts tagOptions) error {
} }
} }
case reflect.String: case reflect.String:
dst.SetString(string(src.Peek(key))) dst.SetString(string(src.Get(key)))
case reflect.Pointer: case reflect.Pointer:
if dst.IsNil() { if dst.IsNil() {
dst.Set(reflect.New(dst.Type().Elem())) dst.Set(reflect.New(dst.Type().Elem()))
@ -150,7 +183,7 @@ func (d Decoder) decode(key string, dst reflect.Value, opts tagOptions) error {
} }
in := make([]reflect.Value, 1) in := make([]reflect.Value, 1)
in[0] = reflect.ValueOf(src.Peek(key)) in[0] = reflect.ValueOf([]byte(src.Get(key)))
out := dst.Method(i).Call(in) out := dst.Method(i).Call(in)
if len(out) > 0 && out[0].Interface() != nil && !opts.Contains(tagOmitempty) { if len(out) > 0 && out[0].Interface() != nil && !opts.Contains(tagOmitempty) {
@ -171,7 +204,7 @@ func (d Decoder) decode(key string, dst reflect.Value, opts tagOptions) error {
} }
in := make([]reflect.Value, 1) in := make([]reflect.Value, 1)
in[0] = reflect.ValueOf(src.Peek(key)) in[0] = reflect.ValueOf([]byte(src.Get(key)))
out := dst.Addr().Method(i).Call(in) out := dst.Addr().Method(i).Call(in)
if len(out) > 0 && out[0].Interface() != nil && !opts.Contains(tagOmitempty) { if len(out) > 0 && out[0].Interface() != nil && !opts.Contains(tagOmitempty) {

View File

@ -2,38 +2,42 @@ package form_test
import ( import (
"errors" "errors"
"net/url"
"strings"
"testing" "testing"
"github.com/google/go-cmp/cmp" "github.com/google/go-cmp/cmp"
http "github.com/valyala/fasthttp"
"source.toby3d.me/toby3d/form" "source.toby3d.me/toby3d/form"
) )
type ( type (
TestResult struct { TestResult struct {
Skip any `form:"-"`
PtrStruct *Struct `form:"ptrStruct"`
PtrStructs *Structs `form:"ptrStructs"`
NullStruct NullStruct `form:"nullstruct,omitempty"`
Struct Struct `form:"struct"`
Empty string `form:"empty"`
String string `form:"string"`
NotFormTag string `json:"notFormTag"`
ArrayStruct []Struct `form:"arrayStruct[]"` ArrayStruct []Struct `form:"arrayStruct[]"`
ArrayPtrStruct []*Struct `form:"arrayPtrStruct[]"` ArrayPtrStruct []*Struct `form:"arrayPtrStruct[]"`
Bytes []byte `form:"bytes"` // TODO(toby3d)
Ints []int `form:"ints[]"` Ints []int `form:"ints[]"`
Struct Struct `form:"struct"` Structs Structs `form:"structs"`
NullStruct NullStruct `form:"nullstruct,omitempty"` Bytes []byte `form:"bytes"`
PtrStruct *Struct `form:"ptrStruct"` Uint uint `form:"uint"`
Skip any `form:"-"` Int int `form:"int"`
// Interface any `form:"interface"` // TODO(toby3d) Float float32 `form:"float"`
Empty string `form:"empty"` Bool bool `form:"bool"`
NotFormTag string `json:"notFormTag"`
String string `form:"string"`
Float float32 `form:"float"`
Uint uint `form:"uint"`
Int int `form:"int"`
Bool bool `form:"bool"`
} }
Struct struct { Struct struct {
uid string `form:"-"` uid string `form:"-"`
} }
Structs []Struct
NullStruct struct { NullStruct struct {
uid string `form:"-"` uid string `form:"-"`
} }
@ -45,9 +49,7 @@ func TestUnmarshal(t *testing.T) {
t.Run("valid", func(t *testing.T) { t.Run("valid", func(t *testing.T) {
t.Parallel() t.Parallel()
args := http.AcquireArgs() args, err := url.ParseQuery(`skip=dontTouchMe` +
defer http.ReleaseArgs(args)
args.Parse(`skip=dontTouchMe` +
`&bool=true` + `&bool=true` +
`&string=hello+world` + `&string=hello+world` +
`&int=42` + `&int=42` +
@ -63,10 +65,15 @@ func TestUnmarshal(t *testing.T) {
`&ints[]=240` + `&ints[]=240` +
`&ints[]=420` + `&ints[]=420` +
`&bytes=sampletext` + `&bytes=sampletext` +
`&notFormTag=dontParseMe`) `&notFormTag=dontParseMe` +
`&structs=123+abc` +
`&ptrStructs=bca+321`)
if err != nil {
t.Fatal(err)
}
var in TestResult in := new(TestResult)
if err := form.Unmarshal(args.QueryString(), &in); err != nil { if err := form.Unmarshal([]byte(args.Encode()), in); err != nil {
t.Fatal(err) t.Fatal(err)
} }
@ -93,31 +100,58 @@ func TestUnmarshal(t *testing.T) {
Bytes: []byte("sampletext"), Bytes: []byte("sampletext"),
NotFormTag: "", NotFormTag: "",
NullStruct: NullStruct{uid: ""}, NullStruct: NullStruct{uid: ""},
Structs: Structs{
{uid: "123"},
{uid: "abc"},
},
PtrStructs: &Structs{
{uid: "bca"},
{uid: "321"},
},
} }
opts := []cmp.Option{ opts := []cmp.Option{
cmp.AllowUnexported(Struct{}, NullStruct{}), cmp.AllowUnexported(Struct{}, NullStruct{}),
} }
if !cmp.Equal(out, in, opts...) { if !cmp.Equal(&out, in, opts...) {
t.Errorf("Unmarshal(%s, &in)\n%+s", args.QueryString(), cmp.Diff(out, in, opts...)) t.Errorf("Unmarshal(%s, &in)\n%+s", args.Encode(), cmp.Diff(&out, in, opts...))
} }
}) })
t.Run("invalid", func(t *testing.T) { t.Run("invalid", func(t *testing.T) {
t.Parallel() t.Parallel()
args := http.AcquireArgs() args, err := url.ParseQuery("arrayStruct[]=wtf")
defer http.ReleaseArgs(args) if err != nil {
args.Parse("arrayStruct[]=wtf") t.Fatal(err)
}
var in TestResult in := new(TestResult)
if err := form.Unmarshal(args.QueryString(), &in); err == nil { if err := form.Unmarshal([]byte(args.Encode()), in); err == nil {
t.Errorf("Unmarshal(%s, &in) = %#+v, want error", args.QueryString(), err) t.Errorf("Unmarshal(%s, &in) = %#+v, want error", args.Encode(), err)
} }
}) })
} }
func (s *Structs) UnmarshalForm(v []byte) error {
for _, f := range strings.Fields(string(v)) {
*s = append(*s, Struct{uid: f})
}
return nil
}
func (s Structs) GoString() string {
out := make([]string, len(s))
for i := range s {
out[i] = s[i].uid
}
return "Structs(" + strings.Join(out, ", ") + ")"
}
func (s *Struct) UnmarshalForm(v []byte) error { func (s *Struct) UnmarshalForm(v []byte) error {
src := string(v) src := string(v)
switch src { switch src {