🧑‍💻 Improved recursive decoding

This commit is contained in:
Maxim Lebedev 2022-05-22 17:55:50 +05:00
parent f9ed5be2c0
commit 9ef1398b80
Signed by: toby3d
GPG Key ID: 1F14E25B7C119FC5
6 changed files with 281 additions and 138 deletions

40
error.go Normal file
View File

@ -0,0 +1,40 @@
package form
import (
"reflect"
)
type (
UnmarshalTypeError struct {
Type reflect.Type
Value string
Struct string
Field string
Offset int64
}
InvalidUnmarshalError struct {
Type reflect.Type
}
)
func (e UnmarshalTypeError) Error() string {
if e.Struct != "" || e.Field != "" {
return "form: cannot unmarshal " + e.Value + " into Go struct field " + e.Struct + "." +
e.Field + " of type " + e.Type.String()
}
return "form: cannot unmarshal " + e.Value + " into Go value of type " + e.Type.String()
}
func (e InvalidUnmarshalError) Error() string {
if e.Type == nil {
return "form: Unmarshal(nil)"
}
if e.Type.Kind() != reflect.Pointer {
return "form: Unmarshal(non-pointer " + e.Type.String() + "}"
}
return "form: Unmarshal(nil " + e.Type.String() + ")"
}

176
form.go
View File

