github.com/abemedia/go-don@v0.2.2-0.20240329015135-be88e32bb73b/decoder/compile.go (about)

     1  package decoder
     2  
     3  import (
     4  	"encoding"
     5  	"reflect"
     6  	"strconv"
     7  	"unsafe"
     8  
     9  	"github.com/abemedia/go-don/internal/byteconv"
    10  )
    11  
    12  type decoder = func(reflect.Value, Getter) error
    13  
    14  func noopDecoder(reflect.Value, Getter) error { return nil }
    15  
    16  var unmarshalerType = reflect.TypeOf((*encoding.TextUnmarshaler)(nil)).Elem()
    17  
    18  //nolint:cyclop,funlen
    19  func compile(typ reflect.Type, tagKey string, isPtr bool) (decoder, error) {
    20  	decoders := []decoder{}
    21  
    22  	for i := 0; i < typ.NumField(); i++ {
    23  		f := typ.Field(i)
    24  		if f.PkgPath != "" {
    25  			continue // skip unexported fields
    26  		}
    27  
    28  		t, k, ptr := typeKind(f.Type)
    29  
    30  		tag, ok := f.Tag.Lookup(tagKey)
    31  		if !ok && k != reflect.Struct {
    32  			continue
    33  		}
    34  
    35  		if reflect.PointerTo(t).Implements(unmarshalerType) {
    36  			decoders = append(decoders, decodeTextUnmarshaler(get(ptr, i, t), tag))
    37  			continue
    38  		}
    39  
    40  		switch k {
    41  		case reflect.Struct:
    42  			dec, err := compile(t, tagKey, ptr)
    43  			if err != nil {
    44  				return nil, err
    45  			}
    46  			index := i
    47  			decoders = append(decoders, func(v reflect.Value, m Getter) error {
    48  				return dec(v.Field(index), m)
    49  			})
    50  		case reflect.String:
    51  			decoders = append(decoders, decodeString(set[string](ptr, i, t), tag))
    52  		case reflect.Int:
    53  			decoders = append(decoders, decodeInt(set[int](ptr, i, t), tag, strconv.IntSize))
    54  		case reflect.Int8:
    55  			decoders = append(decoders, decodeInt(set[int8](ptr, i, t), tag, 8))
    56  		case reflect.Int16:
    57  			decoders = append(decoders, decodeInt(set[int16](ptr, i, t), tag, 16))
    58  		case reflect.Int32:
    59  			decoders = append(decoders, decodeInt(set[int32](ptr, i, t), tag, 32))
    60  		case reflect.Int64:
    61  			decoders = append(decoders, decodeInt(set[int64](ptr, i, t), tag, 64))
    62  		case reflect.Uint:
    63  			decoders = append(decoders, decodeUint(set[uint](ptr, i, t), tag, strconv.IntSize))
    64  		case reflect.Uint8:
    65  			decoders = append(decoders, decodeUint(set[uint8](ptr, i, t), tag, 8))
    66  		case reflect.Uint16:
    67  			decoders = append(decoders, decodeUint(set[uint16](ptr, i, t), tag, 16))
    68  		case reflect.Uint32:
    69  			decoders = append(decoders, decodeUint(set[uint32](ptr, i, t), tag, 32))
    70  		case reflect.Uint64:
    71  			decoders = append(decoders, decodeUint(set[uint64](ptr, i, t), tag, 64))
    72  		case reflect.Float32:
    73  			decoders = append(decoders, decodeFloat(set[float32](ptr, i, t), tag, 32))
    74  		case reflect.Float64:
    75  			decoders = append(decoders, decodeFloat(set[float64](ptr, i, t), tag, 64))
    76  		case reflect.Bool:
    77  			decoders = append(decoders, decodeBool(set[bool](ptr, i, t), tag))
    78  		case reflect.Slice:
    79  			switch t.Elem().Kind() {
    80  			case reflect.String:
    81  				decoders = append(decoders, decodeStrings(set[[]string](ptr, i, t), tag))
    82  			case reflect.Uint8:
    83  				decoders = append(decoders, decodeBytes(set[[]byte](ptr, i, t), tag))
    84  			}
    85  		default:
    86  			return nil, ErrUnsupportedType
    87  		}
    88  	}
    89  
    90  	if len(decoders) == 0 {
    91  		return nil, ErrTagNotFound
    92  	}
    93  
    94  	return func(v reflect.Value, d Getter) error {
    95  		if isPtr {
    96  			if v.IsNil() {
    97  				v.Set(reflect.New(typ))
    98  			}
    99  			v = v.Elem()
   100  		}
   101  
   102  		for _, dec := range decoders {
   103  			if err := dec(v, d); err != nil {
   104  				return err
   105  			}
   106  		}
   107  
   108  		return nil
   109  	}, nil
   110  }
   111  
   112  func typeKind(t reflect.Type) (reflect.Type, reflect.Kind, bool) {
   113  	var isPtr bool
   114  
   115  	k := t.Kind()
   116  	if k == reflect.Pointer {
   117  		t = t.Elem()
   118  		k = t.Kind()
   119  		isPtr = true
   120  	}
   121  
   122  	return t, k, isPtr
   123  }
   124  
   125  func set[T any](ptr bool, i int, t reflect.Type) func(reflect.Value, T) {
   126  	if ptr {
   127  		return func(v reflect.Value, d T) {
   128  			f := v.Field(i)
   129  			if f.IsNil() {
   130  				f.Set(reflect.New(t))
   131  			}
   132  			*(*T)(unsafe.Pointer(f.Elem().UnsafeAddr())) = d
   133  		}
   134  	}
   135  
   136  	return func(v reflect.Value, d T) {
   137  		*(*T)(unsafe.Pointer(v.Field(i).UnsafeAddr())) = d
   138  	}
   139  }
   140  
   141  func get(ptr bool, i int, t reflect.Type) func(v reflect.Value) reflect.Value {
   142  	if ptr {
   143  		return func(v reflect.Value) reflect.Value {
   144  			f := v.Field(i)
   145  			if f.IsNil() {
   146  				f.Set(reflect.New(t))
   147  			}
   148  			return f
   149  		}
   150  	}
   151  
   152  	return func(v reflect.Value) reflect.Value {
   153  		return v.Field(i).Addr()
   154  	}
   155  }
   156  
   157  func decodeTextUnmarshaler(get func(reflect.Value) reflect.Value, k string) decoder {
   158  	return func(v reflect.Value, g Getter) error {
   159  		if s := g.Get(k); s != "" {
   160  			return get(v).Interface().(encoding.TextUnmarshaler).UnmarshalText(byteconv.Atob(s)) //nolint:forcetypeassert
   161  		}
   162  		return nil
   163  	}
   164  }
   165  
   166  func decodeString(set func(reflect.Value, string), k string) decoder {
   167  	return func(v reflect.Value, g Getter) error {
   168  		if s := g.Get(k); s != "" {
   169  			set(v, s)
   170  		}
   171  		return nil
   172  	}
   173  }
   174  
   175  func decodeInt[T int | int8 | int16 | int32 | int64](set func(reflect.Value, T), k string, bits int) decoder {
   176  	return func(v reflect.Value, g Getter) error {
   177  		if s := g.Get(k); s != "" {
   178  			n, err := strconv.ParseInt(s, 10, bits)
   179  			if err != nil {
   180  				return err
   181  			}
   182  			set(v, T(n))
   183  		}
   184  		return nil
   185  	}
   186  }
   187  
   188  func decodeFloat[T float32 | float64](set func(reflect.Value, T), k string, bits int) decoder {
   189  	return func(v reflect.Value, g Getter) error {
   190  		if s := g.Get(k); s != "" {
   191  			f, err := strconv.ParseFloat(s, bits)
   192  			if err != nil {
   193  				return err
   194  			}
   195  			set(v, T(f))
   196  		}
   197  		return nil
   198  	}
   199  }
   200  
   201  func decodeUint[T uint | uint8 | uint16 | uint32 | uint64](set func(reflect.Value, T), k string, bits int) decoder {
   202  	return func(v reflect.Value, g Getter) error {
   203  		if s := g.Get(k); s != "" {
   204  			n, err := strconv.ParseUint(s, 10, bits)
   205  			if err != nil {
   206  				return err
   207  			}
   208  			set(v, T(n))
   209  		}
   210  		return nil
   211  	}
   212  }
   213  
   214  func decodeBool(set func(reflect.Value, bool), k string) decoder {
   215  	return func(v reflect.Value, g Getter) error {
   216  		if s := g.Get(k); s != "" {
   217  			b, err := strconv.ParseBool(s)
   218  			if err != nil {
   219  				return err
   220  			}
   221  			set(v, b)
   222  		}
   223  		return nil
   224  	}
   225  }
   226  
   227  func decodeBytes(set func(reflect.Value, []byte), k string) decoder {
   228  	return func(v reflect.Value, g Getter) error {
   229  		if s := g.Get(k); s != "" {
   230  			set(v, byteconv.Atob(s))
   231  		}
   232  		return nil
   233  	}
   234  }
   235  
   236  func decodeStrings(set func(reflect.Value, []string), k string) decoder {
   237  	return func(v reflect.Value, g Getter) error {
   238  		if s := g.Values(k); s != nil {
   239  			set(v, s)
   240  		}
   241  		return nil
   242  	}
   243  }