Compare commits

...

3 Commits

Author SHA1 Message Date
Maxim Lebedev 5140ec4e47
🧑‍💻 Allow unmarshlers for custom slice types 2023-08-06 05:53:40 +06:00
Maxim Lebedev 47e2700618
🧑‍💻 Added omitempty option support 2022-05-24 21:27:12 +05:00
Maxim Lebedev 9ef1398b80
🧑‍💻 Improved recursive decoding 2022-05-22 17:55:50 +05:00
6 changed files with 383 additions and 140 deletions

40
error.go Normal file
View File

@ -0,0 +1,40 @@
package form
import (
"reflect"
)
type (
UnmarshalTypeError struct {
Type reflect.Type
Value string
Struct string
Field string
Offset int64
}
InvalidUnmarshalError struct {
Type reflect.Type
}
)
func (e UnmarshalTypeError) Error() string {
if e.Struct != "" || e.Field != "" {
return "form: cannot unmarshal " + e.Value + " into Go struct field " + e.Struct + "." +
e.Field + " of type " + e.Type.String()
}
return "form: cannot unmarshal " + e.Value + " into Go value of type " + e.Type.String()
}
func (e InvalidUnmarshalError) Error() string {
if e.Type == nil {
return "form: Unmarshal(nil)"
}
if e.Type.Kind() != reflect.Pointer {
return "form: Unmarshal(non-pointer " + e.Type.String() + "}"
}
return "form: Unmarshal(nil " + e.Type.String() + ")"
}

212
form.go
View File