@ -3,9 +3,12 @@
package form
import (
"errors"
"bytes"
"fmt"
"io"
"reflect"
"strconv"
"strings"
http "github.com/valyala/fasthttp"
)
@ -15,25 +18,32 @@ type (
// a form description of themselves. The input can be assumed to be a
// valid encoding of a form value. UnmarshalForm must copy the form data
// if it wishes to retain the data after returning.
//
// By convention, to approximate the behavior of Unmarshal itself,
// Unmarshalers implement UnmarshalForm([]byte("null")) as a no-op.
Unmarshaler interface {
UnmarshalForm(v []byte) error
}
// A Decoder reads and decodes form values from an *fasthttp.Args.
Decoder struct {
source *http.Args
tag string
args *http.Args
}
)
const tagName string = "form"
const (
tagIgnore = "-"
methodName = "UnmarshalForm"
)
func NewDecoder(r io.Reader) *Decoder {
buf := new(bytes.Buffer)
defer buf.Reset()
_, _ = buf.ReadFrom(r)
args := http.AcquireArgs()
args.ParseBytes(buf.Bytes())
// NewDecoder returns a new decoder that reads from *fasthttp.Args.
func NewDecoder(args *http.Args) *Decoder {
return &Decoder{
source: args,
tag: "form",
args: args,
}
}
@ -57,21 +67,16 @@ func NewDecoder(args *http.Args) *Decoder {
// the keys (either the struct field name or its tag), preferring an exact match
// but also accepting a case-insensitive match. By default, object keys which
// don't have a corresponding struct field are ignored.
func Unmarshal(src *http.Args, dst interface{}) error {
if err := NewDecoder(src).Decode(dst); err != nil {
return fmt.Errorf("unmarshal: %w", err)
}
return nil
func Unmarshal(data []byte, v any) error {
return NewDecoder(bytes.NewReader(data)).Decode(v)
}
// Decode reads the next form-encoded value from its input and stores it in the
// value pointed to by v.
//nolint: funlen
func (dec *Decoder) Decode(src interface{}) (err error) {
v := reflect.ValueOf(src).Elem()
if !v.IsValid() {
return errors.New("invalid input")
func (d Decoder) Decode(dst any) (err error) {
src := reflect.ValueOf(dst)
if !src.IsValid() || src.Kind() != reflect.Pointer || src.Elem().Kind() != reflect.Struct {
return &InvalidUnmarshalError{
Type: reflect.TypeOf(dst),
}
}
defer func() {
@ -84,54 +89,105 @@ func (dec *Decoder) Decode(src interface{}) (err error) {
}
}()
t := reflect.TypeOf(src).Elem()
return d.decode("", src)
}
for i := 0; i < v.NumField(); i++ {
ft := t.Field(i)
func (d Decoder) decode(key string, dst reflect.Value) error {
src := http.AcquireArgs()
defer http.ReleaseArgs(src)
d.args.CopyTo(src)
// NOTE(toby3d): get tag value as query name
tagValue, ok := ft.Tag.Lookup(tagName)
if !ok || tagValue == "" || tagValue == "-" || !dec.source.Has(tagValue) {
continue
if keyIndex := strings.LastIndex(key, ","); keyIndex != -1 {
if index, err := strconv.Atoi(key[keyIndex+1:]); err == nil {
key = key[:keyIndex]
src.Reset()
src.SetBytesV(key, d.args.PeekMulti(key)[index])
}
}
switch dst.Kind() {
case reflect.Bool:
dst.SetBool(src.GetBool(key))
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
dst.SetInt(int64(src.GetUfloatOrZero(key)))
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
dst.SetUint(uint64(src.GetUintOrZero(key)))
case reflect.Float32, reflect.Float64:
dst.SetFloat(src.GetUfloatOrZero(key))
// case reflect.Array: // TODO(toby3d)
// case reflect.Interface: // TODO(toby3d)
case reflect.Slice:
// NOTE(toby3d): copy raw []byte value as is
if dst.Type().Elem().Kind() == reflect.Uint8 {
dst.SetBytes(src.Peek(key))
return nil
}
field := v.Field(i)
if dst.IsNil() {
slice := d.args.PeekMulti(key)
dst.Set(reflect.MakeSlice(dst.Type(), len(slice), cap(slice)))
}
// NOTE(toby3d): read struct field type
switch ft.Type.Kind() {
case reflect.String:
field.SetString(string(dec.source.Peek(tagValue)))
case reflect.Int:
field.SetInt(int64(dec.source.GetUintOrZero(tagValue)))
case reflect.Float64:
field.SetFloat(dec.source.GetUfloatOrZero(tagValue))
case reflect.Bool:
field.SetBool(dec.source.GetBool(tagValue))
case reflect.Ptr: // NOTE(toby3d): pointer to another struct
// NOTE(toby3d): check what custom unmarshal method exists
unmarshalFunc := field.MethodByName("UnmarshalForm")
if unmarshalFunc.IsZero() {
for i := 0; i < dst.Len(); i++ {
if err := d.decode(fmt.Sprintf("%s,%d", key, i), dst.Index(i)); err != nil {
return err
}
}
case reflect.String:
dst.SetString(string(src.Peek(key)))
case reflect.Pointer:
if dst.IsNil() {
dst.Set(reflect.New(dst.Type().Elem()))
}
// NOTE(toby3d): if contains UnmarshalForm method
for i := 0; i < dst.NumMethod(); i++ {
if dst.Type().Method(i).Name != methodName {
continue
}
field.Set(reflect.New(ft.Type.Elem())) // NOTE(toby3d): initialize zero value
unmarshalFunc.Call([]reflect.Value{reflect.ValueOf(dec.source.Peek(tagValue))})
case reflect.Slice:
switch ft.Type.Elem().Kind() {
case reflect.Uint8: // NOTE(toby3d): bytes slice
field.SetBytes(dec.source.Peek(tagValue))
case reflect.String: // NOTE(toby3d): string slice
values := dec.source.PeekMulti(tagValue)
slice := reflect.MakeSlice(ft.Type, len(values), len(values))
in := make([]reflect.Value, 1)
in[0] = reflect.ValueOf(src.Peek(key))
for j, vv := range values {
slice.Index(j).SetString(string(vv))
out := dst.Method(i).Call(in)
if len(out) > 0 && out[0].Interface() != nil {
return out[0].Interface().(error)
}
return nil
}
if err := d.decode(key, dst.Elem()); err != nil {
return err
}
case reflect.Struct:
// NOTE(toby3d): if contains UnmarshalForm method
for i := 0; i < dst.Addr().NumMethod(); i++ {
if dst.Addr().Type().Method(i).Name != methodName {
continue
}
in := make([]reflect.Value, 1)
in[0] = reflect.ValueOf(src.Peek(key))
out := dst.Addr().Method(i).Call(in)
if len(out) > 0 && out[0].Interface() != nil {
return out[0].Interface().(error)
}
return nil
}
for i := 0; i < dst.NumField(); i++ {
if name, _ := parseTag(string(dst.Type().Field(i).Tag.Get(d.tag))); name != tagIgnore {
if err := d.decode(name, dst.Field(i)); err != nil {
return err
}
field.Set(slice)
}
}
}
return
return nil
}

View File

@ -1,83 +1,111 @@
package form_test
import (
"errors"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/google/go-cmp/cmp"
http "github.com/valyala/fasthttp"
"source.toby3d.me/toby3d/form"
)
type (
ResponseType string
URI struct {
*http.URI `form:"-"`
TestResult struct {
ArrayStruct []Struct `form:"arrayStruct[]"`
ArrayPtrStruct []*Struct `form:"arrayPtrStruct[]"`
Bytes []byte `form:"bytes"` // TODO(toby3d)
Ints []int `form:"ints[]"`
Struct Struct `form:"struct"`
PtrStruct *Struct `form:"ptrStruct"`
Skip any `form:"-"`
// Interface any `form:"interface"` // TODO(toby3d)
Empty string `form:"empty"`
NotFormTag string `json:"notFormTag"`
String string `form:"string"`
Float float32 `form:"float"`
Uint uint `form:"uint"`
Int int `form:"int"`
Bool bool `form:"bool"`
}
TestResult struct {
State []byte `form:"state"`
Scope []string `form:"scope[]"`
ClientID *URI `form:"client_id"`
RedirectURI *URI `form:"redirect_uri"`
Me *URI `form:"me"`
ResponseType ResponseType `form:"response_type"`
CodeChallenge string `form:"code_challenge"`
CodeChallengeMethod string `form:"code_challenge_method"`
Struct struct {
uid string `form:"-"`
}
)
const testData string = `response_type=code` + // NOTE(toby3d): string type alias
`&state=1234567890` + // NOTE(toby3d): raw value
// NOTE(toby3d): custom URL types
`&client_id=https://app.example.com/` +
`&redirect_uri=https://app.example.com/redirect` +
`&me=https://user.example.net/` +
// NOTE(toby3d): plain strings
`&code_challenge=OfYAxt8zU2dAPDWQxTAUIteRzMsoj9QBdMIVEDOErUo` +
`&code_challenge_method=S256` +
// NOTE(toby3d): multiple values
`&scope[]=profile` +
`&scope[]=create` +
`&scope[]=update` +
`&scope[]=delete`
const testData string = `skip=dontTouchMe` +
`&bool=true` +
`&string=hello+world` +
`&int=42` +
`&uint=420` +
`&float=4.2` +
// `&interface=a1b2c3` + // TODO(toby3d)
`&struct=abc` +
`&ptrStruct=123` +
`&arrayStruct[]=abc` +
`&arrayStruct[]=123` +
`&arrayPtrStruct[]=321` +
`&arrayPtrStruct[]=bca` +
`&ints[]=240` +
`&ints[]=420` +
`&bytes=sampletext` +
`&notFormTag=dontParseMe`
func TestUnmarshal(t *testing.T) {
t.Parallel()
args := http.AcquireArgs()
clientId, redirectUri, me := http.AcquireURI(), http.AcquireURI(), http.AcquireURI()
t.Cleanup(func() {
http.ReleaseURI(me)
http.ReleaseURI(redirectUri)
http.ReleaseURI(clientId)
http.ReleaseArgs(args)
})
require.NoError(t, clientId.Parse(nil, []byte("https://app.example.com/")))
require.NoError(t, redirectUri.Parse(nil, []byte("https://app.example.com/redirect")))
require.NoError(t, me.Parse(nil, []byte("https://user.example.net/")))
args.Parse(testData)
result := new(TestResult)
require.NoError(t, form.Unmarshal(args, result))
assert.Equal(t, &TestResult{
ClientID: &URI{URI: clientId},
Me: &URI{URI: me},
RedirectURI: &URI{URI: redirectUri},
State: []byte("1234567890"),
Scope: []string{"profile", "create", "update", "delete"},
CodeChallengeMethod: "S256",
CodeChallenge: "OfYAxt8zU2dAPDWQxTAUIteRzMsoj9QBdMIVEDOErUo",
ResponseType: "code",
}, result)
var in TestResult
if err := form.Unmarshal(args.QueryString(), &in); err != nil {
t.Fatal(err)
}
out := TestResult{
Skip: nil,
Bool: true,
Float: 4.2,
Int: 42,
// Interface: []byte("a1b2c3"), // TODO(toby3d)
PtrStruct: &Struct{uid: "123"},
String: "hello world",
Struct: Struct{uid: "abc"},
Uint: 420,
ArrayStruct: []Struct{
{uid: "abc"},
{uid: "123"},
},
ArrayPtrStruct: []*Struct{
{uid: "321"},
{uid: "bca"},
},
Ints: []int{240, 420},
Empty: "",
Bytes: []byte("sampletext"),
NotFormTag: "",
}
opts := []cmp.Option{
cmp.AllowUnexported(Struct{}),
}
if !cmp.Equal(out, in, opts...) {
t.Errorf("Unmarshal(%s, &in)\n%+s", args.QueryString(), cmp.Diff(out, in, opts...))
}
}
func (src *URI) UnmarshalForm(v []byte) error {
src.URI = http.AcquireURI()
func (s *Struct) UnmarshalForm(v []byte) error {
src := string(v)
switch src {
case "123", "abc", "321", "bca":
s.uid = string(v)
default:
return errors.New("Struct: dough!")
}
return src.Parse(nil, v)
return nil
}
func (s Struct) GoString() string { return "Struct(" + s.uid + ")" }

11
go.mod
View File

@ -1,17 +1,14 @@
module source.toby3d.me/toby3d/form
go 1.17
go 1.18
require (
github.com/stretchr/testify v1.7.0
github.com/valyala/fasthttp v1.34.0
github.com/google/go-cmp v0.5.8
github.com/valyala/fasthttp v1.37.0
)
require (
github.com/andybalholm/brotli v1.0.4 // indirect
github.com/davecgh/go-spew v1.1.0 // indirect
github.com/klauspost/compress v1.15.1 // indirect
github.com/pmezard/go-difflib v1.0.0 // indirect
github.com/klauspost/compress v1.15.4 // indirect
github.com/valyala/bytebufferpool v1.0.0 // indirect
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c // indirect
)

21
go.sum
View File

@ -1,19 +1,14 @@
github.com/andybalholm/brotli v1.0.4 h1:V7DdXeJtZscaqfNuAdSRuRFzuiKlHSC/Zh3zl9qY3JY=
github.com/andybalholm/brotli v1.0.4/go.mod h1:fO7iG3H7G2nSZ7m0zPUDn85XEX2GTukHGRSepvi9Eig=
github.com/davecgh/go-spew v1.1.0 h1:ZDRjVQ15GmhC3fiQ8ni8+OwkZQO4DARzQgrnXU1Liz8=
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/google/go-cmp v0.5.8 h1:e6P7q2lk1O+qJJb4BtCQXlK8vWEO8V1ZeuEdJNOqZyg=
github.com/google/go-cmp v0.5.8/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY=
github.com/klauspost/compress v1.15.0/go.mod h1:/3/Vjq9QcHkK5uEr5lBEmyoZ1iFhe47etQ6QUkpK6sk=
github.com/klauspost/compress v1.15.1 h1:y9FcTHGyrebwfP0ZZqFiaxTaiDnUrGkJkI+f583BL1A=
github.com/klauspost/compress v1.15.1/go.mod h1:/3/Vjq9QcHkK5uEr5lBEmyoZ1iFhe47etQ6QUkpK6sk=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
github.com/stretchr/testify v1.7.0 h1:nwc3DEeHmmLAfoZucVR881uASk0Mfjw8xYJ99tb5CcY=
github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
github.com/klauspost/compress v1.15.4 h1:1kn4/7MepF/CHmYub99/nNX8az0IJjfSOU/jbnTVfqQ=
github.com/klauspost/compress v1.15.4/go.mod h1:PhcZ0MbTNciWF3rruxRgKxI5NkcHHrHUDtV4Yw2GlzU=
github.com/valyala/bytebufferpool v1.0.0 h1:GqA5TC/0021Y/b9FG4Oi9Mr3q7XYx6KllzawFIhcdPw=
github.com/valyala/bytebufferpool v1.0.0/go.mod h1:6bBcMArwyJ5K/AmCkWv1jt77kVWyCJ6HpOuEn7z0Csc=
github.com/valyala/fasthttp v1.34.0 h1:d3AAQJ2DRcxJYHm7OXNXtXt2as1vMDfxeIcFvhmGGm4=
github.com/valyala/fasthttp v1.34.0/go.mod h1:epZA5N+7pY6ZaEKRmstzOuYJx9HI8DI1oaCGZpdH4h0=
github.com/valyala/fasthttp v1.37.0 h1:7WHCyI7EAkQMVmrfBhWTCOaeROb1aCBiTopx63LkMbE=
github.com/valyala/fasthttp v1.37.0/go.mod h1:t/G+3rLek+CyY9bnIE+YlMRddxVAAGjhxndDB4i4C0I=
github.com/valyala/tcplisten v1.0.0/go.mod h1:T0xQ8SeCZGxckz9qRXTfG43PvQ/mcWh7FwZEA7Ioqkc=
golang.org/x/crypto v0.0.0-20220214200702-86341886e292/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4=
golang.org/x/net v0.0.0-20211112202133-69e39bad7dc2/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y=
@ -28,7 +23,3 @@ golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuX
golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ=
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c h1:dUUwHk2QECo/6vqA44rthZ8ie2QXMNeKRTHCNY2nXvo=
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=

31
tags.go Normal file
View File

@ -0,0 +1,31 @@
package form
import (
"strings"
)
type tagOptions string
const delim rune = ','
func parseTag(tag string) (string, tagOptions) {
tag, opt, _ := strings.Cut(tag, string(delim))
return tag, tagOptions(opt)
}
func (o tagOptions) Contains(optionName string) bool {
if len(o) == 0 {
return false
}
s := string(o)
for s != "" {
var name string
if name, s, _ = strings.Cut(s, string(delim)); name == optionName {
return true
}
}
return false
}