github.com/aavshr/aws-sdk-go@v1.41.3/private/protocol/rest/unmarshal.go (about) 1 package rest 2 3 import ( 4 "bytes" 5 "encoding/base64" 6 "fmt" 7 "io" 8 "io/ioutil" 9 "net/http" 10 "reflect" 11 "strconv" 12 "strings" 13 "time" 14 15 "github.com/aavshr/aws-sdk-go/aws" 16 "github.com/aavshr/aws-sdk-go/aws/awserr" 17 "github.com/aavshr/aws-sdk-go/aws/request" 18 awsStrings "github.com/aavshr/aws-sdk-go/internal/strings" 19 "github.com/aavshr/aws-sdk-go/private/protocol" 20 ) 21 22 // UnmarshalHandler is a named request handler for unmarshaling rest protocol requests 23 var UnmarshalHandler = request.NamedHandler{Name: "awssdk.rest.Unmarshal", Fn: Unmarshal} 24 25 // UnmarshalMetaHandler is a named request handler for unmarshaling rest protocol request metadata 26 var UnmarshalMetaHandler = request.NamedHandler{Name: "awssdk.rest.UnmarshalMeta", Fn: UnmarshalMeta} 27 28 // Unmarshal unmarshals the REST component of a response in a REST service. 29 func Unmarshal(r *request.Request) { 30 if r.DataFilled() { 31 v := reflect.Indirect(reflect.ValueOf(r.Data)) 32 if err := unmarshalBody(r, v); err != nil { 33 r.Error = err 34 } 35 } 36 } 37 38 // UnmarshalMeta unmarshals the REST metadata of a response in a REST service 39 func UnmarshalMeta(r *request.Request) { 40 r.RequestID = r.HTTPResponse.Header.Get("X-Amzn-Requestid") 41 if r.RequestID == "" { 42 // Alternative version of request id in the header 43 r.RequestID = r.HTTPResponse.Header.Get("X-Amz-Request-Id") 44 } 45 if r.DataFilled() { 46 if err := UnmarshalResponse(r.HTTPResponse, r.Data, aws.BoolValue(r.Config.LowerCaseHeaderMaps)); err != nil { 47 r.Error = err 48 } 49 } 50 } 51 52 // UnmarshalResponse attempts to unmarshal the REST response headers to 53 // the data type passed in. The type must be a pointer. An error is returned 54 // with any error unmarshaling the response into the target datatype. 55 func UnmarshalResponse(resp *http.Response, data interface{}, lowerCaseHeaderMaps bool) error { 56 v := reflect.Indirect(reflect.ValueOf(data)) 57 return unmarshalLocationElements(resp, v, lowerCaseHeaderMaps) 58 } 59 60 func unmarshalBody(r *request.Request, v reflect.Value) error { 61 if field, ok := v.Type().FieldByName("_"); ok { 62 if payloadName := field.Tag.Get("payload"); payloadName != "" { 63 pfield, _ := v.Type().FieldByName(payloadName) 64 if ptag := pfield.Tag.Get("type"); ptag != "" && ptag != "structure" { 65 payload := v.FieldByName(payloadName) 66 if payload.IsValid() { 67 switch payload.Interface().(type) { 68 case []byte: 69 defer r.HTTPResponse.Body.Close() 70 b, err := ioutil.ReadAll(r.HTTPResponse.Body) 71 if err != nil { 72 return awserr.New(request.ErrCodeSerialization, "failed to decode REST response", err) 73 } 74 75 payload.Set(reflect.ValueOf(b)) 76 77 case *string: 78 defer r.HTTPResponse.Body.Close() 79 b, err := ioutil.ReadAll(r.HTTPResponse.Body) 80 if err != nil { 81 return awserr.New(request.ErrCodeSerialization, "failed to decode REST response", err) 82 } 83 84 str := string(b) 85 payload.Set(reflect.ValueOf(&str)) 86 87 default: 88 switch payload.Type().String() { 89 case "io.ReadCloser": 90 payload.Set(reflect.ValueOf(r.HTTPResponse.Body)) 91 92 case "io.ReadSeeker": 93 b, err := ioutil.ReadAll(r.HTTPResponse.Body) 94 if err != nil { 95 return awserr.New(request.ErrCodeSerialization, 96 "failed to read response body", err) 97 } 98 payload.Set(reflect.ValueOf(ioutil.NopCloser(bytes.NewReader(b)))) 99 100 default: 101 io.Copy(ioutil.Discard, r.HTTPResponse.Body) 102 r.HTTPResponse.Body.Close() 103 return awserr.New(request.ErrCodeSerialization, 104 "failed to decode REST response", 105 fmt.Errorf("unknown payload type %s", payload.Type())) 106 } 107 } 108 } 109 } 110 } 111 } 112 113 return nil 114 } 115 116 func unmarshalLocationElements(resp *http.Response, v reflect.Value, lowerCaseHeaderMaps bool) error { 117 for i := 0; i < v.NumField(); i++ { 118 m, field := v.Field(i), v.Type().Field(i) 119 if n := field.Name; n[0:1] == strings.ToLower(n[0:1]) { 120 continue 121 } 122 123 if m.IsValid() { 124 name := field.Tag.Get("locationName") 125 if name == "" { 126 name = field.Name 127 } 128 129 switch field.Tag.Get("location") { 130 case "statusCode": 131 unmarshalStatusCode(m, resp.StatusCode) 132 133 case "header": 134 err := unmarshalHeader(m, resp.Header.Get(name), field.Tag) 135 if err != nil { 136 return awserr.New(request.ErrCodeSerialization, "failed to decode REST response", err) 137 } 138 139 case "headers": 140 prefix := field.Tag.Get("locationName") 141 err := unmarshalHeaderMap(m, resp.Header, prefix, lowerCaseHeaderMaps) 142 if err != nil { 143 awserr.New(request.ErrCodeSerialization, "failed to decode REST response", err) 144 } 145 } 146 } 147 } 148 149 return nil 150 } 151 152 func unmarshalStatusCode(v reflect.Value, statusCode int) { 153 if !v.IsValid() { 154 return 155 } 156 157 switch v.Interface().(type) { 158 case *int64: 159 s := int64(statusCode) 160 v.Set(reflect.ValueOf(&s)) 161 } 162 } 163 164 func unmarshalHeaderMap(r reflect.Value, headers http.Header, prefix string, normalize bool) error { 165 if len(headers) == 0 { 166 return nil 167 } 168 switch r.Interface().(type) { 169 case map[string]*string: // we only support string map value types 170 out := map[string]*string{} 171 for k, v := range headers { 172 if awsStrings.HasPrefixFold(k, prefix) { 173 if normalize == true { 174 k = strings.ToLower(k) 175 } else { 176 k = http.CanonicalHeaderKey(k) 177 } 178 out[k[len(prefix):]] = &v[0] 179 } 180 } 181 if len(out) != 0 { 182 r.Set(reflect.ValueOf(out)) 183 } 184 185 } 186 return nil 187 } 188 189 func unmarshalHeader(v reflect.Value, header string, tag reflect.StructTag) error { 190 switch tag.Get("type") { 191 case "jsonvalue": 192 if len(header) == 0 { 193 return nil 194 } 195 case "blob": 196 if len(header) == 0 { 197 return nil 198 } 199 default: 200 if !v.IsValid() || (header == "" && v.Elem().Kind() != reflect.String) { 201 return nil 202 } 203 } 204 205 switch v.Interface().(type) { 206 case *string: 207 v.Set(reflect.ValueOf(&header)) 208 case []byte: 209 b, err := base64.StdEncoding.DecodeString(header) 210 if err != nil { 211 return err 212 } 213 v.Set(reflect.ValueOf(b)) 214 case *bool: 215 b, err := strconv.ParseBool(header) 216 if err != nil { 217 return err 218 } 219 v.Set(reflect.ValueOf(&b)) 220 case *int64: 221 i, err := strconv.ParseInt(header, 10, 64) 222 if err != nil { 223 return err 224 } 225 v.Set(reflect.ValueOf(&i)) 226 case *float64: 227 f, err := strconv.ParseFloat(header, 64) 228 if err != nil { 229 return err 230 } 231 v.Set(reflect.ValueOf(&f)) 232 case *time.Time: 233 format := tag.Get("timestampFormat") 234 if len(format) == 0 { 235 format = protocol.RFC822TimeFormatName 236 } 237 t, err := protocol.ParseTime(format, header) 238 if err != nil { 239 return err 240 } 241 v.Set(reflect.ValueOf(&t)) 242 case aws.JSONValue: 243 escaping := protocol.NoEscape 244 if tag.Get("location") == "header" { 245 escaping = protocol.Base64Escape 246 } 247 m, err := protocol.DecodeJSONValue(header, escaping) 248 if err != nil { 249 return err 250 } 251 v.Set(reflect.ValueOf(m)) 252 default: 253 err := fmt.Errorf("Unsupported value for param %v (%s)", v.Interface(), v.Type()) 254 return err 255 } 256 return nil 257 }