github.com/aavshr/aws-sdk-go@v1.41.3/private/protocol/json/jsonutil/unmarshal.go (about)

     1  package jsonutil
     2  
     3  import (
     4  	"bytes"
     5  	"encoding/base64"
     6  	"encoding/json"
     7  	"fmt"
     8  	"io"
     9  	"math/big"
    10  	"reflect"
    11  	"strings"
    12  	"time"
    13  
    14  	"github.com/aavshr/aws-sdk-go/aws"
    15  	"github.com/aavshr/aws-sdk-go/aws/awserr"
    16  	"github.com/aavshr/aws-sdk-go/private/protocol"
    17  )
    18  
    19  var millisecondsFloat = new(big.Float).SetInt64(1e3)
    20  
    21  // UnmarshalJSONError unmarshal's the reader's JSON document into the passed in
    22  // type. The value to unmarshal the json document into must be a pointer to the
    23  // type.
    24  func UnmarshalJSONError(v interface{}, stream io.Reader) error {
    25  	var errBuf bytes.Buffer
    26  	body := io.TeeReader(stream, &errBuf)
    27  
    28  	err := json.NewDecoder(body).Decode(v)
    29  	if err != nil {
    30  		msg := "failed decoding error message"
    31  		if err == io.EOF {
    32  			msg = "error message missing"
    33  			err = nil
    34  		}
    35  		return awserr.NewUnmarshalError(err, msg, errBuf.Bytes())
    36  	}
    37  
    38  	return nil
    39  }
    40  
    41  // UnmarshalJSON reads a stream and unmarshals the results in object v.
    42  func UnmarshalJSON(v interface{}, stream io.Reader) error {
    43  	var out interface{}
    44  
    45  	decoder := json.NewDecoder(stream)
    46  	decoder.UseNumber()
    47  	err := decoder.Decode(&out)
    48  	if err == io.EOF {
    49  		return nil
    50  	} else if err != nil {
    51  		return err
    52  	}
    53  
    54  	return unmarshaler{}.unmarshalAny(reflect.ValueOf(v), out, "")
    55  }
    56  
    57  // UnmarshalJSONCaseInsensitive reads a stream and unmarshals the result into the
    58  // object v. Ignores casing for structure members.
    59  func UnmarshalJSONCaseInsensitive(v interface{}, stream io.Reader) error {
    60  	var out interface{}
    61  
    62  	decoder := json.NewDecoder(stream)
    63  	decoder.UseNumber()
    64  	err := decoder.Decode(&out)
    65  	if err == io.EOF {
    66  		return nil
    67  	} else if err != nil {
    68  		return err
    69  	}
    70  
    71  	return unmarshaler{
    72  		caseInsensitive: true,
    73  	}.unmarshalAny(reflect.ValueOf(v), out, "")
    74  }
    75  
    76  type unmarshaler struct {
    77  	caseInsensitive bool
    78  }
    79  
    80  func (u unmarshaler) unmarshalAny(value reflect.Value, data interface{}, tag reflect.StructTag) error {
    81  	vtype := value.Type()
    82  	if vtype.Kind() == reflect.Ptr {
    83  		vtype = vtype.Elem() // check kind of actual element type
    84  	}
    85  
    86  	t := tag.Get("type")
    87  	if t == "" {
    88  		switch vtype.Kind() {
    89  		case reflect.Struct:
    90  			// also it can't be a time object
    91  			if _, ok := value.Interface().(*time.Time); !ok {
    92  				t = "structure"
    93  			}
    94  		case reflect.Slice:
    95  			// also it can't be a byte slice
    96  			if _, ok := value.Interface().([]byte); !ok {
    97  				t = "list"
    98  			}
    99  		case reflect.Map:
   100  			// cannot be a JSONValue map
   101  			if _, ok := value.Interface().(aws.JSONValue); !ok {
   102  				t = "map"
   103  			}
   104  		}
   105  	}
   106  
   107  	switch t {
   108  	case "structure":
   109  		if field, ok := vtype.FieldByName("_"); ok {
   110  			tag = field.Tag
   111  		}
   112  		return u.unmarshalStruct(value, data, tag)
   113  	case "list":
   114  		return u.unmarshalList(value, data, tag)
   115  	case "map":
   116  		return u.unmarshalMap(value, data, tag)
   117  	default:
   118  		return u.unmarshalScalar(value, data, tag)
   119  	}
   120  }
   121  
   122  func (u unmarshaler) unmarshalStruct(value reflect.Value, data interface{}, tag reflect.StructTag) error {
   123  	if data == nil {
   124  		return nil
   125  	}
   126  	mapData, ok := data.(map[string]interface{})
   127  	if !ok {
   128  		return fmt.Errorf("JSON value is not a structure (%#v)", data)
   129  	}
   130  
   131  	t := value.Type()
   132  	if value.Kind() == reflect.Ptr {
   133  		if value.IsNil() { // create the structure if it's nil
   134  			s := reflect.New(value.Type().Elem())
   135  			value.Set(s)
   136  			value = s
   137  		}
   138  
   139  		value = value.Elem()
   140  		t = t.Elem()
   141  	}
   142  
   143  	// unwrap any payloads
   144  	if payload := tag.Get("payload"); payload != "" {
   145  		field, _ := t.FieldByName(payload)
   146  		return u.unmarshalAny(value.FieldByName(payload), data, field.Tag)
   147  	}
   148  
   149  	for i := 0; i < t.NumField(); i++ {
   150  		field := t.Field(i)
   151  		if field.PkgPath != "" {
   152  			continue // ignore unexported fields
   153  		}
   154  
   155  		// figure out what this field is called
   156  		name := field.Name
   157  		if locName := field.Tag.Get("locationName"); locName != "" {
   158  			name = locName
   159  		}
   160  		if u.caseInsensitive {
   161  			if _, ok := mapData[name]; !ok {
   162  				// Fallback to uncased name search if the exact name didn't match.
   163  				for kn, v := range mapData {
   164  					if strings.EqualFold(kn, name) {
   165  						mapData[name] = v
   166  					}
   167  				}
   168  			}
   169  		}
   170  
   171  		member := value.FieldByIndex(field.Index)
   172  		err := u.unmarshalAny(member, mapData[name], field.Tag)
   173  		if err != nil {
   174  			return err
   175  		}
   176  	}
   177  	return nil
   178  }
   179  
   180  func (u unmarshaler) unmarshalList(value reflect.Value, data interface{}, tag reflect.StructTag) error {
   181  	if data == nil {
   182  		return nil
   183  	}
   184  	listData, ok := data.([]interface{})
   185  	if !ok {
   186  		return fmt.Errorf("JSON value is not a list (%#v)", data)
   187  	}
   188  
   189  	if value.IsNil() {
   190  		l := len(listData)
   191  		value.Set(reflect.MakeSlice(value.Type(), l, l))
   192  	}
   193  
   194  	for i, c := range listData {
   195  		err := u.unmarshalAny(value.Index(i), c, "")
   196  		if err != nil {
   197  			return err
   198  		}
   199  	}
   200  
   201  	return nil
   202  }
   203  
   204  func (u unmarshaler) unmarshalMap(value reflect.Value, data interface{}, tag reflect.StructTag) error {
   205  	if data == nil {
   206  		return nil
   207  	}
   208  	mapData, ok := data.(map[string]interface{})
   209  	if !ok {
   210  		return fmt.Errorf("JSON value is not a map (%#v)", data)
   211  	}
   212  
   213  	if value.IsNil() {
   214  		value.Set(reflect.MakeMap(value.Type()))
   215  	}
   216  
   217  	for k, v := range mapData {
   218  		kvalue := reflect.ValueOf(k)
   219  		vvalue := reflect.New(value.Type().Elem()).Elem()
   220  
   221  		u.unmarshalAny(vvalue, v, "")
   222  		value.SetMapIndex(kvalue, vvalue)
   223  	}
   224  
   225  	return nil
   226  }
   227  
   228  func (u unmarshaler) unmarshalScalar(value reflect.Value, data interface{}, tag reflect.StructTag) error {
   229  
   230  	switch d := data.(type) {
   231  	case nil:
   232  		return nil // nothing to do here
   233  	case string:
   234  		switch value.Interface().(type) {
   235  		case *string:
   236  			value.Set(reflect.ValueOf(&d))
   237  		case []byte:
   238  			b, err := base64.StdEncoding.DecodeString(d)
   239  			if err != nil {
   240  				return err
   241  			}
   242  			value.Set(reflect.ValueOf(b))
   243  		case *time.Time:
   244  			format := tag.Get("timestampFormat")
   245  			if len(format) == 0 {
   246  				format = protocol.ISO8601TimeFormatName
   247  			}
   248  
   249  			t, err := protocol.ParseTime(format, d)
   250  			if err != nil {
   251  				return err
   252  			}
   253  			value.Set(reflect.ValueOf(&t))
   254  		case aws.JSONValue:
   255  			// No need to use escaping as the value is a non-quoted string.
   256  			v, err := protocol.DecodeJSONValue(d, protocol.NoEscape)
   257  			if err != nil {
   258  				return err
   259  			}
   260  			value.Set(reflect.ValueOf(v))
   261  		default:
   262  			return fmt.Errorf("unsupported value: %v (%s)", value.Interface(), value.Type())
   263  		}
   264  	case json.Number:
   265  		switch value.Interface().(type) {
   266  		case *int64:
   267  			// Retain the old behavior where we would just truncate the float64
   268  			// calling d.Int64() here could cause an invalid syntax error due to the usage of strconv.ParseInt
   269  			f, err := d.Float64()
   270  			if err != nil {
   271  				return err
   272  			}
   273  			di := int64(f)
   274  			value.Set(reflect.ValueOf(&di))
   275  		case *float64:
   276  			f, err := d.Float64()
   277  			if err != nil {
   278  				return err
   279  			}
   280  			value.Set(reflect.ValueOf(&f))
   281  		case *time.Time:
   282  			float, ok := new(big.Float).SetString(d.String())
   283  			if !ok {
   284  				return fmt.Errorf("unsupported float time representation: %v", d.String())
   285  			}
   286  			float = float.Mul(float, millisecondsFloat)
   287  			ms, _ := float.Int64()
   288  			t := time.Unix(0, ms*1e6).UTC()
   289  			value.Set(reflect.ValueOf(&t))
   290  		default:
   291  			return fmt.Errorf("unsupported value: %v (%s)", value.Interface(), value.Type())
   292  		}
   293  	case bool:
   294  		switch value.Interface().(type) {
   295  		case *bool:
   296  			value.Set(reflect.ValueOf(&d))
   297  		default:
   298  			return fmt.Errorf("unsupported value: %v (%s)", value.Interface(), value.Type())
   299  		}
   300  	default:
   301  		return fmt.Errorf("unsupported JSON value (%v)", data)
   302  	}
   303  	return nil
   304  }