github.com/aavshr/aws-sdk-go@v1.41.3/private/protocol/xml/xmlutil/unmarshal.go (about)

     1  package xmlutil
     2  
     3  import (
     4  	"bytes"
     5  	"encoding/base64"
     6  	"encoding/xml"
     7  	"fmt"
     8  	"io"
     9  	"reflect"
    10  	"strconv"
    11  	"strings"
    12  	"time"
    13  
    14  	"github.com/aavshr/aws-sdk-go/aws/awserr"
    15  	"github.com/aavshr/aws-sdk-go/private/protocol"
    16  )
    17  
    18  // UnmarshalXMLError unmarshals the XML error from the stream into the value
    19  // type specified. The value must be a pointer. If the message fails to
    20  // unmarshal, the message content will be included in the returned error as a
    21  // awserr.UnmarshalError.
    22  func UnmarshalXMLError(v interface{}, stream io.Reader) error {
    23  	var errBuf bytes.Buffer
    24  	body := io.TeeReader(stream, &errBuf)
    25  
    26  	err := xml.NewDecoder(body).Decode(v)
    27  	if err != nil && err != io.EOF {
    28  		return awserr.NewUnmarshalError(err,
    29  			"failed to unmarshal error message", errBuf.Bytes())
    30  	}
    31  
    32  	return nil
    33  }
    34  
    35  // UnmarshalXML deserializes an xml.Decoder into the container v. V
    36  // needs to match the shape of the XML expected to be decoded.
    37  // If the shape doesn't match unmarshaling will fail.
    38  func UnmarshalXML(v interface{}, d *xml.Decoder, wrapper string) error {
    39  	n, err := XMLToStruct(d, nil)
    40  	if err != nil {
    41  		return err
    42  	}
    43  	if n.Children != nil {
    44  		for _, root := range n.Children {
    45  			for _, c := range root {
    46  				if wrappedChild, ok := c.Children[wrapper]; ok {
    47  					c = wrappedChild[0] // pull out wrapped element
    48  				}
    49  
    50  				err = parse(reflect.ValueOf(v), c, "")
    51  				if err != nil {
    52  					if err == io.EOF {
    53  						return nil
    54  					}
    55  					return err
    56  				}
    57  			}
    58  		}
    59  		return nil
    60  	}
    61  	return nil
    62  }
    63  
    64  // parse deserializes any value from the XMLNode. The type tag is used to infer the type, or reflect
    65  // will be used to determine the type from r.
    66  func parse(r reflect.Value, node *XMLNode, tag reflect.StructTag) error {
    67  	xml := tag.Get("xml")
    68  	if len(xml) != 0 {
    69  		name := strings.SplitAfterN(xml, ",", 2)[0]
    70  		if name == "-" {
    71  			return nil
    72  		}
    73  	}
    74  
    75  	rtype := r.Type()
    76  	if rtype.Kind() == reflect.Ptr {
    77  		rtype = rtype.Elem() // check kind of actual element type
    78  	}
    79  
    80  	t := tag.Get("type")
    81  	if t == "" {
    82  		switch rtype.Kind() {
    83  		case reflect.Struct:
    84  			// also it can't be a time object
    85  			if _, ok := r.Interface().(*time.Time); !ok {
    86  				t = "structure"
    87  			}
    88  		case reflect.Slice:
    89  			// also it can't be a byte slice
    90  			if _, ok := r.Interface().([]byte); !ok {
    91  				t = "list"
    92  			}
    93  		case reflect.Map:
    94  			t = "map"
    95  		}
    96  	}
    97  
    98  	switch t {
    99  	case "structure":
   100  		if field, ok := rtype.FieldByName("_"); ok {
   101  			tag = field.Tag
   102  		}
   103  		return parseStruct(r, node, tag)
   104  	case "list":
   105  		return parseList(r, node, tag)
   106  	case "map":
   107  		return parseMap(r, node, tag)
   108  	default:
   109  		return parseScalar(r, node, tag)
   110  	}
   111  }
   112  
   113  // parseStruct deserializes a structure and its fields from an XMLNode. Any nested
   114  // types in the structure will also be deserialized.
   115  func parseStruct(r reflect.Value, node *XMLNode, tag reflect.StructTag) error {
   116  	t := r.Type()
   117  	if r.Kind() == reflect.Ptr {
   118  		if r.IsNil() { // create the structure if it's nil
   119  			s := reflect.New(r.Type().Elem())
   120  			r.Set(s)
   121  			r = s
   122  		}
   123  
   124  		r = r.Elem()
   125  		t = t.Elem()
   126  	}
   127  
   128  	// unwrap any payloads
   129  	if payload := tag.Get("payload"); payload != "" {
   130  		field, _ := t.FieldByName(payload)
   131  		return parseStruct(r.FieldByName(payload), node, field.Tag)
   132  	}
   133  
   134  	for i := 0; i < t.NumField(); i++ {
   135  		field := t.Field(i)
   136  		if c := field.Name[0:1]; strings.ToLower(c) == c {
   137  			continue // ignore unexported fields
   138  		}
   139  
   140  		// figure out what this field is called
   141  		name := field.Name
   142  		if field.Tag.Get("flattened") != "" && field.Tag.Get("locationNameList") != "" {
   143  			name = field.Tag.Get("locationNameList")
   144  		} else if locName := field.Tag.Get("locationName"); locName != "" {
   145  			name = locName
   146  		}
   147  
   148  		// try to find the field by name in elements
   149  		elems := node.Children[name]
   150  
   151  		if elems == nil { // try to find the field in attributes
   152  			if val, ok := node.findElem(name); ok {
   153  				elems = []*XMLNode{{Text: val}}
   154  			}
   155  		}
   156  
   157  		member := r.FieldByName(field.Name)
   158  		for _, elem := range elems {
   159  			err := parse(member, elem, field.Tag)
   160  			if err != nil {
   161  				return err
   162  			}
   163  		}
   164  	}
   165  	return nil
   166  }
   167  
   168  // parseList deserializes a list of values from an XML node. Each list entry
   169  // will also be deserialized.
   170  func parseList(r reflect.Value, node *XMLNode, tag reflect.StructTag) error {
   171  	t := r.Type()
   172  
   173  	if tag.Get("flattened") == "" { // look at all item entries
   174  		mname := "member"
   175  		if name := tag.Get("locationNameList"); name != "" {
   176  			mname = name
   177  		}
   178  
   179  		if Children, ok := node.Children[mname]; ok {
   180  			if r.IsNil() {
   181  				r.Set(reflect.MakeSlice(t, len(Children), len(Children)))
   182  			}
   183  
   184  			for i, c := range Children {
   185  				err := parse(r.Index(i), c, "")
   186  				if err != nil {
   187  					return err
   188  				}
   189  			}
   190  		}
   191  	} else { // flattened list means this is a single element
   192  		if r.IsNil() {
   193  			r.Set(reflect.MakeSlice(t, 0, 0))
   194  		}
   195  
   196  		childR := reflect.Zero(t.Elem())
   197  		r.Set(reflect.Append(r, childR))
   198  		err := parse(r.Index(r.Len()-1), node, "")
   199  		if err != nil {
   200  			return err
   201  		}
   202  	}
   203  
   204  	return nil
   205  }
   206  
   207  // parseMap deserializes a map from an XMLNode. The direct children of the XMLNode
   208  // will also be deserialized as map entries.
   209  func parseMap(r reflect.Value, node *XMLNode, tag reflect.StructTag) error {
   210  	if r.IsNil() {
   211  		r.Set(reflect.MakeMap(r.Type()))
   212  	}
   213  
   214  	if tag.Get("flattened") == "" { // look at all child entries
   215  		for _, entry := range node.Children["entry"] {
   216  			parseMapEntry(r, entry, tag)
   217  		}
   218  	} else { // this element is itself an entry
   219  		parseMapEntry(r, node, tag)
   220  	}
   221  
   222  	return nil
   223  }
   224  
   225  // parseMapEntry deserializes a map entry from a XML node.
   226  func parseMapEntry(r reflect.Value, node *XMLNode, tag reflect.StructTag) error {
   227  	kname, vname := "key", "value"
   228  	if n := tag.Get("locationNameKey"); n != "" {
   229  		kname = n
   230  	}
   231  	if n := tag.Get("locationNameValue"); n != "" {
   232  		vname = n
   233  	}
   234  
   235  	keys, ok := node.Children[kname]
   236  	values := node.Children[vname]
   237  	if ok {
   238  		for i, key := range keys {
   239  			keyR := reflect.ValueOf(key.Text)
   240  			value := values[i]
   241  			valueR := reflect.New(r.Type().Elem()).Elem()
   242  
   243  			parse(valueR, value, "")
   244  			r.SetMapIndex(keyR, valueR)
   245  		}
   246  	}
   247  	return nil
   248  }
   249  
   250  // parseScaller deserializes an XMLNode value into a concrete type based on the
   251  // interface type of r.
   252  //
   253  // Error is returned if the deserialization fails due to invalid type conversion,
   254  // or unsupported interface type.
   255  func parseScalar(r reflect.Value, node *XMLNode, tag reflect.StructTag) error {
   256  	switch r.Interface().(type) {
   257  	case *string:
   258  		r.Set(reflect.ValueOf(&node.Text))
   259  		return nil
   260  	case []byte:
   261  		b, err := base64.StdEncoding.DecodeString(node.Text)
   262  		if err != nil {
   263  			return err
   264  		}
   265  		r.Set(reflect.ValueOf(b))
   266  	case *bool:
   267  		v, err := strconv.ParseBool(node.Text)
   268  		if err != nil {
   269  			return err
   270  		}
   271  		r.Set(reflect.ValueOf(&v))
   272  	case *int64:
   273  		v, err := strconv.ParseInt(node.Text, 10, 64)
   274  		if err != nil {
   275  			return err
   276  		}
   277  		r.Set(reflect.ValueOf(&v))
   278  	case *float64:
   279  		v, err := strconv.ParseFloat(node.Text, 64)
   280  		if err != nil {
   281  			return err
   282  		}
   283  		r.Set(reflect.ValueOf(&v))
   284  	case *time.Time:
   285  		format := tag.Get("timestampFormat")
   286  		if len(format) == 0 {
   287  			format = protocol.ISO8601TimeFormatName
   288  		}
   289  
   290  		t, err := protocol.ParseTime(format, node.Text)
   291  		if err != nil {
   292  			return err
   293  		}
   294  		r.Set(reflect.ValueOf(&t))
   295  	default:
   296  		return fmt.Errorf("unsupported value: %v (%s)", r.Interface(), r.Type())
   297  	}
   298  	return nil
   299  }