github.com/weaviate/weaviate@v1.24.6/adapters/handlers/rest/handlers_graphql.go (about)

     1  //                           _       _
     2  // __      _____  __ ___   ___  __ _| |_ ___
     3  // \ \ /\ / / _ \/ _` \ \ / / |/ _` | __/ _ \
     4  //  \ V  V /  __/ (_| |\ V /| | (_| | ||  __/
     5  //   \_/\_/ \___|\__,_| \_/ |_|\__,_|\__\___|
     6  //
     7  //  Copyright © 2016 - 2024 Weaviate B.V. All rights reserved.
     8  //
     9  //  CONTACT: hello@weaviate.io
    10  //
    11  
    12  package rest
    13  
    14  import (
    15  	"context"
    16  	"encoding/json"
    17  	"fmt"
    18  	"strconv"
    19  	"strings"
    20  	"sync"
    21  
    22  	"github.com/sirupsen/logrus"
    23  	"github.com/weaviate/weaviate/usecases/auth/authorization/errors"
    24  	"github.com/weaviate/weaviate/usecases/monitoring"
    25  	"github.com/weaviate/weaviate/usecases/schema"
    26  
    27  	middleware "github.com/go-openapi/runtime/middleware"
    28  	tailorincgraphql "github.com/tailor-inc/graphql"
    29  	"github.com/tailor-inc/graphql/gqlerrors"
    30  	libgraphql "github.com/weaviate/weaviate/adapters/handlers/graphql"
    31  	"github.com/weaviate/weaviate/adapters/handlers/rest/operations"
    32  	"github.com/weaviate/weaviate/adapters/handlers/rest/operations/graphql"
    33  	enterrors "github.com/weaviate/weaviate/entities/errors"
    34  	"github.com/weaviate/weaviate/entities/models"
    35  )
    36  
    37  const error422 string = "The request is well-formed but was unable to be followed due to semantic errors."
    38  
    39  type gqlUnbatchedRequestResponse struct {
    40  	RequestIndex int
    41  	Response     *models.GraphQLResponse
    42  }
    43  
    44  type graphQLProvider interface {
    45  	GetGraphQL() libgraphql.GraphQL
    46  }
    47  
    48  func setupGraphQLHandlers(
    49  	api *operations.WeaviateAPI,
    50  	gqlProvider graphQLProvider,
    51  	m *schema.Manager,
    52  	disabled bool,
    53  	metrics *monitoring.PrometheusMetrics,
    54  	logger logrus.FieldLogger,
    55  ) {
    56  	metricRequestsTotal := newGraphqlRequestsTotal(metrics, logger)
    57  	api.GraphqlGraphqlPostHandler = graphql.GraphqlPostHandlerFunc(func(params graphql.GraphqlPostParams, principal *models.Principal) middleware.Responder {
    58  		// All requests to the graphQL API need at least permissions to read the schema. Request might have further
    59  		// authorization requirements.
    60  
    61  		err := m.Authorizer.Authorize(principal, "list", "schema/*")
    62  		if err != nil {
    63  			metricRequestsTotal.logUserError()
    64  			switch err.(type) {
    65  			case errors.Forbidden:
    66  				return graphql.NewGraphqlPostForbidden().
    67  					WithPayload(errPayloadFromSingleErr(err))
    68  			default:
    69  				return graphql.NewGraphqlPostUnprocessableEntity().
    70  					WithPayload(errPayloadFromSingleErr(err))
    71  			}
    72  		}
    73  
    74  		if disabled {
    75  			metricRequestsTotal.logUserError()
    76  			err := fmt.Errorf("graphql api is disabled")
    77  			return graphql.NewGraphqlPostUnprocessableEntity().
    78  				WithPayload(errPayloadFromSingleErr(err))
    79  		}
    80  
    81  		errorResponse := &models.ErrorResponse{}
    82  
    83  		// Get all input from the body of the request, as it is a POST.
    84  		query := params.Body.Query
    85  		operationName := params.Body.OperationName
    86  
    87  		// If query is empty, the request is unprocessable
    88  		if query == "" {
    89  			metricRequestsTotal.logUserError()
    90  			errorResponse.Error = []*models.ErrorResponseErrorItems0{
    91  				{
    92  					Message: "query cannot be empty",
    93  				},
    94  			}
    95  			return graphql.NewGraphqlPostUnprocessableEntity().WithPayload(errorResponse)
    96  		}
    97  
    98  		// Only set variables if exists in request
    99  		var variables map[string]interface{}
   100  		if params.Body.Variables != nil {
   101  			variables = params.Body.Variables.(map[string]interface{})
   102  		}
   103  
   104  		graphQL := gqlProvider.GetGraphQL()
   105  		if graphQL == nil {
   106  			metricRequestsTotal.logUserError()
   107  			errorResponse.Error = []*models.ErrorResponseErrorItems0{
   108  				{
   109  					Message: "no graphql provider present, this is most likely because no schema is present. Import a schema first!",
   110  				},
   111  			}
   112  			return graphql.NewGraphqlPostUnprocessableEntity().WithPayload(errorResponse)
   113  		}
   114  
   115  		ctx := params.HTTPRequest.Context()
   116  		ctx = context.WithValue(ctx, "principal", principal)
   117  
   118  		result := graphQL.Resolve(ctx, query,
   119  			operationName, variables)
   120  
   121  		// Marshal the JSON
   122  		resultJSON, jsonErr := json.Marshal(result)
   123  		if jsonErr != nil {
   124  			metricRequestsTotal.logUserError()
   125  			errorResponse.Error = []*models.ErrorResponseErrorItems0{
   126  				{
   127  					Message: fmt.Sprintf("couldn't marshal json: %s", jsonErr),
   128  				},
   129  			}
   130  			return graphql.NewGraphqlPostUnprocessableEntity().WithPayload(errorResponse)
   131  		}
   132  
   133  		// Put the data in a response ready object
   134  		graphQLResponse := &models.GraphQLResponse{}
   135  		marshallErr := json.Unmarshal(resultJSON, graphQLResponse)
   136  
   137  		// If json gave error, return nothing.
   138  		if marshallErr != nil {
   139  			metricRequestsTotal.logUserError()
   140  			errorResponse.Error = []*models.ErrorResponseErrorItems0{
   141  				{
   142  					Message: fmt.Sprintf("couldn't unmarshal json: %s\noriginal result was %#v", marshallErr, result),
   143  				},
   144  			}
   145  			return graphql.NewGraphqlPostUnprocessableEntity().WithPayload(errorResponse)
   146  		}
   147  
   148  		metricRequestsTotal.log(result)
   149  		// Return the response
   150  		return graphql.NewGraphqlPostOK().WithPayload(graphQLResponse)
   151  	})
   152  
   153  	api.GraphqlGraphqlBatchHandler = graphql.GraphqlBatchHandlerFunc(func(params graphql.GraphqlBatchParams, principal *models.Principal) middleware.Responder {
   154  		amountOfBatchedRequests := len(params.Body)
   155  		errorResponse := &models.ErrorResponse{}
   156  
   157  		if amountOfBatchedRequests == 0 {
   158  			metricRequestsTotal.logUserError()
   159  			return graphql.NewGraphqlBatchUnprocessableEntity().WithPayload(errorResponse)
   160  		}
   161  		requestResults := make(chan gqlUnbatchedRequestResponse, amountOfBatchedRequests)
   162  
   163  		wg := new(sync.WaitGroup)
   164  
   165  		ctx := params.HTTPRequest.Context()
   166  		ctx = context.WithValue(ctx, "principal", principal)
   167  
   168  		graphQL := gqlProvider.GetGraphQL()
   169  		if graphQL == nil {
   170  			metricRequestsTotal.logUserError()
   171  			errRes := errPayloadFromSingleErr(fmt.Errorf("no graphql provider present, " +
   172  				"this is most likely because no schema is present. Import a schema first!"))
   173  			return graphql.NewGraphqlBatchUnprocessableEntity().WithPayload(errRes)
   174  		}
   175  
   176  		// Generate a goroutine for each separate request
   177  		for requestIndex, unbatchedRequest := range params.Body {
   178  			requestIndex, unbatchedRequest := requestIndex, unbatchedRequest
   179  			wg.Add(1)
   180  			enterrors.GoWrapper(func() {
   181  				handleUnbatchedGraphQLRequest(ctx, wg, graphQL, unbatchedRequest, requestIndex, &requestResults, metricRequestsTotal)
   182  			}, logger)
   183  		}
   184  
   185  		wg.Wait()
   186  
   187  		close(requestResults)
   188  
   189  		batchedRequestResponse := make([]*models.GraphQLResponse, amountOfBatchedRequests)
   190  
   191  		// Add the requests to the result array in the correct order
   192  		for unbatchedRequestResult := range requestResults {
   193  			batchedRequestResponse[unbatchedRequestResult.RequestIndex] = unbatchedRequestResult.Response
   194  		}
   195  
   196  		return graphql.NewGraphqlBatchOK().WithPayload(batchedRequestResponse)
   197  	})
   198  }
   199  
   200  // Handle a single unbatched GraphQL request, return a tuple containing the index of the request in the batch and either the response or an error
   201  func handleUnbatchedGraphQLRequest(ctx context.Context, wg *sync.WaitGroup, graphQL libgraphql.GraphQL, unbatchedRequest *models.GraphQLQuery, requestIndex int, requestResults *chan gqlUnbatchedRequestResponse, metricRequestsTotal *graphqlRequestsTotal) {
   202  	defer wg.Done()
   203  
   204  	// Get all input from the body of the request
   205  	query := unbatchedRequest.Query
   206  	operationName := unbatchedRequest.OperationName
   207  	graphQLResponse := &models.GraphQLResponse{}
   208  
   209  	// Return an unprocessable error if the query is empty
   210  	if query == "" {
   211  		metricRequestsTotal.logUserError()
   212  		// Regular error messages are returned as an error code in the request header, but that doesn't work for batched requests
   213  		errorCode := strconv.Itoa(graphql.GraphqlBatchUnprocessableEntityCode)
   214  		errorMessage := fmt.Sprintf("%s: %s", errorCode, error422)
   215  		errors := []*models.GraphQLError{{Message: errorMessage}}
   216  		graphQLResponse := models.GraphQLResponse{Data: nil, Errors: errors}
   217  		*requestResults <- gqlUnbatchedRequestResponse{
   218  			requestIndex,
   219  			&graphQLResponse,
   220  		}
   221  	} else {
   222  		// Extract any variables from the request
   223  		var variables map[string]interface{}
   224  		if unbatchedRequest.Variables != nil {
   225  			var ok bool
   226  			variables, ok = unbatchedRequest.Variables.(map[string]interface{})
   227  			if !ok {
   228  				errorCode := strconv.Itoa(graphql.GraphqlBatchUnprocessableEntityCode)
   229  				errorMessage := fmt.Sprintf("%s: %s", errorCode, fmt.Sprintf("expected map[string]interface{}, received %v", unbatchedRequest.Variables))
   230  
   231  				error := []*models.GraphQLError{{Message: errorMessage}}
   232  				graphQLResponse := models.GraphQLResponse{Data: nil, Errors: error}
   233  				*requestResults <- gqlUnbatchedRequestResponse{
   234  					requestIndex,
   235  					&graphQLResponse,
   236  				}
   237  				return
   238  			}
   239  		}
   240  
   241  		result := graphQL.Resolve(ctx, query, operationName, variables)
   242  
   243  		// Marshal the JSON
   244  		resultJSON, jsonErr := json.Marshal(result)
   245  
   246  		// Return an unprocessable error if marshalling the result to JSON failed
   247  		if jsonErr != nil {
   248  			metricRequestsTotal.logUserError()
   249  			// Regular error messages are returned as an error code in the request header, but that doesn't work for batched requests
   250  			errorCode := strconv.Itoa(graphql.GraphqlBatchUnprocessableEntityCode)
   251  			errorMessage := fmt.Sprintf("%s: %s", errorCode, error422)
   252  			errors := []*models.GraphQLError{{Message: errorMessage}}
   253  			graphQLResponse := models.GraphQLResponse{Data: nil, Errors: errors}
   254  			*requestResults <- gqlUnbatchedRequestResponse{
   255  				requestIndex,
   256  				&graphQLResponse,
   257  			}
   258  		} else {
   259  			// Put the result data in a response ready object
   260  			marshallErr := json.Unmarshal(resultJSON, graphQLResponse)
   261  
   262  			// Return an unprocessable error if unmarshalling the result to JSON failed
   263  			if marshallErr != nil {
   264  				metricRequestsTotal.logUserError()
   265  				// Regular error messages are returned as an error code in the request header, but that doesn't work for batched requests
   266  				errorCode := strconv.Itoa(graphql.GraphqlBatchUnprocessableEntityCode)
   267  				errorMessage := fmt.Sprintf("%s: %s", errorCode, error422)
   268  				errors := []*models.GraphQLError{{Message: errorMessage}}
   269  				graphQLResponse := models.GraphQLResponse{Data: nil, Errors: errors}
   270  				*requestResults <- gqlUnbatchedRequestResponse{
   271  					requestIndex,
   272  					&graphQLResponse,
   273  				}
   274  			} else {
   275  				metricRequestsTotal.log(result)
   276  				// Return the GraphQL response
   277  				*requestResults <- gqlUnbatchedRequestResponse{
   278  					requestIndex,
   279  					graphQLResponse,
   280  				}
   281  			}
   282  		}
   283  	}
   284  }
   285  
   286  type graphqlRequestsTotal struct {
   287  	metrics *requestsTotalMetric
   288  	logger  logrus.FieldLogger
   289  }
   290  
   291  func newGraphqlRequestsTotal(metrics *monitoring.PrometheusMetrics, logger logrus.FieldLogger) *graphqlRequestsTotal {
   292  	return &graphqlRequestsTotal{newRequestsTotalMetric(metrics, "graphql"), logger}
   293  }
   294  
   295  func (e *graphqlRequestsTotal) getQueryType(path []interface{}) string {
   296  	if len(path) > 0 {
   297  		return fmt.Sprintf("%v", path[0])
   298  	}
   299  	return ""
   300  }
   301  
   302  func (e *graphqlRequestsTotal) getClassName(path []interface{}) string {
   303  	if len(path) > 1 {
   304  		return fmt.Sprintf("%v", path[1])
   305  	}
   306  	return ""
   307  }
   308  
   309  func (e *graphqlRequestsTotal) getErrGraphQLUser(gqlError gqlerrors.FormattedError) (bool, *enterrors.ErrGraphQLUser) {
   310  	if gqlError.OriginalError() != nil {
   311  		if gqlOriginalErr, ok := gqlError.OriginalError().(*gqlerrors.Error); ok {
   312  			if gqlOriginalErr.OriginalError != nil {
   313  				switch err := gqlOriginalErr.OriginalError.(type) {
   314  				case enterrors.ErrGraphQLUser:
   315  					return e.getError(err)
   316  				default:
   317  					if gqlFormatted, ok := gqlOriginalErr.OriginalError.(gqlerrors.FormattedError); ok {
   318  						if gqlFormatted.OriginalError() != nil {
   319  							return e.getError(gqlFormatted.OriginalError())
   320  						}
   321  					}
   322  				}
   323  			}
   324  		}
   325  	}
   326  	return false, nil
   327  }
   328  
   329  func (e *graphqlRequestsTotal) isSyntaxRelatedError(gqlError gqlerrors.FormattedError) bool {
   330  	for _, prefix := range []string{"Syntax Error ", "Cannot query field"} {
   331  		if strings.HasPrefix(gqlError.Message, prefix) {
   332  			return true
   333  		}
   334  	}
   335  	return false
   336  }
   337  
   338  func (e *graphqlRequestsTotal) getError(err error) (bool, *enterrors.ErrGraphQLUser) {
   339  	switch e := err.(type) {
   340  	case enterrors.ErrGraphQLUser:
   341  		return true, &e
   342  	default:
   343  		return false, nil
   344  	}
   345  }
   346  
   347  func (e *graphqlRequestsTotal) log(result *tailorincgraphql.Result) {
   348  	if len(result.Errors) > 0 {
   349  		for _, gqlErr := range result.Errors {
   350  			if isUserError, err := e.getErrGraphQLUser(gqlErr); isUserError {
   351  				if e.metrics != nil {
   352  					e.metrics.RequestsTotalInc(UserError, err.ClassName(), err.QueryType())
   353  				}
   354  			} else if e.isSyntaxRelatedError(gqlErr) {
   355  				if e.metrics != nil {
   356  					e.metrics.RequestsTotalInc(UserError, "", "")
   357  				}
   358  			} else {
   359  				e.logServerError(gqlErr, e.getClassName(gqlErr.Path), e.getQueryType(gqlErr.Path))
   360  			}
   361  		}
   362  	} else if result.Data != nil {
   363  		e.logOk(result.Data)
   364  	}
   365  }
   366  
   367  func (e *graphqlRequestsTotal) logServerError(err error, className, queryType string) {
   368  	e.logger.WithFields(logrus.Fields{
   369  		"action":     "requests_total",
   370  		"api":        "graphql",
   371  		"query_type": queryType,
   372  		"class_name": className,
   373  	}).WithError(err).Error("unexpected error")
   374  	if e.metrics != nil {
   375  		e.metrics.RequestsTotalInc(ServerError, className, queryType)
   376  	}
   377  }
   378  
   379  func (e *graphqlRequestsTotal) logUserError() {
   380  	if e.metrics != nil {
   381  		e.metrics.RequestsTotalInc(UserError, "", "")
   382  	}
   383  }
   384  
   385  func (e *graphqlRequestsTotal) logOk(data interface{}) {
   386  	if e.metrics != nil {
   387  		className, queryType := e.getClassNameAndQueryType(data)
   388  		e.metrics.RequestsTotalInc(Ok, className, queryType)
   389  	}
   390  }
   391  
   392  func (e *graphqlRequestsTotal) getClassNameAndQueryType(data interface{}) (className, queryType string) {
   393  	dataMap, ok := data.(map[string]interface{})
   394  	if ok {
   395  		for query, value := range dataMap {
   396  			queryType = query
   397  			if queryType == "Explore" {
   398  				// Explore queries are cross class queries, we won't get a className in this case
   399  				// there's no sense in further value investigation
   400  				return
   401  			}
   402  			if value != nil {
   403  				if valueMap, ok := value.(map[string]interface{}); ok {
   404  					for class := range valueMap {
   405  						className = class
   406  						return
   407  					}
   408  				}
   409  			}
   410  		}
   411  	}
   412  	return
   413  }