github.com/weaviate/weaviate@v1.24.6/adapters/handlers/rest/handlers_graphql.go (about) 1 // _ _ 2 // __ _____ __ ___ ___ __ _| |_ ___ 3 // \ \ /\ / / _ \/ _` \ \ / / |/ _` | __/ _ \ 4 // \ V V / __/ (_| |\ V /| | (_| | || __/ 5 // \_/\_/ \___|\__,_| \_/ |_|\__,_|\__\___| 6 // 7 // Copyright © 2016 - 2024 Weaviate B.V. All rights reserved. 8 // 9 // CONTACT: hello@weaviate.io 10 // 11 12 package rest 13 14 import ( 15 "context" 16 "encoding/json" 17 "fmt" 18 "strconv" 19 "strings" 20 "sync" 21 22 "github.com/sirupsen/logrus" 23 "github.com/weaviate/weaviate/usecases/auth/authorization/errors" 24 "github.com/weaviate/weaviate/usecases/monitoring" 25 "github.com/weaviate/weaviate/usecases/schema" 26 27 middleware "github.com/go-openapi/runtime/middleware" 28 tailorincgraphql "github.com/tailor-inc/graphql" 29 "github.com/tailor-inc/graphql/gqlerrors" 30 libgraphql "github.com/weaviate/weaviate/adapters/handlers/graphql" 31 "github.com/weaviate/weaviate/adapters/handlers/rest/operations" 32 "github.com/weaviate/weaviate/adapters/handlers/rest/operations/graphql" 33 enterrors "github.com/weaviate/weaviate/entities/errors" 34 "github.com/weaviate/weaviate/entities/models" 35 ) 36 37 const error422 string = "The request is well-formed but was unable to be followed due to semantic errors." 38 39 type gqlUnbatchedRequestResponse struct { 40 RequestIndex int 41 Response *models.GraphQLResponse 42 } 43 44 type graphQLProvider interface { 45 GetGraphQL() libgraphql.GraphQL 46 } 47 48 func setupGraphQLHandlers( 49 api *operations.WeaviateAPI, 50 gqlProvider graphQLProvider, 51 m *schema.Manager, 52 disabled bool, 53 metrics *monitoring.PrometheusMetrics, 54 logger logrus.FieldLogger, 55 ) { 56 metricRequestsTotal := newGraphqlRequestsTotal(metrics, logger) 57 api.GraphqlGraphqlPostHandler = graphql.GraphqlPostHandlerFunc(func(params graphql.GraphqlPostParams, principal *models.Principal) middleware.Responder { 58 // All requests to the graphQL API need at least permissions to read the schema. Request might have further 59 // authorization requirements. 60 61 err := m.Authorizer.Authorize(principal, "list", "schema/*") 62 if err != nil { 63 metricRequestsTotal.logUserError() 64 switch err.(type) { 65 case errors.Forbidden: 66 return graphql.NewGraphqlPostForbidden(). 67 WithPayload(errPayloadFromSingleErr(err)) 68 default: 69 return graphql.NewGraphqlPostUnprocessableEntity(). 70 WithPayload(errPayloadFromSingleErr(err)) 71 } 72 } 73 74 if disabled { 75 metricRequestsTotal.logUserError() 76 err := fmt.Errorf("graphql api is disabled") 77 return graphql.NewGraphqlPostUnprocessableEntity(). 78 WithPayload(errPayloadFromSingleErr(err)) 79 } 80 81 errorResponse := &models.ErrorResponse{} 82 83 // Get all input from the body of the request, as it is a POST. 84 query := params.Body.Query 85 operationName := params.Body.OperationName 86 87 // If query is empty, the request is unprocessable 88 if query == "" { 89 metricRequestsTotal.logUserError() 90 errorResponse.Error = []*models.ErrorResponseErrorItems0{ 91 { 92 Message: "query cannot be empty", 93 }, 94 } 95 return graphql.NewGraphqlPostUnprocessableEntity().WithPayload(errorResponse) 96 } 97 98 // Only set variables if exists in request 99 var variables map[string]interface{} 100 if params.Body.Variables != nil { 101 variables = params.Body.Variables.(map[string]interface{}) 102 } 103 104 graphQL := gqlProvider.GetGraphQL() 105 if graphQL == nil { 106 metricRequestsTotal.logUserError() 107 errorResponse.Error = []*models.ErrorResponseErrorItems0{ 108 { 109 Message: "no graphql provider present, this is most likely because no schema is present. Import a schema first!", 110 }, 111 } 112 return graphql.NewGraphqlPostUnprocessableEntity().WithPayload(errorResponse) 113 } 114 115 ctx := params.HTTPRequest.Context() 116 ctx = context.WithValue(ctx, "principal", principal) 117 118 result := graphQL.Resolve(ctx, query, 119 operationName, variables) 120 121 // Marshal the JSON 122 resultJSON, jsonErr := json.Marshal(result) 123 if jsonErr != nil { 124 metricRequestsTotal.logUserError() 125 errorResponse.Error = []*models.ErrorResponseErrorItems0{ 126 { 127 Message: fmt.Sprintf("couldn't marshal json: %s", jsonErr), 128 }, 129 } 130 return graphql.NewGraphqlPostUnprocessableEntity().WithPayload(errorResponse) 131 } 132 133 // Put the data in a response ready object 134 graphQLResponse := &models.GraphQLResponse{} 135 marshallErr := json.Unmarshal(resultJSON, graphQLResponse) 136 137 // If json gave error, return nothing. 138 if marshallErr != nil { 139 metricRequestsTotal.logUserError() 140 errorResponse.Error = []*models.ErrorResponseErrorItems0{ 141 { 142 Message: fmt.Sprintf("couldn't unmarshal json: %s\noriginal result was %#v", marshallErr, result), 143 }, 144 } 145 return graphql.NewGraphqlPostUnprocessableEntity().WithPayload(errorResponse) 146 } 147 148 metricRequestsTotal.log(result) 149 // Return the response 150 return graphql.NewGraphqlPostOK().WithPayload(graphQLResponse) 151 }) 152 153 api.GraphqlGraphqlBatchHandler = graphql.GraphqlBatchHandlerFunc(func(params graphql.GraphqlBatchParams, principal *models.Principal) middleware.Responder { 154 amountOfBatchedRequests := len(params.Body) 155 errorResponse := &models.ErrorResponse{} 156 157 if amountOfBatchedRequests == 0 { 158 metricRequestsTotal.logUserError() 159 return graphql.NewGraphqlBatchUnprocessableEntity().WithPayload(errorResponse) 160 } 161 requestResults := make(chan gqlUnbatchedRequestResponse, amountOfBatchedRequests) 162 163 wg := new(sync.WaitGroup) 164 165 ctx := params.HTTPRequest.Context() 166 ctx = context.WithValue(ctx, "principal", principal) 167 168 graphQL := gqlProvider.GetGraphQL() 169 if graphQL == nil { 170 metricRequestsTotal.logUserError() 171 errRes := errPayloadFromSingleErr(fmt.Errorf("no graphql provider present, " + 172 "this is most likely because no schema is present. Import a schema first!")) 173 return graphql.NewGraphqlBatchUnprocessableEntity().WithPayload(errRes) 174 } 175 176 // Generate a goroutine for each separate request 177 for requestIndex, unbatchedRequest := range params.Body { 178 requestIndex, unbatchedRequest := requestIndex, unbatchedRequest 179 wg.Add(1) 180 enterrors.GoWrapper(func() { 181 handleUnbatchedGraphQLRequest(ctx, wg, graphQL, unbatchedRequest, requestIndex, &requestResults, metricRequestsTotal) 182 }, logger) 183 } 184 185 wg.Wait() 186 187 close(requestResults) 188 189 batchedRequestResponse := make([]*models.GraphQLResponse, amountOfBatchedRequests) 190 191 // Add the requests to the result array in the correct order 192 for unbatchedRequestResult := range requestResults { 193 batchedRequestResponse[unbatchedRequestResult.RequestIndex] = unbatchedRequestResult.Response 194 } 195 196 return graphql.NewGraphqlBatchOK().WithPayload(batchedRequestResponse) 197 }) 198 } 199 200 // Handle a single unbatched GraphQL request, return a tuple containing the index of the request in the batch and either the response or an error 201 func handleUnbatchedGraphQLRequest(ctx context.Context, wg *sync.WaitGroup, graphQL libgraphql.GraphQL, unbatchedRequest *models.GraphQLQuery, requestIndex int, requestResults *chan gqlUnbatchedRequestResponse, metricRequestsTotal *graphqlRequestsTotal) { 202 defer wg.Done() 203 204 // Get all input from the body of the request 205 query := unbatchedRequest.Query 206 operationName := unbatchedRequest.OperationName 207 graphQLResponse := &models.GraphQLResponse{} 208 209 // Return an unprocessable error if the query is empty 210 if query == "" { 211 metricRequestsTotal.logUserError() 212 // Regular error messages are returned as an error code in the request header, but that doesn't work for batched requests 213 errorCode := strconv.Itoa(graphql.GraphqlBatchUnprocessableEntityCode) 214 errorMessage := fmt.Sprintf("%s: %s", errorCode, error422) 215 errors := []*models.GraphQLError{{Message: errorMessage}} 216 graphQLResponse := models.GraphQLResponse{Data: nil, Errors: errors} 217 *requestResults <- gqlUnbatchedRequestResponse{ 218 requestIndex, 219 &graphQLResponse, 220 } 221 } else { 222 // Extract any variables from the request 223 var variables map[string]interface{} 224 if unbatchedRequest.Variables != nil { 225 var ok bool 226 variables, ok = unbatchedRequest.Variables.(map[string]interface{}) 227 if !ok { 228 errorCode := strconv.Itoa(graphql.GraphqlBatchUnprocessableEntityCode) 229 errorMessage := fmt.Sprintf("%s: %s", errorCode, fmt.Sprintf("expected map[string]interface{}, received %v", unbatchedRequest.Variables)) 230 231 error := []*models.GraphQLError{{Message: errorMessage}} 232 graphQLResponse := models.GraphQLResponse{Data: nil, Errors: error} 233 *requestResults <- gqlUnbatchedRequestResponse{ 234 requestIndex, 235 &graphQLResponse, 236 } 237 return 238 } 239 } 240 241 result := graphQL.Resolve(ctx, query, operationName, variables) 242 243 // Marshal the JSON 244 resultJSON, jsonErr := json.Marshal(result) 245 246 // Return an unprocessable error if marshalling the result to JSON failed 247 if jsonErr != nil { 248 metricRequestsTotal.logUserError() 249 // Regular error messages are returned as an error code in the request header, but that doesn't work for batched requests 250 errorCode := strconv.Itoa(graphql.GraphqlBatchUnprocessableEntityCode) 251 errorMessage := fmt.Sprintf("%s: %s", errorCode, error422) 252 errors := []*models.GraphQLError{{Message: errorMessage}} 253 graphQLResponse := models.GraphQLResponse{Data: nil, Errors: errors} 254 *requestResults <- gqlUnbatchedRequestResponse{ 255 requestIndex, 256 &graphQLResponse, 257 } 258 } else { 259 // Put the result data in a response ready object 260 marshallErr := json.Unmarshal(resultJSON, graphQLResponse) 261 262 // Return an unprocessable error if unmarshalling the result to JSON failed 263 if marshallErr != nil { 264 metricRequestsTotal.logUserError() 265 // Regular error messages are returned as an error code in the request header, but that doesn't work for batched requests 266 errorCode := strconv.Itoa(graphql.GraphqlBatchUnprocessableEntityCode) 267 errorMessage := fmt.Sprintf("%s: %s", errorCode, error422) 268 errors := []*models.GraphQLError{{Message: errorMessage}} 269 graphQLResponse := models.GraphQLResponse{Data: nil, Errors: errors} 270 *requestResults <- gqlUnbatchedRequestResponse{ 271 requestIndex, 272 &graphQLResponse, 273 } 274 } else { 275 metricRequestsTotal.log(result) 276 // Return the GraphQL response 277 *requestResults <- gqlUnbatchedRequestResponse{ 278 requestIndex, 279 graphQLResponse, 280 } 281 } 282 } 283 } 284 } 285 286 type graphqlRequestsTotal struct { 287 metrics *requestsTotalMetric 288 logger logrus.FieldLogger 289 } 290 291 func newGraphqlRequestsTotal(metrics *monitoring.PrometheusMetrics, logger logrus.FieldLogger) *graphqlRequestsTotal { 292 return &graphqlRequestsTotal{newRequestsTotalMetric(metrics, "graphql"), logger} 293 } 294 295 func (e *graphqlRequestsTotal) getQueryType(path []interface{}) string { 296 if len(path) > 0 { 297 return fmt.Sprintf("%v", path[0]) 298 } 299 return "" 300 } 301 302 func (e *graphqlRequestsTotal) getClassName(path []interface{}) string { 303 if len(path) > 1 { 304 return fmt.Sprintf("%v", path[1]) 305 } 306 return "" 307 } 308 309 func (e *graphqlRequestsTotal) getErrGraphQLUser(gqlError gqlerrors.FormattedError) (bool, *enterrors.ErrGraphQLUser) { 310 if gqlError.OriginalError() != nil { 311 if gqlOriginalErr, ok := gqlError.OriginalError().(*gqlerrors.Error); ok { 312 if gqlOriginalErr.OriginalError != nil { 313 switch err := gqlOriginalErr.OriginalError.(type) { 314 case enterrors.ErrGraphQLUser: 315 return e.getError(err) 316 default: 317 if gqlFormatted, ok := gqlOriginalErr.OriginalError.(gqlerrors.FormattedError); ok { 318 if gqlFormatted.OriginalError() != nil { 319 return e.getError(gqlFormatted.OriginalError()) 320 } 321 } 322 } 323 } 324 } 325 } 326 return false, nil 327 } 328 329 func (e *graphqlRequestsTotal) isSyntaxRelatedError(gqlError gqlerrors.FormattedError) bool { 330 for _, prefix := range []string{"Syntax Error ", "Cannot query field"} { 331 if strings.HasPrefix(gqlError.Message, prefix) { 332 return true 333 } 334 } 335 return false 336 } 337 338 func (e *graphqlRequestsTotal) getError(err error) (bool, *enterrors.ErrGraphQLUser) { 339 switch e := err.(type) { 340 case enterrors.ErrGraphQLUser: 341 return true, &e 342 default: 343 return false, nil 344 } 345 } 346 347 func (e *graphqlRequestsTotal) log(result *tailorincgraphql.Result) { 348 if len(result.Errors) > 0 { 349 for _, gqlErr := range result.Errors { 350 if isUserError, err := e.getErrGraphQLUser(gqlErr); isUserError { 351 if e.metrics != nil { 352 e.metrics.RequestsTotalInc(UserError, err.ClassName(), err.QueryType()) 353 } 354 } else if e.isSyntaxRelatedError(gqlErr) { 355 if e.metrics != nil { 356 e.metrics.RequestsTotalInc(UserError, "", "") 357 } 358 } else { 359 e.logServerError(gqlErr, e.getClassName(gqlErr.Path), e.getQueryType(gqlErr.Path)) 360 } 361 } 362 } else if result.Data != nil { 363 e.logOk(result.Data) 364 } 365 } 366 367 func (e *graphqlRequestsTotal) logServerError(err error, className, queryType string) { 368 e.logger.WithFields(logrus.Fields{ 369 "action": "requests_total", 370 "api": "graphql", 371 "query_type": queryType, 372 "class_name": className, 373 }).WithError(err).Error("unexpected error") 374 if e.metrics != nil { 375 e.metrics.RequestsTotalInc(ServerError, className, queryType) 376 } 377 } 378 379 func (e *graphqlRequestsTotal) logUserError() { 380 if e.metrics != nil { 381 e.metrics.RequestsTotalInc(UserError, "", "") 382 } 383 } 384 385 func (e *graphqlRequestsTotal) logOk(data interface{}) { 386 if e.metrics != nil { 387 className, queryType := e.getClassNameAndQueryType(data) 388 e.metrics.RequestsTotalInc(Ok, className, queryType) 389 } 390 } 391 392 func (e *graphqlRequestsTotal) getClassNameAndQueryType(data interface{}) (className, queryType string) { 393 dataMap, ok := data.(map[string]interface{}) 394 if ok { 395 for query, value := range dataMap { 396 queryType = query 397 if queryType == "Explore" { 398 // Explore queries are cross class queries, we won't get a className in this case 399 // there's no sense in further value investigation 400 return 401 } 402 if value != nil { 403 if valueMap, ok := value.(map[string]interface{}); ok { 404 for class := range valueMap { 405 className = class 406 return 407 } 408 } 409 } 410 } 411 } 412 return 413 }