github.com/treeverse/lakefs@v1.24.1-0.20240520134607-95648127bfb0/pkg/gateway/sig/v4.go (about)

     1  package sig
     2  
     3  import (
     4  	"bufio"
     5  	"crypto/hmac"
     6  	"crypto/sha256"
     7  	"encoding/hex"
     8  	"fmt"
     9  	"io"
    10  	"net/http"
    11  	"net/url"
    12  	"regexp"
    13  	"sort"
    14  	"strconv"
    15  	"strings"
    16  	"time"
    17  	"unicode"
    18  
    19  	"github.com/treeverse/lakefs/pkg/auth/model"
    20  	"github.com/treeverse/lakefs/pkg/gateway/errors"
    21  )
    22  
    23  const (
    24  	V4authHeaderName        = "Authorization"
    25  	V4authHeaderPrefix      = "AWS4-HMAC-SHA256"
    26  	AmzDecodedContentLength = "X-Amz-Decoded-Content-Length"
    27  	v4StreamingPayloadHash  = "STREAMING-AWS4-HMAC-SHA256-PAYLOAD"
    28  	v4UnsignedPayload       = "UNSIGNED-PAYLOAD"
    29  	v4authHeaderPayload     = "x-amz-content-sha256"
    30  	v4scopeTerminator       = "aws4_request"
    31  	v4timeFormat            = "20060102T150405Z"
    32  	v4shortTimeFormat       = "20060102"
    33  	v4SignatureHeader       = "X-Amz-Signature"
    34  )
    35  
    36  var (
    37  	V4AuthHeaderRegexp      = regexp.MustCompile(`AWS4-HMAC-SHA256 Credential=(?P<AccessKeyId>.{3,20})/(?P<Date>\d{8})/(?P<Region>[\w\-]+)/(?P<Service>[\w\-]+)/aws4_request,\s*SignedHeaders=(?P<SignatureHeaders>[\w\-\;]+),\s*Signature=(?P<Signature>[abcdef0123456789]{64})`)
    38  	V4CredentialScopeRegexp = regexp.MustCompile(`(?P<AccessKeyId>.{3,20})/(?P<Date>\d{8})/(?P<Region>[\w\-]+)/(?P<Service>[\w\-]+)/aws4_request`)
    39  )
    40  
    41  type V4Auth struct {
    42  	AccessKeyID         string
    43  	Date                string
    44  	Region              string
    45  	Service             string
    46  	SignedHeaders       []string
    47  	SignedHeadersString string
    48  	Signature           string
    49  }
    50  
    51  func (a V4Auth) GetAccessKeyID() string {
    52  	return a.AccessKeyID
    53  }
    54  
    55  func splitHeaders(headers string) []string {
    56  	headerValues := strings.Split(headers, ";")
    57  	sort.Strings(headerValues)
    58  	return headerValues
    59  }
    60  
    61  func ParseV4AuthContext(r *http.Request) (V4Auth, error) {
    62  	var ctx V4Auth
    63  
    64  	// start by trying to extract the data from the Authorization header
    65  	headerValue := r.Header.Get(V4authHeaderName)
    66  	if len(headerValue) > 0 {
    67  		match := V4AuthHeaderRegexp.FindStringSubmatch(headerValue)
    68  		if len(match) == 0 {
    69  			return ctx, ErrHeaderMalformed
    70  		}
    71  		result := make(map[string]string)
    72  		for i, name := range V4AuthHeaderRegexp.SubexpNames() {
    73  			if i != 0 && name != "" {
    74  				result[name] = match[i]
    75  			}
    76  		}
    77  		ctx.AccessKeyID = result["AccessKeyId"]
    78  		ctx.Date = result["Date"]
    79  		ctx.Region = result["Region"]
    80  		ctx.Service = result["Service"]
    81  		ctx.Signature = result["Signature"]
    82  
    83  		signatureHeaders := result["SignatureHeaders"]
    84  		ctx.SignedHeaders = splitHeaders(signatureHeaders)
    85  		ctx.SignedHeadersString = signatureHeaders
    86  		return ctx, nil
    87  	}
    88  
    89  	// otherwise, see if we have all the required query parameters
    90  	query := r.URL.Query()
    91  	algorithm := query.Get("X-Amz-Algorithm")
    92  	if len(algorithm) == 0 || !strings.EqualFold(algorithm, V4authHeaderPrefix) {
    93  		return ctx, errors.ErrInvalidQuerySignatureAlgo
    94  	}
    95  	credentialScope := query.Get("X-Amz-Credential")
    96  	if len(credentialScope) == 0 {
    97  		return ctx, errors.ErrMissingCredTag
    98  	}
    99  	credsMatch := V4CredentialScopeRegexp.FindStringSubmatch(credentialScope)
   100  	if len(credsMatch) == 0 {
   101  		return ctx, errors.ErrCredMalformed
   102  	}
   103  	credsResult := make(map[string]string)
   104  	for i, name := range V4CredentialScopeRegexp.SubexpNames() {
   105  		if i != 0 && name != "" {
   106  			credsResult[name] = credsMatch[i]
   107  		}
   108  	}
   109  	ctx.AccessKeyID = credsResult["AccessKeyId"]
   110  	ctx.Date = credsResult["Date"]
   111  	ctx.Region = credsResult["Region"]
   112  	ctx.Service = credsResult["Service"]
   113  
   114  	ctx.SignedHeadersString = query.Get("X-Amz-SignedHeaders")
   115  	headers := splitHeaders(ctx.SignedHeadersString)
   116  	ctx.SignedHeaders = headers
   117  	ctx.Signature = query.Get(v4SignatureHeader)
   118  	return ctx, nil
   119  }
   120  
   121  func V4Verify(auth V4Auth, credentials *model.Credential, r *http.Request) error {
   122  	ctx := &verificationCtx{
   123  		Request:   r,
   124  		Query:     r.URL.Query(),
   125  		AuthValue: auth,
   126  	}
   127  
   128  	canonicalRequest := ctx.buildCanonicalRequest()
   129  	stringToSign, err := ctx.buildSignedString(canonicalRequest)
   130  	if err != nil {
   131  		return err
   132  	}
   133  	// sign
   134  	signingKey := createSignature(credentials.SecretAccessKey, auth.Date, auth.Region, auth.Service)
   135  	signature := hex.EncodeToString(sign(signingKey, stringToSign))
   136  
   137  	// compare signatures
   138  	if !Equal([]byte(signature), []byte(auth.Signature)) {
   139  		return errors.ErrSignatureDoesNotMatch
   140  	}
   141  
   142  	// wrap body with verifier
   143  	reader, err := ctx.reader(r.Body, credentials)
   144  	if err != nil {
   145  		return err
   146  	}
   147  	r.Body = reader
   148  
   149  	// update to decoded content length
   150  	contentLength, err := ctx.contentLength()
   151  	if err != nil {
   152  		return err
   153  	}
   154  	r.ContentLength = contentLength
   155  	return nil
   156  }
   157  
   158  type verificationCtx struct {
   159  	Request   *http.Request
   160  	Query     url.Values
   161  	AuthValue V4Auth
   162  }
   163  
   164  func (ctx *verificationCtx) queryEscape(str string) string {
   165  	return strings.ReplaceAll(url.QueryEscape(str), "+", "%20")
   166  }
   167  
   168  func (ctx *verificationCtx) canonicalizeQueryString() string {
   169  	queryNames := make([]string, 0, len(ctx.Query))
   170  	for k := range ctx.Query {
   171  		if k == v4SignatureHeader {
   172  			continue
   173  		}
   174  		queryNames = append(queryNames, k)
   175  	}
   176  	sort.Strings(queryNames)
   177  	buf := make([]string, len(queryNames))
   178  	for i, key := range queryNames {
   179  		buf[i] = fmt.Sprintf("%s=%s", ctx.queryEscape(key), ctx.queryEscape(ctx.Query.Get(key)))
   180  	}
   181  	return strings.Join(buf, "&")
   182  }
   183  
   184  func (ctx *verificationCtx) canonicalizeHeaders(headers []string) string {
   185  	var buf strings.Builder
   186  	for _, header := range headers {
   187  		var value string
   188  		if strings.EqualFold(strings.ToLower(header), "host") {
   189  			// in Go, Host is removed from the headers and is promoted to request.Host for some reason
   190  			value = ctx.Request.Host
   191  		} else {
   192  			value = getInsensitiveHeader(ctx.Request, header)
   193  		}
   194  		buf.WriteString(header)
   195  		buf.WriteString(":")
   196  		buf.WriteString(ctx.trimAll(value))
   197  		buf.WriteString("\n")
   198  	}
   199  	return buf.String()
   200  }
   201  
   202  func (ctx *verificationCtx) trimAll(str string) string {
   203  	str = strings.TrimSpace(str)
   204  	inSpace := false
   205  	var buf strings.Builder
   206  	for _, ch := range str {
   207  		if unicode.IsSpace(ch) {
   208  			if !inSpace {
   209  				// first space to appear
   210  				buf.WriteRune(ch)
   211  				inSpace = true
   212  			}
   213  		} else {
   214  			// not a space
   215  			buf.WriteRune(ch)
   216  			inSpace = false
   217  		}
   218  	}
   219  	return buf.String()
   220  }
   221  
   222  func getInsensitiveHeader(r *http.Request, headerName string) string {
   223  	for k, v := range r.Header {
   224  		if strings.EqualFold(k, headerName) {
   225  			return v[0]
   226  		}
   227  	}
   228  	return ""
   229  }
   230  
   231  func (ctx *verificationCtx) payloadHash() string {
   232  	payloadHash := getInsensitiveHeader(ctx.Request, v4authHeaderPayload)
   233  	if payloadHash == "" {
   234  		return v4UnsignedPayload
   235  	}
   236  	return payloadHash
   237  }
   238  
   239  func (ctx *verificationCtx) buildCanonicalRequest() string {
   240  	// Step 1: Canonical request
   241  	method := ctx.Request.Method
   242  	canonicalURI := EncodePath(ctx.Request.URL.Path)
   243  	canonicalQueryString := ctx.canonicalizeQueryString()
   244  	canonicalHeaders := ctx.canonicalizeHeaders(ctx.AuthValue.SignedHeaders)
   245  	signedHeaders := ctx.AuthValue.SignedHeadersString
   246  	payloadHash := ctx.payloadHash()
   247  	canonicalRequest := strings.Join([]string{
   248  		method,
   249  		canonicalURI,
   250  		canonicalQueryString,
   251  		canonicalHeaders,
   252  		signedHeaders,
   253  		payloadHash,
   254  	}, "\n")
   255  	return canonicalRequest
   256  }
   257  
   258  func (ctx *verificationCtx) getAmzDate() (string, error) {
   259  	// https://docs.aws.amazon.com/general/latest/gr/sigv4-date-handling.html
   260  	amzDate := ctx.Request.URL.Query().Get("X-Amz-Date")
   261  	if len(amzDate) == 0 {
   262  		amzDate = ctx.Request.Header.Get("x-amz-date")
   263  		if len(amzDate) == 0 {
   264  			amzDate = ctx.Request.Header.Get("date")
   265  			if len(amzDate) == 0 {
   266  				return "", errors.ErrMissingDateHeader
   267  			}
   268  		}
   269  	}
   270  
   271  	// parse date
   272  	ts, err := time.Parse(v4timeFormat, amzDate)
   273  	if err != nil {
   274  		return "", errors.ErrMalformedDate
   275  	}
   276  
   277  	// parse signature date
   278  	sigTS, err := time.Parse(v4shortTimeFormat, ctx.AuthValue.Date)
   279  	if err != nil {
   280  		return "", errors.ErrMalformedCredentialDate
   281  	}
   282  
   283  	// ensure same date
   284  	if sigTS.Year() != ts.Year() || sigTS.Month() != ts.Month() || sigTS.Day() != ts.Day() {
   285  		return "", errors.ErrMalformedCredentialDate
   286  	}
   287  
   288  	return amzDate, nil
   289  }
   290  
   291  func sign(key []byte, msg string) []byte {
   292  	h := hmac.New(sha256.New, key)
   293  	_, _ = h.Write([]byte(msg))
   294  	return h.Sum(nil)
   295  }
   296  
   297  func createSignature(key, dateStamp, region, service string) []byte {
   298  	kDate := sign([]byte(fmt.Sprintf("AWS4%s", key)), dateStamp)
   299  	kRegion := sign(kDate, region)
   300  	kService := sign(kRegion, service)
   301  	kSigning := sign(kService, v4scopeTerminator)
   302  	return kSigning
   303  }
   304  
   305  func (ctx *verificationCtx) buildSignedString(canonicalRequest string) (string, error) {
   306  	// Step 2: Create string to sign
   307  	algorithm := V4authHeaderPrefix
   308  	credentialScope := strings.Join([]string{
   309  		ctx.AuthValue.Date,
   310  		ctx.AuthValue.Region,
   311  		ctx.AuthValue.Service,
   312  		v4scopeTerminator,
   313  	}, "/")
   314  	amzDate, err := ctx.getAmzDate()
   315  	if err != nil {
   316  		return "", err
   317  	}
   318  	h := sha256.Sum256([]byte(canonicalRequest))
   319  	hashedCanonicalRequest := hex.EncodeToString(h[:])
   320  	stringToSign := strings.Join([]string{
   321  		algorithm,
   322  		amzDate,
   323  		credentialScope,
   324  		hashedCanonicalRequest,
   325  	}, "\n")
   326  	return stringToSign, nil
   327  }
   328  
   329  func (ctx *verificationCtx) isStreaming() bool {
   330  	payloadHash := ctx.payloadHash()
   331  	return strings.EqualFold(payloadHash, v4StreamingPayloadHash)
   332  }
   333  
   334  func (ctx *verificationCtx) isUnsigned() bool {
   335  	return strings.EqualFold(ctx.payloadHash(), v4UnsignedPayload)
   336  }
   337  
   338  func (ctx *verificationCtx) contentLength() (int64, error) {
   339  	size := ctx.Request.ContentLength
   340  	if ctx.isStreaming() {
   341  		if sizeStr, ok := ctx.Request.Header[AmzDecodedContentLength]; ok {
   342  			if sizeStr[0] == "" {
   343  				return 0, errors.ErrMissingContentLength
   344  			}
   345  			var err error
   346  			size, err = strconv.ParseInt(sizeStr[0], 10, 64) //nolint: mnd
   347  			if err != nil {
   348  				return 0, err
   349  			}
   350  		}
   351  	}
   352  	return size, nil
   353  }
   354  
   355  func (ctx *verificationCtx) reader(reader io.ReadCloser, creds *model.Credential) (io.ReadCloser, error) {
   356  	if ctx.isStreaming() {
   357  		amzDate, err := ctx.getAmzDate()
   358  		if err != nil {
   359  			return nil, err
   360  		}
   361  		chunkReader, err := newSignV4ChunkedReader(bufio.NewReader(reader), amzDate, ctx.AuthValue, creds)
   362  		if err != nil {
   363  			return nil, err
   364  		}
   365  		return chunkReader, nil
   366  	}
   367  
   368  	if ctx.isUnsigned() {
   369  		return reader, nil
   370  	}
   371  	return NewSha265Reader(reader, ctx.payloadHash())
   372  }
   373  
   374  type V4Authenticator struct {
   375  	request *http.Request
   376  	sigCtx  V4Auth
   377  }
   378  
   379  func (a *V4Authenticator) Parse() (SigContext, error) {
   380  	sigCtx, err := ParseV4AuthContext(a.request)
   381  	if err != nil {
   382  		return nil, err
   383  	}
   384  	a.sigCtx = sigCtx
   385  	return a.sigCtx, nil
   386  }
   387  
   388  func (a *V4Authenticator) String() string {
   389  	return "sigv4"
   390  }
   391  
   392  func (a *V4Authenticator) Verify(creds *model.Credential) error {
   393  	return V4Verify(a.sigCtx, creds, a.request)
   394  }
   395  
   396  func NewV4Authenticator(r *http.Request) *V4Authenticator {
   397  	return &V4Authenticator{
   398  		request: r,
   399  		sigCtx:  V4Auth{},
   400  	}
   401  }