github.com/kyma-incubator/compass/components/director@v0.0.0-20230623144113-d764f56ff805/pkg/correlation/middleware.go (about)

     1  /*
     2   * Copyright 2020 The Compass Authors
     3   *
     4   * Licensed under the Apache License, Version 2.0 (the "License");
     5   * you may not use this file except in compliance with the License.
     6   * You may obtain a copy of the License at
     7   *
     8   *     http://www.apache.org/licenses/LICENSE-2.0
     9   *
    10   * Unless required by applicable law or agreed to in writing, software
    11   * distributed under the License is distributed on an "AS IS" BASIS,
    12   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    13   * See the License for the specific language governing permissions and
    14   * limitations under the License.
    15   */
    16  
    17  package correlation
    18  
    19  import (
    20  	"context"
    21  	"net/http"
    22  
    23  	"github.com/sirupsen/logrus"
    24  
    25  	"github.com/google/uuid"
    26  )
    27  
    28  type contextKey string
    29  
    30  // HeadersContextKey missing godoc
    31  const HeadersContextKey contextKey = "CorrelationHeaders"
    32  
    33  // RequestIDHeaderKey missing godoc
    34  const RequestIDHeaderKey = "x-request-id"
    35  
    36  // headerKeys are the expected headers that are used for distributed tracing.
    37  var headerKeys = []string{"x-request-id", "x-b3-traceid", "x-b3-spanid", "x-b3-parentspanid", "x-b3-sampled", "x-b3-flags", "b3"}
    38  
    39  // Headers missing godoc
    40  type Headers map[string]string
    41  
    42  // CorrelationIDForRequest returns the correlation ID for the current request
    43  func CorrelationIDForRequest(request *http.Request) string {
    44  	return HeadersForRequest(request)[RequestIDHeaderKey]
    45  }
    46  
    47  // AttachCorrelationIDToContext returns middleware that attaches all headers used for tracing in the current request.
    48  func AttachCorrelationIDToContext() func(next http.Handler) http.Handler {
    49  	return func(next http.Handler) http.Handler {
    50  		return http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
    51  			ctx := r.Context()
    52  			if correlationHeaders := HeadersForRequest(r); len(correlationHeaders) != 0 {
    53  				ctx = SaveToContext(ctx, correlationHeaders)
    54  				r = r.WithContext(ctx)
    55  			}
    56  
    57  			next.ServeHTTP(rw, r)
    58  		})
    59  	}
    60  }
    61  
    62  // HeadersForRequest returns all http headers used for tracing of the passed request.
    63  // If the request headers are not set, but are part of the context, they're set as headers as well.
    64  // If the x-request-id header does not exists a new one is generated, and set as a header.
    65  func HeadersForRequest(request *http.Request) Headers {
    66  	reqHeaders := make(map[string]string)
    67  	headersFromCtx := HeadersFromContext(request.Context())
    68  
    69  	for _, headerKey := range headerKeys {
    70  		headerValue := request.Header.Get(headerKey)
    71  		if headerValue != "" {
    72  			reqHeaders[headerKey] = headerValue
    73  			continue
    74  		}
    75  
    76  		if headerValue, ok := headersFromCtx[headerKey]; ok {
    77  			request.Header.Set(headerKey, headerValue)
    78  			reqHeaders[headerKey] = headerValue
    79  		}
    80  	}
    81  
    82  	// Context might have been enriched with additional headers (outside of those among the well known header keys array)
    83  	// which should be attached as well
    84  	for headerKey, headerValue := range headersFromCtx {
    85  		if _, ok := reqHeaders[headerKey]; !ok {
    86  			request.Header.Set(headerKey, headerValue)
    87  			reqHeaders[headerKey] = headerValue
    88  		}
    89  	}
    90  
    91  	if _, ok := reqHeaders[RequestIDHeaderKey]; !ok {
    92  		newRequestID := uuid.New().String()
    93  		reqHeaders[RequestIDHeaderKey] = newRequestID
    94  		request.Header.Set(RequestIDHeaderKey, newRequestID)
    95  	}
    96  
    97  	return reqHeaders
    98  }
    99  
   100  // HeadersFromContext returns the headers for the provided context
   101  func HeadersFromContext(ctx context.Context) Headers {
   102  	var headersFromCtx map[string]string
   103  	if ctx.Value(HeadersContextKey) != nil {
   104  		var ok bool
   105  		headersFromCtx, ok = ctx.Value(HeadersContextKey).(Headers)
   106  		if !ok {
   107  			logrus.Errorf("unexpected type of %s: %T, should be %T", HeadersContextKey, headersFromCtx, Headers{})
   108  		}
   109  	}
   110  
   111  	return headersFromCtx
   112  }
   113  
   114  // CorrelationIDFromContext returns correlation id from the given context
   115  func CorrelationIDFromContext(ctx context.Context) string {
   116  	return HeadersFromContext(ctx)[RequestIDHeaderKey]
   117  }
   118  
   119  // SaveToContext saves the provided headers as correlation ID headers in the specified context
   120  func SaveToContext(ctx context.Context, headers Headers) context.Context {
   121  	return context.WithValue(ctx, HeadersContextKey, headers)
   122  }
   123  
   124  // SaveCorrelationIDHeaderToContext saves the header provided key/value pair as a correlation ID header in the specified context
   125  func SaveCorrelationIDHeaderToContext(ctx context.Context, key, value *string) context.Context {
   126  	if key == nil || value == nil {
   127  		return ctx
   128  	}
   129  
   130  	headers := HeadersFromContext(ctx)
   131  	if headers == nil {
   132  		headers = make(map[string]string)
   133  	}
   134  
   135  	headers[*key] = *value
   136  
   137  	return SaveToContext(ctx, headers)
   138  }