github.com/weaviate/weaviate@v1.24.6/adapters/handlers/rest/middlewares.go (about)

     1  //                           _       _
     2  // __      _____  __ ___   ___  __ _| |_ ___
     3  // \ \ /\ / / _ \/ _` \ \ / / |/ _` | __/ _ \
     4  //  \ V  V /  __/ (_| |\ V /| | (_| | ||  __/
     5  //   \_/\_/ \___|\__,_| \_/ |_|\__,_|\__\___|
     6  //
     7  //  Copyright © 2016 - 2024 Weaviate B.V. All rights reserved.
     8  //
     9  //  CONTACT: hello@weaviate.io
    10  //
    11  
    12  package rest
    13  
    14  import (
    15  	"context"
    16  	"fmt"
    17  	"net/http"
    18  	"strings"
    19  	"time"
    20  
    21  	"github.com/prometheus/client_golang/prometheus"
    22  	"github.com/rs/cors"
    23  	"github.com/sirupsen/logrus"
    24  	"github.com/weaviate/weaviate/adapters/handlers/rest/state"
    25  	"github.com/weaviate/weaviate/adapters/handlers/rest/swagger_middleware"
    26  	"github.com/weaviate/weaviate/usecases/config"
    27  	"github.com/weaviate/weaviate/usecases/modules"
    28  	"github.com/weaviate/weaviate/usecases/monitoring"
    29  )
    30  
    31  // The middleware configuration is for the handler executors. These do not apply to the swagger.json document.
    32  // The middleware executes after routing but before authentication, binding and validation
    33  //
    34  // we are setting the middlewares from within configureAPI, as we need access
    35  // to some resources which are not exposed
    36  func makeSetupMiddlewares(appState *state.State) func(http.Handler) http.Handler {
    37  	return func(handler http.Handler) http.Handler {
    38  		return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
    39  			if r.URL.String() == "/v1/.well-known/openid-configuration" || r.URL.String() == "/v1" {
    40  				handler.ServeHTTP(w, r)
    41  				return
    42  			}
    43  			appState.AnonymousAccess.Middleware(handler).ServeHTTP(w, r)
    44  		})
    45  	}
    46  }
    47  
    48  func addHandleRoot(next http.Handler) http.Handler {
    49  	return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
    50  		if r.URL.String() == "/" {
    51  			w.Header().Add("Location", "/v1")
    52  			w.WriteHeader(http.StatusMovedPermanently)
    53  			w.Write([]byte(`{"links":{"href":"/v1","name":"api v1","documentationHref":` +
    54  				`"https://weaviate.io/developers/weaviate/current/"}}`))
    55  			return
    56  		}
    57  
    58  		next.ServeHTTP(w, r)
    59  	})
    60  }
    61  
    62  func makeAddModuleHandlers(modules *modules.Provider) func(http.Handler) http.Handler {
    63  	return func(next http.Handler) http.Handler {
    64  		mux := http.NewServeMux()
    65  
    66  		for _, mod := range modules.GetAll() {
    67  			prefix := fmt.Sprintf("/v1/modules/%s", mod.Name())
    68  			mux.Handle(fmt.Sprintf("%s/", prefix),
    69  				http.StripPrefix(prefix, mod.RootHandler()))
    70  		}
    71  
    72  		prefix := "/v1/modules"
    73  		return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
    74  			if url := r.URL.String(); len(url) > len(prefix) && url[:len(prefix)] == prefix {
    75  				mux.ServeHTTP(w, r)
    76  				return
    77  			}
    78  
    79  			next.ServeHTTP(w, r)
    80  		})
    81  	}
    82  }
    83  
    84  // The middleware configuration happens before anything, this middleware also applies to serving the swagger.json document.
    85  // So this is a good place to plug in a panic handling middleware, logging and metrics
    86  // Contains "x-api-key", "x-api-token" for legacy reasons, older interfaces might need these headers.
    87  func makeSetupGlobalMiddleware(appState *state.State) func(http.Handler) http.Handler {
    88  	return func(handler http.Handler) http.Handler {
    89  		handleCORS := cors.New(cors.Options{
    90  			OptionsPassthrough: true,
    91  			AllowedMethods:     strings.Split(appState.ServerConfig.Config.CORS.AllowMethods, ","),
    92  			AllowedHeaders:     strings.Split(appState.ServerConfig.Config.CORS.AllowHeaders, ","),
    93  			AllowedOrigins:     strings.Split(appState.ServerConfig.Config.CORS.AllowOrigin, ","),
    94  		}).Handler
    95  		handler = handleCORS(handler)
    96  		handler = swagger_middleware.AddMiddleware([]byte(SwaggerJSON), handler)
    97  		handler = makeAddLogging(appState.Logger)(handler)
    98  		if appState.ServerConfig.Config.Monitoring.Enabled {
    99  			handler = makeAddMonitoring(appState.Metrics)(handler)
   100  		}
   101  		handler = addPreflight(handler, appState.ServerConfig.Config.CORS)
   102  		handler = addLiveAndReadyness(appState, handler)
   103  		handler = addHandleRoot(handler)
   104  		handler = makeAddModuleHandlers(appState.Modules)(handler)
   105  		handler = addInjectHeadersIntoContext(handler)
   106  		handler = makeCatchPanics(appState.Logger,
   107  			newPanicsRequestsTotal(appState.Metrics, appState.Logger))(handler)
   108  
   109  		return handler
   110  	}
   111  }
   112  
   113  func makeAddLogging(logger logrus.FieldLogger) func(http.Handler) http.Handler {
   114  	return func(next http.Handler) http.Handler {
   115  		return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   116  			logger.
   117  				WithField("action", "restapi_request").
   118  				WithField("method", r.Method).
   119  				WithField("url", r.URL).
   120  				Debug("received HTTP request")
   121  			next.ServeHTTP(w, r)
   122  		})
   123  	}
   124  }
   125  
   126  func makeAddMonitoring(metrics *monitoring.PrometheusMetrics) func(http.Handler) http.Handler {
   127  	return func(next http.Handler) http.Handler {
   128  		return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   129  			before := time.Now()
   130  			method := r.Method
   131  			path := r.URL.Path
   132  			next.ServeHTTP(w, r)
   133  
   134  			if strings.HasPrefix(path, "/v1/batch/objects") && method == http.MethodPost {
   135  				metrics.BatchTime.With(prometheus.Labels{
   136  					"operation":  "total_api_level",
   137  					"class_name": "n/a",
   138  					"shard_name": "n/a",
   139  				}).
   140  					Observe(float64(time.Since(before) / time.Millisecond))
   141  			}
   142  		})
   143  	}
   144  }
   145  
   146  func addPreflight(next http.Handler, cfg config.CORS) http.Handler {
   147  	return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   148  		w.Header().Set("Access-Control-Allow-Origin", cfg.AllowOrigin)
   149  		w.Header().Set("Access-Control-Allow-Methods", cfg.AllowMethods)
   150  		w.Header().Set("Access-Control-Allow-Headers", cfg.AllowHeaders)
   151  
   152  		if r.Method == "OPTIONS" {
   153  			return
   154  		}
   155  
   156  		next.ServeHTTP(w, r)
   157  	})
   158  }
   159  
   160  func addInjectHeadersIntoContext(next http.Handler) http.Handler {
   161  	return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   162  		ctx := r.Context()
   163  		changed := false
   164  		for k, v := range r.Header {
   165  			if strings.HasPrefix(k, "X-") {
   166  				ctx = context.WithValue(ctx, k, v)
   167  				changed = true
   168  			}
   169  		}
   170  
   171  		if changed {
   172  			next.ServeHTTP(w, r.Clone(ctx))
   173  		} else {
   174  			next.ServeHTTP(w, r)
   175  		}
   176  	})
   177  }
   178  
   179  func addLiveAndReadyness(state *state.State, next http.Handler) http.Handler {
   180  	return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   181  		if r.URL.String() == "/v1/.well-known/live" {
   182  			w.WriteHeader(http.StatusOK)
   183  			return
   184  		}
   185  
   186  		if r.URL.String() == "/v1/.well-known/ready" {
   187  			code := http.StatusServiceUnavailable
   188  			if state.DB.StartupComplete() && state.Cluster.ClusterHealthScore() == 0 {
   189  				code = http.StatusOK
   190  			}
   191  			w.WriteHeader(code)
   192  			return
   193  		}
   194  
   195  		next.ServeHTTP(w, r)
   196  	})
   197  }