github.com/abemedia/go-don@v0.2.2-0.20240329015135-be88e32bb73b/decoder/compile.go (about) 1 package decoder 2 3 import ( 4 "encoding" 5 "reflect" 6 "strconv" 7 "unsafe" 8 9 "github.com/abemedia/go-don/internal/byteconv" 10 ) 11 12 type decoder = func(reflect.Value, Getter) error 13 14 func noopDecoder(reflect.Value, Getter) error { return nil } 15 16 var unmarshalerType = reflect.TypeOf((*encoding.TextUnmarshaler)(nil)).Elem() 17 18 //nolint:cyclop,funlen 19 func compile(typ reflect.Type, tagKey string, isPtr bool) (decoder, error) { 20 decoders := []decoder{} 21 22 for i := 0; i < typ.NumField(); i++ { 23 f := typ.Field(i) 24 if f.PkgPath != "" { 25 continue // skip unexported fields 26 } 27 28 t, k, ptr := typeKind(f.Type) 29 30 tag, ok := f.Tag.Lookup(tagKey) 31 if !ok && k != reflect.Struct { 32 continue 33 } 34 35 if reflect.PointerTo(t).Implements(unmarshalerType) { 36 decoders = append(decoders, decodeTextUnmarshaler(get(ptr, i, t), tag)) 37 continue 38 } 39 40 switch k { 41 case reflect.Struct: 42 dec, err := compile(t, tagKey, ptr) 43 if err != nil { 44 return nil, err 45 } 46 index := i 47 decoders = append(decoders, func(v reflect.Value, m Getter) error { 48 return dec(v.Field(index), m) 49 }) 50 case reflect.String: 51 decoders = append(decoders, decodeString(set[string](ptr, i, t), tag)) 52 case reflect.Int: 53 decoders = append(decoders, decodeInt(set[int](ptr, i, t), tag, strconv.IntSize)) 54 case reflect.Int8: 55 decoders = append(decoders, decodeInt(set[int8](ptr, i, t), tag, 8)) 56 case reflect.Int16: 57 decoders = append(decoders, decodeInt(set[int16](ptr, i, t), tag, 16)) 58 case reflect.Int32: 59 decoders = append(decoders, decodeInt(set[int32](ptr, i, t), tag, 32)) 60 case reflect.Int64: 61 decoders = append(decoders, decodeInt(set[int64](ptr, i, t), tag, 64)) 62 case reflect.Uint: 63 decoders = append(decoders, decodeUint(set[uint](ptr, i, t), tag, strconv.IntSize)) 64 case reflect.Uint8: 65 decoders = append(decoders, decodeUint(set[uint8](ptr, i, t), tag, 8)) 66 case reflect.Uint16: 67 decoders = append(decoders, decodeUint(set[uint16](ptr, i, t), tag, 16)) 68 case reflect.Uint32: 69 decoders = append(decoders, decodeUint(set[uint32](ptr, i, t), tag, 32)) 70 case reflect.Uint64: 71 decoders = append(decoders, decodeUint(set[uint64](ptr, i, t), tag, 64)) 72 case reflect.Float32: 73 decoders = append(decoders, decodeFloat(set[float32](ptr, i, t), tag, 32)) 74 case reflect.Float64: 75 decoders = append(decoders, decodeFloat(set[float64](ptr, i, t), tag, 64)) 76 case reflect.Bool: 77 decoders = append(decoders, decodeBool(set[bool](ptr, i, t), tag)) 78 case reflect.Slice: 79 switch t.Elem().Kind() { 80 case reflect.String: 81 decoders = append(decoders, decodeStrings(set[[]string](ptr, i, t), tag)) 82 case reflect.Uint8: 83 decoders = append(decoders, decodeBytes(set[[]byte](ptr, i, t), tag)) 84 } 85 default: 86 return nil, ErrUnsupportedType 87 } 88 } 89 90 if len(decoders) == 0 { 91 return nil, ErrTagNotFound 92 } 93 94 return func(v reflect.Value, d Getter) error { 95 if isPtr { 96 if v.IsNil() { 97 v.Set(reflect.New(typ)) 98 } 99 v = v.Elem() 100 } 101 102 for _, dec := range decoders { 103 if err := dec(v, d); err != nil { 104 return err 105 } 106 } 107 108 return nil 109 }, nil 110 } 111 112 func typeKind(t reflect.Type) (reflect.Type, reflect.Kind, bool) { 113 var isPtr bool 114 115 k := t.Kind() 116 if k == reflect.Pointer { 117 t = t.Elem() 118 k = t.Kind() 119 isPtr = true 120 } 121 122 return t, k, isPtr 123 } 124 125 func set[T any](ptr bool, i int, t reflect.Type) func(reflect.Value, T) { 126 if ptr { 127 return func(v reflect.Value, d T) { 128 f := v.Field(i) 129 if f.IsNil() { 130 f.Set(reflect.New(t)) 131 } 132 *(*T)(unsafe.Pointer(f.Elem().UnsafeAddr())) = d 133 } 134 } 135 136 return func(v reflect.Value, d T) { 137 *(*T)(unsafe.Pointer(v.Field(i).UnsafeAddr())) = d 138 } 139 } 140 141 func get(ptr bool, i int, t reflect.Type) func(v reflect.Value) reflect.Value { 142 if ptr { 143 return func(v reflect.Value) reflect.Value { 144 f := v.Field(i) 145 if f.IsNil() { 146 f.Set(reflect.New(t)) 147 } 148 return f 149 } 150 } 151 152 return func(v reflect.Value) reflect.Value { 153 return v.Field(i).Addr() 154 } 155 } 156 157 func decodeTextUnmarshaler(get func(reflect.Value) reflect.Value, k string) decoder { 158 return func(v reflect.Value, g Getter) error { 159 if s := g.Get(k); s != "" { 160 return get(v).Interface().(encoding.TextUnmarshaler).UnmarshalText(byteconv.Atob(s)) //nolint:forcetypeassert 161 } 162 return nil 163 } 164 } 165 166 func decodeString(set func(reflect.Value, string), k string) decoder { 167 return func(v reflect.Value, g Getter) error { 168 if s := g.Get(k); s != "" { 169 set(v, s) 170 } 171 return nil 172 } 173 } 174 175 func decodeInt[T int | int8 | int16 | int32 | int64](set func(reflect.Value, T), k string, bits int) decoder { 176 return func(v reflect.Value, g Getter) error { 177 if s := g.Get(k); s != "" { 178 n, err := strconv.ParseInt(s, 10, bits) 179 if err != nil { 180 return err 181 } 182 set(v, T(n)) 183 } 184 return nil 185 } 186 } 187 188 func decodeFloat[T float32 | float64](set func(reflect.Value, T), k string, bits int) decoder { 189 return func(v reflect.Value, g Getter) error { 190 if s := g.Get(k); s != "" { 191 f, err := strconv.ParseFloat(s, bits) 192 if err != nil { 193 return err 194 } 195 set(v, T(f)) 196 } 197 return nil 198 } 199 } 200 201 func decodeUint[T uint | uint8 | uint16 | uint32 | uint64](set func(reflect.Value, T), k string, bits int) decoder { 202 return func(v reflect.Value, g Getter) error { 203 if s := g.Get(k); s != "" { 204 n, err := strconv.ParseUint(s, 10, bits) 205 if err != nil { 206 return err 207 } 208 set(v, T(n)) 209 } 210 return nil 211 } 212 } 213 214 func decodeBool(set func(reflect.Value, bool), k string) decoder { 215 return func(v reflect.Value, g Getter) error { 216 if s := g.Get(k); s != "" { 217 b, err := strconv.ParseBool(s) 218 if err != nil { 219 return err 220 } 221 set(v, b) 222 } 223 return nil 224 } 225 } 226 227 func decodeBytes(set func(reflect.Value, []byte), k string) decoder { 228 return func(v reflect.Value, g Getter) error { 229 if s := g.Get(k); s != "" { 230 set(v, byteconv.Atob(s)) 231 } 232 return nil 233 } 234 } 235 236 func decodeStrings(set func(reflect.Value, []string), k string) decoder { 237 return func(v reflect.Value, g Getter) error { 238 if s := g.Values(k); s != nil { 239 set(v, s) 240 } 241 return nil 242 } 243 }