github.com/weaviate/weaviate@v1.24.6/adapters/handlers/rest/handlers_classification.go (about)

     1  //                           _       _
     2  // __      _____  __ ___   ___  __ _| |_ ___
     3  // \ \ /\ / / _ \/ _` \ \ / / |/ _` | __/ _ \
     4  //  \ V  V /  __/ (_| |\ V /| | (_| | ||  __/
     5  //   \_/\_/ \___|\__,_| \_/ |_|\__,_|\__\___|
     6  //
     7  //  Copyright © 2016 - 2024 Weaviate B.V. All rights reserved.
     8  //
     9  //  CONTACT: hello@weaviate.io
    10  //
    11  
    12  package rest
    13  
    14  import (
    15  	middleware "github.com/go-openapi/runtime/middleware"
    16  	"github.com/go-openapi/strfmt"
    17  	"github.com/sirupsen/logrus"
    18  	"github.com/weaviate/weaviate/adapters/handlers/rest/operations"
    19  	"github.com/weaviate/weaviate/adapters/handlers/rest/operations/classifications"
    20  	"github.com/weaviate/weaviate/entities/models"
    21  	"github.com/weaviate/weaviate/usecases/classification"
    22  	"github.com/weaviate/weaviate/usecases/monitoring"
    23  )
    24  
    25  func setupClassificationHandlers(api *operations.WeaviateAPI,
    26  	classifier *classification.Classifier, metrics *monitoring.PrometheusMetrics, logger logrus.FieldLogger,
    27  ) {
    28  	metricRequestsTotal := newClassificationRequestsTotal(metrics, logger)
    29  	api.ClassificationsClassificationsGetHandler = classifications.ClassificationsGetHandlerFunc(
    30  		func(params classifications.ClassificationsGetParams, principal *models.Principal) middleware.Responder {
    31  			res, err := classifier.Get(params.HTTPRequest.Context(), principal, strfmt.UUID(params.ID))
    32  			if err != nil {
    33  				metricRequestsTotal.logError("", err)
    34  				return classifications.NewClassificationsGetInternalServerError().WithPayload(errPayloadFromSingleErr(err))
    35  			}
    36  
    37  			if res == nil {
    38  				metricRequestsTotal.logUserError("")
    39  				return classifications.NewClassificationsGetNotFound()
    40  			}
    41  
    42  			metricRequestsTotal.logOk("")
    43  			return classifications.NewClassificationsGetOK().WithPayload(res)
    44  		},
    45  	)
    46  
    47  	api.ClassificationsClassificationsPostHandler = classifications.ClassificationsPostHandlerFunc(
    48  		func(params classifications.ClassificationsPostParams, principal *models.Principal) middleware.Responder {
    49  			res, err := classifier.Schedule(params.HTTPRequest.Context(), principal, *params.Params)
    50  			if err != nil {
    51  				metricRequestsTotal.logUserError("")
    52  				return classifications.NewClassificationsPostBadRequest().WithPayload(errPayloadFromSingleErr(err))
    53  			}
    54  
    55  			metricRequestsTotal.logOk("")
    56  			return classifications.NewClassificationsPostCreated().WithPayload(res)
    57  		},
    58  	)
    59  }
    60  
    61  type classificationRequestsTotal struct {
    62  	*restApiRequestsTotalImpl
    63  }
    64  
    65  func newClassificationRequestsTotal(metrics *monitoring.PrometheusMetrics, logger logrus.FieldLogger) restApiRequestsTotal {
    66  	return &classificationRequestsTotal{
    67  		restApiRequestsTotalImpl: &restApiRequestsTotalImpl{newRequestsTotalMetric(metrics, "rest"), "rest", "classification", logger},
    68  	}
    69  }
    70  
    71  func (e *classificationRequestsTotal) logError(className string, err error) {
    72  	switch err.(type) {
    73  	default:
    74  		e.logServerError(className, err)
    75  	}
    76  }