github.com/blend/go-sdk@v1.20240719.1/reflectutil/patch_strings.go (about) 1 /* 2 3 Copyright (c) 2024 - Present. Blend Labs, Inc. All rights reserved 4 Use of this source code is governed by a MIT license that can be found in the LICENSE file. 5 6 */ 7 8 package reflectutil 9 10 import ( 11 "encoding/base64" 12 "reflect" 13 "strconv" 14 "strings" 15 "time" 16 17 "github.com/blend/go-sdk/ex" 18 ) 19 20 // PatchStrings options. 21 const ( 22 // FieldTagEnv is the struct tag for what environment variable to use to populate a field. 23 FieldTagEnv = "env" 24 // FieldFlagCSV is a field tag flag (say that 10 times fast). 25 FieldFlagCSV = "csv" 26 // FieldFlagBase64 is a field tag flag (say that 10 times fast). 27 FieldFlagBase64 = "base64" 28 // FieldFlagBytes is a field tag flag (say that 10 times fast). 29 FieldFlagBytes = "bytes" 30 ) 31 32 // PatchStringer is a type that handles unmarshalling a map of strings into itself. 33 type PatchStringer interface { 34 PatchStrings(map[string]string) error 35 } 36 37 // PatchStringsFuncer is a type that handles unmarshalling a map of strings into itself. 38 type PatchStringsFuncer interface { 39 PatchStringsFunc(func(string) (string, bool)) error 40 } 41 42 // PatchStrings patches an object with a given map of data matched with tags of a given name or the name of the field. 43 func PatchStrings(tagName string, data map[string]string, obj interface{}) error { 44 // check if the type implements marshaler. 45 if typed, isTyped := obj.(PatchStringer); isTyped { 46 return typed.PatchStrings(data) 47 } 48 49 return PatchStringsFunc(tagName, func(key string) (string, bool) { value, ok := data[key]; return value, ok }, obj) 50 } 51 52 // PatchStringsFunc patches an object with a given map of data matched with tags of a given name or the name of the field. 53 func PatchStringsFunc(tagName string, getData func(string) (string, bool), obj interface{}) (err error) { 54 defer func() { 55 if r := recover(); r != nil { 56 err = ex.New(r) 57 } 58 }() 59 60 // check if the type implements marshaler. 61 if typed, isTyped := obj.(PatchStringsFuncer); isTyped { 62 return typed.PatchStringsFunc(getData) 63 } 64 65 objMeta := reflectType(obj) 66 objValue := reflectValue(obj) 67 68 typeDuration := reflect.TypeOf(time.Duration(time.Nanosecond)) 69 70 var field reflect.StructField 71 var fieldType reflect.Type 72 var fieldValue reflect.Value 73 var tag string 74 var pieces []string 75 var dataField string 76 var dataValue string 77 var dataFieldValue interface{} 78 var hasDataValue bool 79 80 var isCSV bool 81 var isBytes bool 82 var isBase64 bool 83 var assigned bool 84 85 for x := 0; x < objMeta.NumField(); x++ { 86 isCSV = false 87 isBytes = false 88 isBase64 = false 89 90 field = objMeta.Field(x) 91 fieldValue = objValue.FieldByName(field.Name) 92 93 // Treat structs as nested values. 94 if field.Type.Kind() == reflect.Struct { 95 if err = PatchStringsFunc(tagName, getData, objValue.Field(x).Addr().Interface()); err != nil { 96 return err 97 } 98 continue 99 } 100 101 tag = field.Tag.Get(tagName) 102 if len(tag) > 0 { 103 pieces = strings.Split(tag, ",") 104 dataField = pieces[0] 105 if len(pieces) > 1 { 106 for y := 1; y < len(pieces); y++ { 107 if pieces[y] == FieldFlagCSV { 108 isCSV = true 109 } else if pieces[y] == FieldFlagBase64 { 110 isBase64 = true 111 } else if pieces[y] == FieldFlagBytes { 112 isBytes = true 113 } 114 } 115 } 116 117 dataValue, hasDataValue = getData(dataField) 118 if !hasDataValue { 119 continue 120 } 121 122 if isCSV { 123 dataFieldValue = strings.Split(dataValue, ",") 124 } else if isBase64 { 125 dataFieldValue, err = base64.StdEncoding.DecodeString(dataValue) 126 if err != nil { 127 return 128 } 129 } else if isBytes { 130 dataFieldValue = []byte(dataValue) 131 } else { 132 errWithFieldName := func(err error) error { 133 return ex.New(err, ex.OptMessagef("key: %q", dataField)) 134 } 135 136 // figure out the rootmost type (i.e. deref ****ptr etc.) 137 fieldType = followType(field.Type) 138 switch fieldType { 139 case typeDuration: 140 dataFieldValue, err = time.ParseDuration(dataValue) 141 if err != nil { 142 err = errWithFieldName(err) 143 return 144 } 145 default: 146 switch fieldType.Kind() { 147 case reflect.Bool: 148 if hasDataValue { 149 dataFieldValue = parseBool(dataValue) 150 } else { 151 continue 152 } 153 case reflect.Float32: 154 if dataValue == "" { 155 continue 156 } 157 dataFieldValue, err = strconv.ParseFloat(dataValue, 32) 158 if err != nil { 159 err = errWithFieldName(err) 160 return 161 } 162 case reflect.Float64: 163 if dataValue == "" { 164 continue 165 } 166 dataFieldValue, err = strconv.ParseFloat(dataValue, 64) 167 if err != nil { 168 err = errWithFieldName(err) 169 return 170 } 171 case reflect.Int8: 172 if dataValue == "" { 173 continue 174 } 175 dataFieldValue, err = strconv.ParseInt(dataValue, 10, 8) 176 if err != nil { 177 err = errWithFieldName(err) 178 return 179 } 180 case reflect.Int16: 181 if dataValue == "" { 182 continue 183 } 184 dataFieldValue, err = strconv.ParseInt(dataValue, 10, 16) 185 if err != nil { 186 err = errWithFieldName(err) 187 return 188 } 189 case reflect.Int32: 190 if dataValue == "" { 191 continue 192 } 193 dataFieldValue, err = strconv.ParseInt(dataValue, 10, 32) 194 if err != nil { 195 err = errWithFieldName(err) 196 return 197 } 198 case reflect.Int: 199 if dataValue == "" { 200 continue 201 } 202 dataFieldValue, err = strconv.ParseInt(dataValue, 10, 64) 203 if err != nil { 204 err = errWithFieldName(err) 205 return 206 } 207 case reflect.Int64: 208 if dataValue == "" { 209 continue 210 } 211 dataFieldValue, err = strconv.ParseInt(dataValue, 10, 64) 212 if err != nil { 213 err = errWithFieldName(err) 214 return 215 } 216 case reflect.Uint8: 217 if dataValue == "" { 218 continue 219 } 220 dataFieldValue, err = strconv.ParseUint(dataValue, 10, 8) 221 if err != nil { 222 err = errWithFieldName(err) 223 return 224 } 225 case reflect.Uint16: 226 if dataValue == "" { 227 continue 228 } 229 dataFieldValue, err = strconv.ParseUint(dataValue, 10, 8) 230 if err != nil { 231 err = errWithFieldName(err) 232 return 233 } 234 case reflect.Uint32: 235 if dataValue == "" { 236 continue 237 } 238 dataFieldValue, err = strconv.ParseUint(dataValue, 10, 32) 239 if err != nil { 240 err = errWithFieldName(err) 241 return 242 } 243 case reflect.Uint64: 244 if dataValue == "" { 245 continue 246 } 247 dataFieldValue, err = strconv.ParseUint(dataValue, 10, 64) 248 if err != nil { 249 err = errWithFieldName(err) 250 return 251 } 252 case reflect.Uint, reflect.Uintptr: 253 if dataValue == "" { 254 continue 255 } 256 dataFieldValue, err = strconv.ParseUint(dataValue, 10, 64) 257 if err != nil { 258 err = errWithFieldName(err) 259 return 260 } 261 case reflect.String: 262 dataFieldValue = dataValue 263 default: 264 err = ex.New("map strings into; unhandled assignment", ex.OptMessagef("type: %q", fieldType.String())) 265 return 266 } 267 } 268 } 269 270 value := reflectValue(dataFieldValue) 271 if !value.IsValid() { 272 err = ex.New("invalid value", ex.OptMessagef("%s `%s`", objMeta.Name(), field.Name)) 273 return 274 } 275 276 assigned, err = tryAssignment(fieldValue, value) 277 if err != nil { 278 return 279 } 280 if !assigned { 281 err = ex.New("cannot set field", ex.OptMessagef("%s `%s`", objMeta.Name(), field.Name)) 282 return 283 } 284 } 285 } 286 return nil 287 } 288 289 func followType(t reflect.Type) reflect.Type { 290 for t.Kind() == reflect.Ptr || t.Kind() == reflect.Interface { 291 t = t.Elem() 292 } 293 return t 294 } 295 296 func reflectValue(obj interface{}) reflect.Value { 297 v := reflect.ValueOf(obj) 298 for v.Kind() == reflect.Ptr || v.Kind() == reflect.Interface { 299 v = v.Elem() 300 } 301 return v 302 } 303 304 func reflectType(obj interface{}) reflect.Type { 305 t := reflect.TypeOf(obj) 306 for t.Kind() == reflect.Ptr { 307 t = t.Elem() 308 } 309 310 return t 311 } 312 313 func parseBool(str string) bool { 314 strLower := strings.ToLower(str) 315 switch strLower { 316 case "true", "1", "yes": 317 return true 318 } 319 return false 320 }