github.com/aavshr/aws-sdk-go@v1.41.3/private/protocol/rest/unmarshal.go (about)

     1  package rest
     2  
     3  import (
     4  	"bytes"
     5  	"encoding/base64"
     6  	"fmt"
     7  	"io"
     8  	"io/ioutil"
     9  	"net/http"
    10  	"reflect"
    11  	"strconv"
    12  	"strings"
    13  	"time"
    14  
    15  	"github.com/aavshr/aws-sdk-go/aws"
    16  	"github.com/aavshr/aws-sdk-go/aws/awserr"
    17  	"github.com/aavshr/aws-sdk-go/aws/request"
    18  	awsStrings "github.com/aavshr/aws-sdk-go/internal/strings"
    19  	"github.com/aavshr/aws-sdk-go/private/protocol"
    20  )
    21  
    22  // UnmarshalHandler is a named request handler for unmarshaling rest protocol requests
    23  var UnmarshalHandler = request.NamedHandler{Name: "awssdk.rest.Unmarshal", Fn: Unmarshal}
    24  
    25  // UnmarshalMetaHandler is a named request handler for unmarshaling rest protocol request metadata
    26  var UnmarshalMetaHandler = request.NamedHandler{Name: "awssdk.rest.UnmarshalMeta", Fn: UnmarshalMeta}
    27  
    28  // Unmarshal unmarshals the REST component of a response in a REST service.
    29  func Unmarshal(r *request.Request) {
    30  	if r.DataFilled() {
    31  		v := reflect.Indirect(reflect.ValueOf(r.Data))
    32  		if err := unmarshalBody(r, v); err != nil {
    33  			r.Error = err
    34  		}
    35  	}
    36  }
    37  
    38  // UnmarshalMeta unmarshals the REST metadata of a response in a REST service
    39  func UnmarshalMeta(r *request.Request) {
    40  	r.RequestID = r.HTTPResponse.Header.Get("X-Amzn-Requestid")
    41  	if r.RequestID == "" {
    42  		// Alternative version of request id in the header
    43  		r.RequestID = r.HTTPResponse.Header.Get("X-Amz-Request-Id")
    44  	}
    45  	if r.DataFilled() {
    46  		if err := UnmarshalResponse(r.HTTPResponse, r.Data, aws.BoolValue(r.Config.LowerCaseHeaderMaps)); err != nil {
    47  			r.Error = err
    48  		}
    49  	}
    50  }
    51  
    52  // UnmarshalResponse attempts to unmarshal the REST response headers to
    53  // the data type passed in. The type must be a pointer. An error is returned
    54  // with any error unmarshaling the response into the target datatype.
    55  func UnmarshalResponse(resp *http.Response, data interface{}, lowerCaseHeaderMaps bool) error {
    56  	v := reflect.Indirect(reflect.ValueOf(data))
    57  	return unmarshalLocationElements(resp, v, lowerCaseHeaderMaps)
    58  }
    59  
    60  func unmarshalBody(r *request.Request, v reflect.Value) error {
    61  	if field, ok := v.Type().FieldByName("_"); ok {
    62  		if payloadName := field.Tag.Get("payload"); payloadName != "" {
    63  			pfield, _ := v.Type().FieldByName(payloadName)
    64  			if ptag := pfield.Tag.Get("type"); ptag != "" && ptag != "structure" {
    65  				payload := v.FieldByName(payloadName)
    66  				if payload.IsValid() {
    67  					switch payload.Interface().(type) {
    68  					case []byte:
    69  						defer r.HTTPResponse.Body.Close()
    70  						b, err := ioutil.ReadAll(r.HTTPResponse.Body)
    71  						if err != nil {
    72  							return awserr.New(request.ErrCodeSerialization, "failed to decode REST response", err)
    73  						}
    74  
    75  						payload.Set(reflect.ValueOf(b))
    76  
    77  					case *string:
    78  						defer r.HTTPResponse.Body.Close()
    79  						b, err := ioutil.ReadAll(r.HTTPResponse.Body)
    80  						if err != nil {
    81  							return awserr.New(request.ErrCodeSerialization, "failed to decode REST response", err)
    82  						}
    83  
    84  						str := string(b)
    85  						payload.Set(reflect.ValueOf(&str))
    86  
    87  					default:
    88  						switch payload.Type().String() {
    89  						case "io.ReadCloser":
    90  							payload.Set(reflect.ValueOf(r.HTTPResponse.Body))
    91  
    92  						case "io.ReadSeeker":
    93  							b, err := ioutil.ReadAll(r.HTTPResponse.Body)
    94  							if err != nil {
    95  								return awserr.New(request.ErrCodeSerialization,
    96  									"failed to read response body", err)
    97  							}
    98  							payload.Set(reflect.ValueOf(ioutil.NopCloser(bytes.NewReader(b))))
    99  
   100  						default:
   101  							io.Copy(ioutil.Discard, r.HTTPResponse.Body)
   102  							r.HTTPResponse.Body.Close()
   103  							return awserr.New(request.ErrCodeSerialization,
   104  								"failed to decode REST response",
   105  								fmt.Errorf("unknown payload type %s", payload.Type()))
   106  						}
   107  					}
   108  				}
   109  			}
   110  		}
   111  	}
   112  
   113  	return nil
   114  }
   115  
   116  func unmarshalLocationElements(resp *http.Response, v reflect.Value, lowerCaseHeaderMaps bool) error {
   117  	for i := 0; i < v.NumField(); i++ {
   118  		m, field := v.Field(i), v.Type().Field(i)
   119  		if n := field.Name; n[0:1] == strings.ToLower(n[0:1]) {
   120  			continue
   121  		}
   122  
   123  		if m.IsValid() {
   124  			name := field.Tag.Get("locationName")
   125  			if name == "" {
   126  				name = field.Name
   127  			}
   128  
   129  			switch field.Tag.Get("location") {
   130  			case "statusCode":
   131  				unmarshalStatusCode(m, resp.StatusCode)
   132  
   133  			case "header":
   134  				err := unmarshalHeader(m, resp.Header.Get(name), field.Tag)
   135  				if err != nil {
   136  					return awserr.New(request.ErrCodeSerialization, "failed to decode REST response", err)
   137  				}
   138  
   139  			case "headers":
   140  				prefix := field.Tag.Get("locationName")
   141  				err := unmarshalHeaderMap(m, resp.Header, prefix, lowerCaseHeaderMaps)
   142  				if err != nil {
   143  					awserr.New(request.ErrCodeSerialization, "failed to decode REST response", err)
   144  				}
   145  			}
   146  		}
   147  	}
   148  
   149  	return nil
   150  }
   151  
   152  func unmarshalStatusCode(v reflect.Value, statusCode int) {
   153  	if !v.IsValid() {
   154  		return
   155  	}
   156  
   157  	switch v.Interface().(type) {
   158  	case *int64:
   159  		s := int64(statusCode)
   160  		v.Set(reflect.ValueOf(&s))
   161  	}
   162  }
   163  
   164  func unmarshalHeaderMap(r reflect.Value, headers http.Header, prefix string, normalize bool) error {
   165  	if len(headers) == 0 {
   166  		return nil
   167  	}
   168  	switch r.Interface().(type) {
   169  	case map[string]*string: // we only support string map value types
   170  		out := map[string]*string{}
   171  		for k, v := range headers {
   172  			if awsStrings.HasPrefixFold(k, prefix) {
   173  				if normalize == true {
   174  					k = strings.ToLower(k)
   175  				} else {
   176  					k = http.CanonicalHeaderKey(k)
   177  				}
   178  				out[k[len(prefix):]] = &v[0]
   179  			}
   180  		}
   181  		if len(out) != 0 {
   182  			r.Set(reflect.ValueOf(out))
   183  		}
   184  
   185  	}
   186  	return nil
   187  }
   188  
   189  func unmarshalHeader(v reflect.Value, header string, tag reflect.StructTag) error {
   190  	switch tag.Get("type") {
   191  	case "jsonvalue":
   192  		if len(header) == 0 {
   193  			return nil
   194  		}
   195  	case "blob":
   196  		if len(header) == 0 {
   197  			return nil
   198  		}
   199  	default:
   200  		if !v.IsValid() || (header == "" && v.Elem().Kind() != reflect.String) {
   201  			return nil
   202  		}
   203  	}
   204  
   205  	switch v.Interface().(type) {
   206  	case *string:
   207  		v.Set(reflect.ValueOf(&header))
   208  	case []byte:
   209  		b, err := base64.StdEncoding.DecodeString(header)
   210  		if err != nil {
   211  			return err
   212  		}
   213  		v.Set(reflect.ValueOf(b))
   214  	case *bool:
   215  		b, err := strconv.ParseBool(header)
   216  		if err != nil {
   217  			return err
   218  		}
   219  		v.Set(reflect.ValueOf(&b))
   220  	case *int64:
   221  		i, err := strconv.ParseInt(header, 10, 64)
   222  		if err != nil {
   223  			return err
   224  		}
   225  		v.Set(reflect.ValueOf(&i))
   226  	case *float64:
   227  		f, err := strconv.ParseFloat(header, 64)
   228  		if err != nil {
   229  			return err
   230  		}
   231  		v.Set(reflect.ValueOf(&f))
   232  	case *time.Time:
   233  		format := tag.Get("timestampFormat")
   234  		if len(format) == 0 {
   235  			format = protocol.RFC822TimeFormatName
   236  		}
   237  		t, err := protocol.ParseTime(format, header)
   238  		if err != nil {
   239  			return err
   240  		}
   241  		v.Set(reflect.ValueOf(&t))
   242  	case aws.JSONValue:
   243  		escaping := protocol.NoEscape
   244  		if tag.Get("location") == "header" {
   245  			escaping = protocol.Base64Escape
   246  		}
   247  		m, err := protocol.DecodeJSONValue(header, escaping)
   248  		if err != nil {
   249  			return err
   250  		}
   251  		v.Set(reflect.ValueOf(m))
   252  	default:
   253  		err := fmt.Errorf("Unsupported value for param %v (%s)", v.Interface(), v.Type())
   254  		return err
   255  	}
   256  	return nil
   257  }