github.com/aavshr/aws-sdk-go@v1.41.3/private/protocol/rest/build.go (about)

     1  // Package rest provides RESTful serialization of AWS requests and responses.
     2  package rest
     3  
     4  import (
     5  	"bytes"
     6  	"encoding/base64"
     7  	"fmt"
     8  	"io"
     9  	"net/http"
    10  	"net/url"
    11  	"path"
    12  	"reflect"
    13  	"strconv"
    14  	"strings"
    15  	"time"
    16  
    17  	"github.com/aavshr/aws-sdk-go/aws"
    18  	"github.com/aavshr/aws-sdk-go/aws/awserr"
    19  	"github.com/aavshr/aws-sdk-go/aws/request"
    20  	"github.com/aavshr/aws-sdk-go/private/protocol"
    21  )
    22  
    23  // Whether the byte value can be sent without escaping in AWS URLs
    24  var noEscape [256]bool
    25  
    26  var errValueNotSet = fmt.Errorf("value not set")
    27  
    28  var byteSliceType = reflect.TypeOf([]byte{})
    29  
    30  func init() {
    31  	for i := 0; i < len(noEscape); i++ {
    32  		// AWS expects every character except these to be escaped
    33  		noEscape[i] = (i >= 'A' && i <= 'Z') ||
    34  			(i >= 'a' && i <= 'z') ||
    35  			(i >= '0' && i <= '9') ||
    36  			i == '-' ||
    37  			i == '.' ||
    38  			i == '_' ||
    39  			i == '~'
    40  	}
    41  }
    42  
    43  // BuildHandler is a named request handler for building rest protocol requests
    44  var BuildHandler = request.NamedHandler{Name: "awssdk.rest.Build", Fn: Build}
    45  
    46  // Build builds the REST component of a service request.
    47  func Build(r *request.Request) {
    48  	if r.ParamsFilled() {
    49  		v := reflect.ValueOf(r.Params).Elem()
    50  		buildLocationElements(r, v, false)
    51  		buildBody(r, v)
    52  	}
    53  }
    54  
    55  // BuildAsGET builds the REST component of a service request with the ability to hoist
    56  // data from the body.
    57  func BuildAsGET(r *request.Request) {
    58  	if r.ParamsFilled() {
    59  		v := reflect.ValueOf(r.Params).Elem()
    60  		buildLocationElements(r, v, true)
    61  		buildBody(r, v)
    62  	}
    63  }
    64  
    65  func buildLocationElements(r *request.Request, v reflect.Value, buildGETQuery bool) {
    66  	query := r.HTTPRequest.URL.Query()
    67  
    68  	// Setup the raw path to match the base path pattern. This is needed
    69  	// so that when the path is mutated a custom escaped version can be
    70  	// stored in RawPath that will be used by the Go client.
    71  	r.HTTPRequest.URL.RawPath = r.HTTPRequest.URL.Path
    72  
    73  	for i := 0; i < v.NumField(); i++ {
    74  		m := v.Field(i)
    75  		if n := v.Type().Field(i).Name; n[0:1] == strings.ToLower(n[0:1]) {
    76  			continue
    77  		}
    78  
    79  		if m.IsValid() {
    80  			field := v.Type().Field(i)
    81  			name := field.Tag.Get("locationName")
    82  			if name == "" {
    83  				name = field.Name
    84  			}
    85  			if kind := m.Kind(); kind == reflect.Ptr {
    86  				m = m.Elem()
    87  			} else if kind == reflect.Interface {
    88  				if !m.Elem().IsValid() {
    89  					continue
    90  				}
    91  			}
    92  			if !m.IsValid() {
    93  				continue
    94  			}
    95  			if field.Tag.Get("ignore") != "" {
    96  				continue
    97  			}
    98  
    99  			// Support the ability to customize values to be marshaled as a
   100  			// blob even though they were modeled as a string. Required for S3
   101  			// API operations like SSECustomerKey is modeled as string but
   102  			// required to be base64 encoded in request.
   103  			if field.Tag.Get("marshal-as") == "blob" {
   104  				m = m.Convert(byteSliceType)
   105  			}
   106  
   107  			var err error
   108  			switch field.Tag.Get("location") {
   109  			case "headers": // header maps
   110  				err = buildHeaderMap(&r.HTTPRequest.Header, m, field.Tag)
   111  			case "header":
   112  				err = buildHeader(&r.HTTPRequest.Header, m, name, field.Tag)
   113  			case "uri":
   114  				err = buildURI(r.HTTPRequest.URL, m, name, field.Tag)
   115  			case "querystring":
   116  				err = buildQueryString(query, m, name, field.Tag)
   117  			default:
   118  				if buildGETQuery {
   119  					err = buildQueryString(query, m, name, field.Tag)
   120  				}
   121  			}
   122  			r.Error = err
   123  		}
   124  		if r.Error != nil {
   125  			return
   126  		}
   127  	}
   128  
   129  	r.HTTPRequest.URL.RawQuery = query.Encode()
   130  	if !aws.BoolValue(r.Config.DisableRestProtocolURICleaning) {
   131  		cleanPath(r.HTTPRequest.URL)
   132  	}
   133  }
   134  
   135  func buildBody(r *request.Request, v reflect.Value) {
   136  	if field, ok := v.Type().FieldByName("_"); ok {
   137  		if payloadName := field.Tag.Get("payload"); payloadName != "" {
   138  			pfield, _ := v.Type().FieldByName(payloadName)
   139  			if ptag := pfield.Tag.Get("type"); ptag != "" && ptag != "structure" {
   140  				payload := reflect.Indirect(v.FieldByName(payloadName))
   141  				if payload.IsValid() && payload.Interface() != nil {
   142  					switch reader := payload.Interface().(type) {
   143  					case io.ReadSeeker:
   144  						r.SetReaderBody(reader)
   145  					case []byte:
   146  						r.SetBufferBody(reader)
   147  					case string:
   148  						r.SetStringBody(reader)
   149  					default:
   150  						r.Error = awserr.New(request.ErrCodeSerialization,
   151  							"failed to encode REST request",
   152  							fmt.Errorf("unknown payload type %s", payload.Type()))
   153  					}
   154  				}
   155  			}
   156  		}
   157  	}
   158  }
   159  
   160  func buildHeader(header *http.Header, v reflect.Value, name string, tag reflect.StructTag) error {
   161  	str, err := convertType(v, tag)
   162  	if err == errValueNotSet {
   163  		return nil
   164  	} else if err != nil {
   165  		return awserr.New(request.ErrCodeSerialization, "failed to encode REST request", err)
   166  	}
   167  
   168  	name = strings.TrimSpace(name)
   169  	str = strings.TrimSpace(str)
   170  
   171  	header.Add(name, str)
   172  
   173  	return nil
   174  }
   175  
   176  func buildHeaderMap(header *http.Header, v reflect.Value, tag reflect.StructTag) error {
   177  	prefix := tag.Get("locationName")
   178  	for _, key := range v.MapKeys() {
   179  		str, err := convertType(v.MapIndex(key), tag)
   180  		if err == errValueNotSet {
   181  			continue
   182  		} else if err != nil {
   183  			return awserr.New(request.ErrCodeSerialization, "failed to encode REST request", err)
   184  
   185  		}
   186  		keyStr := strings.TrimSpace(key.String())
   187  		str = strings.TrimSpace(str)
   188  
   189  		header.Add(prefix+keyStr, str)
   190  	}
   191  	return nil
   192  }
   193  
   194  func buildURI(u *url.URL, v reflect.Value, name string, tag reflect.StructTag) error {
   195  	value, err := convertType(v, tag)
   196  	if err == errValueNotSet {
   197  		return nil
   198  	} else if err != nil {
   199  		return awserr.New(request.ErrCodeSerialization, "failed to encode REST request", err)
   200  	}
   201  
   202  	u.Path = strings.Replace(u.Path, "{"+name+"}", value, -1)
   203  	u.Path = strings.Replace(u.Path, "{"+name+"+}", value, -1)
   204  
   205  	u.RawPath = strings.Replace(u.RawPath, "{"+name+"}", EscapePath(value, true), -1)
   206  	u.RawPath = strings.Replace(u.RawPath, "{"+name+"+}", EscapePath(value, false), -1)
   207  
   208  	return nil
   209  }
   210  
   211  func buildQueryString(query url.Values, v reflect.Value, name string, tag reflect.StructTag) error {
   212  	switch value := v.Interface().(type) {
   213  	case []*string:
   214  		for _, item := range value {
   215  			query.Add(name, *item)
   216  		}
   217  	case map[string]*string:
   218  		for key, item := range value {
   219  			query.Add(key, *item)
   220  		}
   221  	case map[string][]*string:
   222  		for key, items := range value {
   223  			for _, item := range items {
   224  				query.Add(key, *item)
   225  			}
   226  		}
   227  	default:
   228  		str, err := convertType(v, tag)
   229  		if err == errValueNotSet {
   230  			return nil
   231  		} else if err != nil {
   232  			return awserr.New(request.ErrCodeSerialization, "failed to encode REST request", err)
   233  		}
   234  		query.Set(name, str)
   235  	}
   236  
   237  	return nil
   238  }
   239  
   240  func cleanPath(u *url.URL) {
   241  	hasSlash := strings.HasSuffix(u.Path, "/")
   242  
   243  	// clean up path, removing duplicate `/`
   244  	u.Path = path.Clean(u.Path)
   245  	u.RawPath = path.Clean(u.RawPath)
   246  
   247  	if hasSlash && !strings.HasSuffix(u.Path, "/") {
   248  		u.Path += "/"
   249  		u.RawPath += "/"
   250  	}
   251  }
   252  
   253  // EscapePath escapes part of a URL path in Amazon style
   254  func EscapePath(path string, encodeSep bool) string {
   255  	var buf bytes.Buffer
   256  	for i := 0; i < len(path); i++ {
   257  		c := path[i]
   258  		if noEscape[c] || (c == '/' && !encodeSep) {
   259  			buf.WriteByte(c)
   260  		} else {
   261  			fmt.Fprintf(&buf, "%%%02X", c)
   262  		}
   263  	}
   264  	return buf.String()
   265  }
   266  
   267  func convertType(v reflect.Value, tag reflect.StructTag) (str string, err error) {
   268  	v = reflect.Indirect(v)
   269  	if !v.IsValid() {
   270  		return "", errValueNotSet
   271  	}
   272  
   273  	switch value := v.Interface().(type) {
   274  	case string:
   275  		str = value
   276  	case []byte:
   277  		str = base64.StdEncoding.EncodeToString(value)
   278  	case bool:
   279  		str = strconv.FormatBool(value)
   280  	case int64:
   281  		str = strconv.FormatInt(value, 10)
   282  	case float64:
   283  		str = strconv.FormatFloat(value, 'f', -1, 64)
   284  	case time.Time:
   285  		format := tag.Get("timestampFormat")
   286  		if len(format) == 0 {
   287  			format = protocol.RFC822TimeFormatName
   288  			if tag.Get("location") == "querystring" {
   289  				format = protocol.ISO8601TimeFormatName
   290  			}
   291  		}
   292  		str = protocol.FormatTime(format, value)
   293  	case aws.JSONValue:
   294  		if len(value) == 0 {
   295  			return "", errValueNotSet
   296  		}
   297  		escaping := protocol.NoEscape
   298  		if tag.Get("location") == "header" {
   299  			escaping = protocol.Base64Escape
   300  		}
   301  		str, err = protocol.EncodeJSONValue(value, escaping)
   302  		if err != nil {
   303  			return "", fmt.Errorf("unable to encode JSONValue, %v", err)
   304  		}
   305  	default:
   306  		err := fmt.Errorf("unsupported value for param %v (%s)", v.Interface(), v.Type())
   307  		return "", err
   308  	}
   309  	return str, nil
   310  }