@ -3,11 +3,13 @@
package form
import (
"errors"
"bytes"
"fmt"
"io"
"net/url"
"reflect"
http "github.com/valyala/fasthttp"
"strconv"
"strings"
)
type (
@ -15,25 +17,32 @@ type (
// a form description of themselves. The input can be assumed to be a
// valid encoding of a form value. UnmarshalForm must copy the form data
// if it wishes to retain the data after returning.
//
// By convention, to approximate the behavior of Unmarshal itself,
// Unmarshalers implement UnmarshalForm([]byte("null")) as a no-op.
Unmarshaler interface {
UnmarshalForm(v []byte) error
}
// A Decoder reads and decodes form values from an *fasthttp.Args.
Decoder struct {
source *http.Args
args url.Values
tag string
}
)
const tagName string = "form"
const (
tagIgnore = "-"
tagOmitempty = "omitempty"
methodName = "UnmarshalForm"
)
func NewDecoder(r io.Reader) *Decoder {
buf := new(bytes.Buffer)
defer buf.Reset()
_, _ = buf.ReadFrom(r)
args, _ := url.ParseQuery(buf.String())
// NewDecoder returns a new decoder that reads from *fasthttp.Args.
func NewDecoder(args *http.Args) *Decoder {
return &Decoder{
source: args,
tag: "form",
args: args,
}
}
@ -57,21 +66,16 @@ func NewDecoder(args *http.Args) *Decoder {
// the keys (either the struct field name or its tag), preferring an exact match
// but also accepting a case-insensitive match. By default, object keys which
// don't have a corresponding struct field are ignored.
func Unmarshal(src *http.Args, dst interface{}) error {
if err := NewDecoder(src).Decode(dst); err != nil {
return fmt.Errorf("unmarshal: %w", err)
}
return nil
func Unmarshal(data []byte, v any) error {
return NewDecoder(bytes.NewReader(data)).Decode(v)
}
// Decode reads the next form-encoded value from its input and stores it in the
// value pointed to by v.
//nolint: funlen
func (dec *Decoder) Decode(src interface{}) (err error) {
v := reflect.ValueOf(src).Elem()
if !v.IsValid() {
return errors.New("invalid input")
func (d Decoder) Decode(dst any) (err error) {
src := reflect.ValueOf(dst)
if !src.IsValid() || src.Kind() != reflect.Pointer || src.Elem().Kind() != reflect.Struct {
return &InvalidUnmarshalError{
Type: reflect.TypeOf(dst),
}
}
defer func() {
@ -84,54 +88,140 @@ func (dec *Decoder) Decode(src interface{}) (err error) {
}
}()
t := reflect.TypeOf(src).Elem()
return d.decode("", src, "")
}
for i := 0; i < v.NumField(); i++ {
ft := t.Field(i)
func (d Decoder) decode(key string, dst reflect.Value, opts tagOptions) error {
src := d.args
// NOTE(toby3d): get tag value as query name
tagValue, ok := ft.Tag.Lookup(tagName)
if !ok || tagValue == "" || tagValue == "-" || !dec.source.Has(tagValue) {
continue
if keyIndex := strings.LastIndex(key, ","); keyIndex != -1 {
if index, err := strconv.Atoi(key[keyIndex+1:]); err == nil {
key = key[:keyIndex]
src = make(url.Values)
src.Set(key, d.args[key][index])
}
}
switch dst.Kind() {
case reflect.Bool:
out, err := strconv.ParseBool(src.Get(key))
if err != nil {
return err
}
field := v.Field(i)
dst.SetBool(out)
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
out, err := strconv.ParseInt(src.Get(key), 10, 64)
if err != nil {
return err
}
// NOTE(toby3d): read struct field type
switch ft.Type.Kind() {
case reflect.String:
field.SetString(string(dec.source.Peek(tagValue)))
case reflect.Int:
field.SetInt(int64(dec.source.GetUintOrZero(tagValue)))
case reflect.Float64:
field.SetFloat(dec.source.GetUfloatOrZero(tagValue))
case reflect.Bool:
field.SetBool(dec.source.GetBool(tagValue))
case reflect.Ptr: // NOTE(toby3d): pointer to another struct
// NOTE(toby3d): check what custom unmarshal method exists
unmarshalFunc := field.MethodByName("UnmarshalForm")
if unmarshalFunc.IsZero() {
dst.SetInt(out)
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
out, err := strconv.ParseUint(src.Get(key), 10, 64)
if err != nil {
return err
}
dst.SetUint(out)
case reflect.Float32, reflect.Float64:
out, err := strconv.ParseFloat(src.Get(key), 64)
if err != nil {
return err
}
dst.SetFloat(out)
// case reflect.Array: // TODO(toby3d)
// case reflect.Interface: // TODO(toby3d)
case reflect.Slice:
// NOTE(toby3d): copy raw []byte value as is
if dst.Type().Elem().Kind() == reflect.Uint8 {
dst.SetBytes([]byte(src.Get(key)))
return nil
}
// NOTE(toby3d): if contains UnmarshalForm method
for i := 0; i < dst.Addr().NumMethod(); i++ {
if dst.Addr().Type().Method(i).Name != methodName {
continue
}
field.Set(reflect.New(ft.Type.Elem())) // NOTE(toby3d): initialize zero value
unmarshalFunc.Call([]reflect.Value{reflect.ValueOf(dec.source.Peek(tagValue))})
case reflect.Slice:
switch ft.Type.Elem().Kind() {
case reflect.Uint8: // NOTE(toby3d): bytes slice
field.SetBytes(dec.source.Peek(tagValue))
case reflect.String: // NOTE(toby3d): string slice
values := dec.source.PeekMulti(tagValue)
slice := reflect.MakeSlice(ft.Type, len(values), len(values))
in := make([]reflect.Value, 1)
in[0] = reflect.ValueOf([]byte(src.Get(key)))
for j, vv := range values {
slice.Index(j).SetString(string(vv))
out := dst.Addr().Method(i).Call(in)
if len(out) > 0 && out[0].Interface() != nil && !opts.Contains(tagOmitempty) {
return out[0].Interface().(error)
}
return nil
}
if dst.IsNil() {
slice := d.args[key]
dst.Set(reflect.MakeSlice(dst.Type(), len(slice), cap(slice)))
}
for i := 0; i < dst.Len(); i++ {
if err := d.decode(fmt.Sprintf("%s,%d", key, i), dst.Index(i), ""); err != nil {
return err
}
}
case reflect.String:
dst.SetString(string(src.Get(key)))
case reflect.Pointer:
if dst.IsNil() {
dst.Set(reflect.New(dst.Type().Elem()))
}
// NOTE(toby3d): if contains UnmarshalForm method
for i := 0; i < dst.NumMethod(); i++ {
if dst.Type().Method(i).Name != methodName {
continue
}
in := make([]reflect.Value, 1)
in[0] = reflect.ValueOf([]byte(src.Get(key)))
out := dst.Method(i).Call(in)
if len(out) > 0 && out[0].Interface() != nil && !opts.Contains(tagOmitempty) {
return out[0].Interface().(error)
}
return nil
}
if err := d.decode(key, dst.Elem(), ""); err != nil {
return err
}
case reflect.Struct:
// NOTE(toby3d): if contains UnmarshalForm method
for i := 0; i < dst.Addr().NumMethod(); i++ {
if dst.Addr().Type().Method(i).Name != methodName {
continue
}
in := make([]reflect.Value, 1)
in[0] = reflect.ValueOf([]byte(src.Get(key)))
out := dst.Addr().Method(i).Call(in)
if len(out) > 0 && out[0].Interface() != nil && !opts.Contains(tagOmitempty) {
return out[0].Interface().(error)
}
return nil
}
for i := 0; i < dst.NumField(); i++ {
if name, opts := parseTag(string(dst.Type().Field(i).Tag.Get(d.tag))); name != tagIgnore {
if err := d.decode(name, dst.Field(i), opts); err != nil {
return err
}
field.Set(slice)
}
}
}
return
return nil
}

View File

@ -1,83 +1,177 @@
package form_test
import (
"errors"
"net/url"
"strings"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
http "github.com/valyala/fasthttp"
"github.com/google/go-cmp/cmp"
"source.toby3d.me/toby3d/form"
)
type (
ResponseType string
URI struct {
*http.URI `form:"-"`
TestResult struct {
Skip any `form:"-"`
PtrStruct *Struct `form:"ptrStruct"`
PtrStructs *Structs `form:"ptrStructs"`
NullStruct NullStruct `form:"nullstruct,omitempty"`
Struct Struct `form:"struct"`
Empty string `form:"empty"`
String string `form:"string"`
NotFormTag string `json:"notFormTag"`
ArrayStruct []Struct `form:"arrayStruct[]"`
ArrayPtrStruct []*Struct `form:"arrayPtrStruct[]"`
Ints []int `form:"ints[]"`
Structs Structs `form:"structs"`
Bytes []byte `form:"bytes"`
Uint uint `form:"uint"`
Int int `form:"int"`
Float float32 `form:"float"`
Bool bool `form:"bool"`
}
TestResult struct {
State []byte `form:"state"`
Scope []string `form:"scope[]"`
ClientID *URI `form:"client_id"`
RedirectURI *URI `form:"redirect_uri"`
Me *URI `form:"me"`
ResponseType ResponseType `form:"response_type"`
CodeChallenge string `form:"code_challenge"`
CodeChallengeMethod string `form:"code_challenge_method"`
Struct struct {
uid string `form:"-"`
}
Structs []Struct
NullStruct struct {
uid string `form:"-"`
}
)
const testData string = `response_type=code` + // NOTE(toby3d): string type alias
`&state=1234567890` + // NOTE(toby3d): raw value
// NOTE(toby3d): custom URL types
`&client_id=https://app.example.com/` +
`&redirect_uri=https://app.example.com/redirect` +
`&me=https://user.example.net/` +
// NOTE(toby3d): plain strings
`&code_challenge=OfYAxt8zU2dAPDWQxTAUIteRzMsoj9QBdMIVEDOErUo` +
`&code_challenge_method=S256` +
// NOTE(toby3d): multiple values
`&scope[]=profile` +
`&scope[]=create` +
`&scope[]=update` +
`&scope[]=delete`
func TestUnmarshal(t *testing.T) {
t.Parallel()
args := http.AcquireArgs()
clientId, redirectUri, me := http.AcquireURI(), http.AcquireURI(), http.AcquireURI()
t.Run("valid", func(t *testing.T) {
t.Parallel()
t.Cleanup(func() {
http.ReleaseURI(me)
http.ReleaseURI(redirectUri)
http.ReleaseURI(clientId)
http.ReleaseArgs(args)
args, err := url.ParseQuery(`skip=dontTouchMe` +
`&bool=true` +
`&string=hello+world` +
`&int=42` +
`&uint=420` +
`&float=4.2` +
// `&interface=a1b2c3` + // TODO(toby3d)
`&struct=abc` +
`&ptrStruct=123` +
`&arrayStruct[]=abc` +
`&arrayStruct[]=123` +
`&arrayPtrStruct[]=321` +
`&arrayPtrStruct[]=bca` +
`&ints[]=240` +
`&ints[]=420` +
`&bytes=sampletext` +
`&notFormTag=dontParseMe` +
`&structs=123+abc` +
`&ptrStructs=bca+321`)
if err != nil {
t.Fatal(err)
}
in := new(TestResult)
if err := form.Unmarshal([]byte(args.Encode()), in); err != nil {
t.Fatal(err)
}
out := TestResult{
Skip: nil,
Bool: true,
Float: 4.2,
Int: 42,
// Interface: []byte("a1b2c3"), // TODO(toby3d)
PtrStruct: &Struct{uid: "123"},
String: "hello world",
Struct: Struct{uid: "abc"},
Uint: 420,
ArrayStruct: []Struct{
{uid: "abc"},
{uid: "123"},
},
ArrayPtrStruct: []*Struct{
{uid: "321"},
{uid: "bca"},
},
Ints: []int{240, 420},
Empty: "",
Bytes: []byte("sampletext"),
NotFormTag: "",
NullStruct: NullStruct{uid: ""},
Structs: Structs{
{uid: "123"},
{uid: "abc"},
},
PtrStructs: &Structs{
{uid: "bca"},
{uid: "321"},
},
}
opts := []cmp.Option{
cmp.AllowUnexported(Struct{}, NullStruct{}),
}
if !cmp.Equal(&out, in, opts...) {
t.Errorf("Unmarshal(%s, &in)\n%+s", args.Encode(), cmp.Diff(&out, in, opts...))
}
})
require.NoError(t, clientId.Parse(nil, []byte("https://app.example.com/")))
require.NoError(t, redirectUri.Parse(nil, []byte("https://app.example.com/redirect")))
require.NoError(t, me.Parse(nil, []byte("https://user.example.net/")))
args.Parse(testData)
t.Run("invalid", func(t *testing.T) {
t.Parallel()
result := new(TestResult)
require.NoError(t, form.Unmarshal(args, result))
assert.Equal(t, &TestResult{
ClientID: &URI{URI: clientId},
Me: &URI{URI: me},
RedirectURI: &URI{URI: redirectUri},
State: []byte("1234567890"),
Scope: []string{"profile", "create", "update", "delete"},
CodeChallengeMethod: "S256",
CodeChallenge: "OfYAxt8zU2dAPDWQxTAUIteRzMsoj9QBdMIVEDOErUo",
ResponseType: "code",
}, result)
args, err := url.ParseQuery("arrayStruct[]=wtf")
if err != nil {
t.Fatal(err)
}
in := new(TestResult)
if err := form.Unmarshal([]byte(args.Encode()), in); err == nil {
t.Errorf("Unmarshal(%s, &in) = %#+v, want error", args.Encode(), err)
}
})
}
func (src *URI) UnmarshalForm(v []byte) error {
src.URI = http.AcquireURI()
func (s *Structs) UnmarshalForm(v []byte) error {
for _, f := range strings.Fields(string(v)) {
*s = append(*s, Struct{uid: f})
}
return src.Parse(nil, v)
return nil
}
func (s Structs) GoString() string {
out := make([]string, len(s))
for i := range s {
out[i] = s[i].uid
}
return "Structs(" + strings.Join(out, ", ") + ")"
}
func (s *Struct) UnmarshalForm(v []byte) error {
src := string(v)
switch src {
case "123", "abc", "321", "bca":
s.uid = string(v)
return nil
}
return errors.New("Struct: dough!")
}
func (ns *NullStruct) UnmarshalForm(v []byte) error {
if src := string(v); src != "" {
ns.uid = src
return nil
}
return errors.New("NullStruct: dough!")
}
func (s Struct) GoString() string { return "Struct(" + s.uid + ")" }

11
go.mod
View File

@ -1,17 +1,14 @@
module source.toby3d.me/toby3d/form
go 1.17
go 1.18
require (
github.com/stretchr/testify v1.7.0
github.com/valyala/fasthttp v1.34.0
github.com/google/go-cmp v0.5.8
github.com/valyala/fasthttp v1.37.0
)
require (
github.com/andybalholm/brotli v1.0.4 // indirect
github.com/davecgh/go-spew v1.1.0 // indirect
github.com/klauspost/compress v1.15.1 // indirect
github.com/pmezard/go-difflib v1.0.0 // indirect
github.com/klauspost/compress v1.15.4 // indirect
github.com/valyala/bytebufferpool v1.0.0 // indirect
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c // indirect
)

21
go.sum
View File

@ -1,19 +1,14 @@
github.com/andybalholm/brotli v1.0.4 h1:V7DdXeJtZscaqfNuAdSRuRFzuiKlHSC/Zh3zl9qY3JY=
github.com/andybalholm/brotli v1.0.4/go.mod h1:fO7iG3H7G2nSZ7m0zPUDn85XEX2GTukHGRSepvi9Eig=
github.com/davecgh/go-spew v1.1.0 h1:ZDRjVQ15GmhC3fiQ8ni8+OwkZQO4DARzQgrnXU1Liz8=
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/google/go-cmp v0.5.8 h1:e6P7q2lk1O+qJJb4BtCQXlK8vWEO8V1ZeuEdJNOqZyg=
github.com/google/go-cmp v0.5.8/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY=
github.com/klauspost/compress v1.15.0/go.mod h1:/3/Vjq9QcHkK5uEr5lBEmyoZ1iFhe47etQ6QUkpK6sk=
github.com/klauspost/compress v1.15.1 h1:y9FcTHGyrebwfP0ZZqFiaxTaiDnUrGkJkI+f583BL1A=
github.com/klauspost/compress v1.15.1/go.mod h1:/3/Vjq9QcHkK5uEr5lBEmyoZ1iFhe47etQ6QUkpK6sk=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
github.com/stretchr/testify v1.7.0 h1:nwc3DEeHmmLAfoZucVR881uASk0Mfjw8xYJ99tb5CcY=
github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
github.com/klauspost/compress v1.15.4 h1:1kn4/7MepF/CHmYub99/nNX8az0IJjfSOU/jbnTVfqQ=
github.com/klauspost/compress v1.15.4/go.mod h1:PhcZ0MbTNciWF3rruxRgKxI5NkcHHrHUDtV4Yw2GlzU=
github.com/valyala/bytebufferpool v1.0.0 h1:GqA5TC/0021Y/b9FG4Oi9Mr3q7XYx6KllzawFIhcdPw=
github.com/valyala/bytebufferpool v1.0.0/go.mod h1:6bBcMArwyJ5K/AmCkWv1jt77kVWyCJ6HpOuEn7z0Csc=
github.com/valyala/fasthttp v1.34.0 h1:d3AAQJ2DRcxJYHm7OXNXtXt2as1vMDfxeIcFvhmGGm4=
github.com/valyala/fasthttp v1.34.0/go.mod h1:epZA5N+7pY6ZaEKRmstzOuYJx9HI8DI1oaCGZpdH4h0=
github.com/valyala/fasthttp v1.37.0 h1:7WHCyI7EAkQMVmrfBhWTCOaeROb1aCBiTopx63LkMbE=
github.com/valyala/fasthttp v1.37.0/go.mod h1:t/G+3rLek+CyY9bnIE+YlMRddxVAAGjhxndDB4i4C0I=
github.com/valyala/tcplisten v1.0.0/go.mod h1:T0xQ8SeCZGxckz9qRXTfG43PvQ/mcWh7FwZEA7Ioqkc=
golang.org/x/crypto v0.0.0-20220214200702-86341886e292/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4=
golang.org/x/net v0.0.0-20211112202133-69e39bad7dc2/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y=
@ -28,7 +23,3 @@ golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuX
golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ=
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c h1:dUUwHk2QECo/6vqA44rthZ8ie2QXMNeKRTHCNY2nXvo=
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=

31
tags.go Normal file
View File

@ -0,0 +1,31 @@
package form
import (
"strings"
)
type tagOptions string
const delim rune = ','
func parseTag(tag string) (string, tagOptions) {
tag, opt, _ := strings.Cut(tag, string(delim))
return tag, tagOptions(opt)
}
func (o tagOptions) Contains(optionName string) bool {
if len(o) == 0 {
return false
}
s := string(o)
for s != "" {
var name string
if name, s, _ = strings.Cut(s, string(delim)); name == optionName {
return true
}
}
return false
}