github.com/viant/toolbox@v0.34.5/service_router.go (about)

     1  package toolbox
     2  
     3  import (
     4  	"bytes"
     5  	"errors"
     6  	"fmt"
     7  	"io/ioutil"
     8  	"net"
     9  	"net/http"
    10  	"reflect"
    11  	"strings"
    12  	"time"
    13  )
    14  
    15  const (
    16  	jsonContentType       = "application/json"
    17  	yamlContentTypeSuffix = "/yaml"
    18  	textPlainContentType  = "text/plain"
    19  	contentTypeHeader     = "Content-Type"
    20  )
    21  
    22  const (
    23  	//MethodGet HTTP GET meothd
    24  	MethodGet     = "GET"
    25  	MethodHead    = "HEAD"
    26  	MethodPost    = "POST"
    27  	MethodPut     = "PUT"
    28  	MethodPatch   = "PATCH" // RFC 5789
    29  	MethodDelete  = "DELETE"
    30  	MethodOptions = "OPTIONS"
    31  	MethodTrace   = "TRACE"
    32  )
    33  
    34  var httpMethods = map[string]bool{
    35  	MethodDelete:  true,
    36  	MethodGet:     true,
    37  	MethodPatch:   true,
    38  	MethodPost:    true,
    39  	MethodPut:     true,
    40  	MethodHead:    true,
    41  	MethodTrace:   true,
    42  	MethodOptions: true,
    43  }
    44  
    45  //HandlerInvoker method is responsible  of passing required parameters to router handler.
    46  type HandlerInvoker func(serviceRouting *ServiceRouting, request *http.Request, response http.ResponseWriter, parameters map[string]interface{}) error
    47  
    48  //DefaultEncoderFactory  - NewJSONEncoderFactory
    49  var DefaultEncoderFactory = NewJSONEncoderFactory()
    50  
    51  //DefaultDecoderFactory - NewJSONDecoderFactory
    52  var DefaultDecoderFactory = NewJSONDecoderFactory()
    53  
    54  //YamlDefaultEncoderFactory  - NewYamlEncoderFactory
    55  var YamlDefaultEncoderFactory = NewYamlEncoderFactory()
    56  
    57  //YamlDefaultDecoderFactory - NewYamlDecoderFactory
    58  var YamlDefaultDecoderFactory = NewFlexYamlDecoderFactory()
    59  
    60  //ServiceRouting represents a simple web services routing rule, which is matched with http request
    61  type ServiceRouting struct {
    62  	URI                 string      //matching uri
    63  	Handler             interface{} //has to be func
    64  	HTTPMethod          string
    65  	Parameters          []string
    66  	ContentTypeEncoders map[string]EncoderFactory //content type encoder factory
    67  	ContentTypeDecoders map[string]DecoderFactory //content type decoder factory
    68  	HandlerInvoker      HandlerInvoker            //optional function that will be used instead of reflection to invoke a handler.
    69  }
    70  
    71  func (sr ServiceRouting) getDecoderFactory(contentType string) DecoderFactory {
    72  	if sr.ContentTypeDecoders != nil {
    73  		if factory, found := sr.ContentTypeDecoders[contentType]; found {
    74  			return factory
    75  		}
    76  	}
    77  	if strings.HasSuffix(contentType, yamlContentTypeSuffix) {
    78  		return YamlDefaultDecoderFactory
    79  	}
    80  	return DefaultDecoderFactory
    81  }
    82  
    83  func (sr ServiceRouting) getEncoderFactory(contentType string) EncoderFactory {
    84  	if sr.ContentTypeDecoders != nil {
    85  		if factory, found := sr.ContentTypeEncoders[contentType]; found {
    86  			return factory
    87  		}
    88  	}
    89  	if strings.HasSuffix(contentType, yamlContentTypeSuffix) {
    90  		return YamlDefaultEncoderFactory
    91  	}
    92  	return DefaultEncoderFactory
    93  }
    94  
    95  func (sr ServiceRouting) extractParameterFromBody(parameterName string, targetType reflect.Type, request *http.Request) (interface{}, error) {
    96  	targetValuePointer := reflect.New(targetType)
    97  	contentType := getContentTypeOrJSONContentType(request.Header.Get(contentTypeHeader))
    98  	decoderFactory := sr.getDecoderFactory(contentType)
    99  	body, err := ioutil.ReadAll(request.Body)
   100  	if err != nil {
   101  		return nil, err
   102  	}
   103  
   104  	decoder := decoderFactory.Create(bytes.NewReader(body))
   105  
   106  	if !strings.Contains(parameterName, ":") {
   107  		err := decoder.Decode(targetValuePointer.Interface())
   108  		if err != nil {
   109  			return nil, fmt.Errorf("unable to extract %T due to: %v, body: !%s!", targetValuePointer.Interface(), err, body)
   110  		}
   111  	} else {
   112  		var valueMap = make(map[string]interface{})
   113  		pair := strings.SplitN(parameterName, ":", 2)
   114  		valueMap[pair[1]] = targetValuePointer.Interface()
   115  		err := decoder.Decode(&valueMap)
   116  		if err != nil {
   117  			return nil, fmt.Errorf("unable to extract %T due to %v", targetValuePointer.Interface(), err)
   118  		}
   119  	}
   120  	return targetValuePointer.Interface(), nil
   121  }
   122  
   123  func (sr ServiceRouting) extractParameters(request *http.Request, response http.ResponseWriter) (map[string]interface{}, error) {
   124  	var result = make(map[string]interface{})
   125  	_ = request.ParseForm()
   126  	functionSignature := GetFuncSignature(sr.Handler)
   127  	uriParameters, _ := ExtractURIParameters(sr.URI, request.RequestURI)
   128  	for _, name := range sr.Parameters {
   129  		value, found := uriParameters[name]
   130  		if found {
   131  			if strings.Contains(value, ",") {
   132  				result[name] = strings.Split(value, ",")
   133  			} else {
   134  				result[name] = value
   135  			}
   136  			continue
   137  		}
   138  
   139  		value = request.Form.Get(name)
   140  		if len(value) > 0 {
   141  			result[name] = value
   142  		} else {
   143  			continue
   144  		}
   145  	}
   146  	if HasSliceAnyElements(sr.Parameters, "@httpRequest") {
   147  		result["@httpRequest"] = request
   148  	}
   149  	if HasSliceAnyElements(sr.Parameters, "@httpResponseWriter") {
   150  		result["@httpResponseWriter"] = response
   151  	}
   152  
   153  	if request.ContentLength > 0 {
   154  		for i, parameter := range sr.Parameters {
   155  			if _, found := result[parameter]; !found {
   156  				value, err := sr.extractParameterFromBody(parameter, functionSignature[i], request)
   157  				if err != nil {
   158  					return nil, fmt.Errorf("failed to extract parameters for %v %v due to %v", sr.HTTPMethod, sr.URI, err)
   159  				}
   160  				result[parameter] = value
   161  				break
   162  			}
   163  		}
   164  	}
   165  	return result, nil
   166  }
   167  
   168  //ServiceRouter represents routing rule
   169  type ServiceRouter struct {
   170  	serviceRouting []*ServiceRouting
   171  }
   172  
   173  func (r *ServiceRouter) match(request *http.Request) []*ServiceRouting {
   174  	var result = make([]*ServiceRouting, 0)
   175  	for _, candidate := range r.serviceRouting {
   176  		if candidate.HTTPMethod == request.Method {
   177  			_, matched := ExtractURIParameters(candidate.URI, request.RequestURI)
   178  			if matched {
   179  				result = append(result, candidate)
   180  			}
   181  		}
   182  	}
   183  	return result
   184  }
   185  
   186  func getContentTypeOrJSONContentType(contentType string) string {
   187  	if strings.Contains(contentType, textPlainContentType) || strings.Contains(contentType, jsonContentType) || contentType == "" {
   188  		return jsonContentType
   189  	}
   190  	return contentType
   191  }
   192  
   193  //Route matches  service routing by http method , and number of parameters, then it call routing method, and sent back its response.
   194  func (r *ServiceRouter) Route(response http.ResponseWriter, request *http.Request) error {
   195  	candidates := r.match(request)
   196  	if len(candidates) == 0 {
   197  		var uriTemplates = make([]string, 0)
   198  		for _, routing := range r.serviceRouting {
   199  			uriTemplates = append(uriTemplates, routing.URI)
   200  		}
   201  		return fmt.Errorf("failed to route request - unable to match %v with one of %v", request.RequestURI, strings.Join(uriTemplates, ","))
   202  	}
   203  	var finalError error
   204  
   205  	for _, serviceRouting := range candidates {
   206  
   207  		parameterValues, err := serviceRouting.extractParameters(request, response)
   208  		if err != nil {
   209  			finalError = fmt.Errorf("unable to extract parameters due to %v", err)
   210  			continue
   211  		}
   212  
   213  		if serviceRouting.HandlerInvoker != nil {
   214  			err := serviceRouting.HandlerInvoker(serviceRouting, request, response, parameterValues)
   215  			if err != nil {
   216  				finalError = fmt.Errorf("unable to extract parameters due to %v", err)
   217  			}
   218  			continue
   219  		}
   220  
   221  		functionParameters, err := BuildFunctionParameters(serviceRouting.Handler, serviceRouting.Parameters, parameterValues)
   222  		if err != nil {
   223  			finalError = fmt.Errorf("unable to build function parameters %T due to %v", serviceRouting.Handler, err)
   224  			continue
   225  		}
   226  
   227  		result := CallFunction(serviceRouting.Handler, functionParameters...)
   228  		if len(result) > 0 {
   229  			err = WriteServiceRoutingResponse(response, request, serviceRouting, result[0])
   230  			if err != nil {
   231  				return fmt.Errorf("failed to write response response %v, due to %v", result[0], err)
   232  			}
   233  			return nil
   234  		}
   235  		response.Header().Set(contentTypeHeader, textPlainContentType)
   236  	}
   237  	if finalError != nil {
   238  		return fmt.Errorf("failed to route request - %v", finalError)
   239  	}
   240  	return nil
   241  }
   242  
   243  //WriteServiceRoutingResponse writes service router response
   244  func WriteServiceRoutingResponse(response http.ResponseWriter, request *http.Request, serviceRouting *ServiceRouting, result interface{}) error {
   245  	if result == nil {
   246  		result = struct{}{}
   247  	}
   248  	statusCodeAccessor, ok := result.(StatucCodeAccessor)
   249  	if ok {
   250  		statusCode := statusCodeAccessor.GetStatusCode()
   251  		if statusCode > 0 && statusCode != http.StatusOK {
   252  			response.WriteHeader(statusCode)
   253  			return nil
   254  		}
   255  	}
   256  	contentTypeAccessor, ok := result.(ContentTypeAccessor)
   257  	var responseContentType string
   258  	if ok {
   259  		responseContentType = contentTypeAccessor.GetContentType()
   260  	}
   261  	if responseContentType == "" {
   262  		requestContentType := request.Header.Get(contentTypeHeader)
   263  		responseContentType = getContentTypeOrJSONContentType(requestContentType)
   264  	}
   265  	encoderFactory := serviceRouting.getEncoderFactory(responseContentType)
   266  	encoder := encoderFactory.Create(response)
   267  	response.Header().Set(contentTypeHeader, responseContentType)
   268  	err := encoder.Encode(result)
   269  	if err != nil {
   270  		return fmt.Errorf("failed to encode response %v, due to %v", response, err)
   271  	}
   272  	return nil
   273  }
   274  
   275  //WriteResponse writes response to response writer, it used encoder factory to encode passed in response to the writer, it sets back request contenttype to response.
   276  func (r *ServiceRouter) WriteResponse(encoderFactory EncoderFactory, response interface{}, request *http.Request, responseWriter http.ResponseWriter) error {
   277  	requestContentType := request.Header.Get(contentTypeHeader)
   278  	responseContentType := getContentTypeOrJSONContentType(requestContentType)
   279  	encoder := encoderFactory.Create(responseWriter)
   280  	responseWriter.Header().Set(contentTypeHeader, responseContentType)
   281  	err := encoder.Encode(response)
   282  	if err != nil {
   283  		return fmt.Errorf("failed to encode response %v, due to %v", response, err)
   284  	}
   285  	return nil
   286  }
   287  
   288  //NewServiceRouter creates a new service router, is takes list of service routing as arguments
   289  func NewServiceRouter(serviceRouting ...ServiceRouting) *ServiceRouter {
   290  	var routings = make([]*ServiceRouting, 0)
   291  	for i := range serviceRouting {
   292  		routings = append(routings, &serviceRouting[i])
   293  	}
   294  	return &ServiceRouter{routings}
   295  }
   296  
   297  //RouteToService calls web service url, with passed in json request, and encodes http json response into passed response
   298  func RouteToService(method, url string, request, response interface{}, options ...*HttpOptions) (err error) {
   299  	client, err := NewToolboxHTTPClient(options...)
   300  	if err != nil {
   301  		return err
   302  	}
   303  	return client.Request(method, url, request, response, NewJSONEncoderFactory(), NewJSONDecoderFactory())
   304  }
   305  
   306  type HttpOptions struct {
   307  	Key   string
   308  	Value interface{}
   309  }
   310  
   311  func NewHttpClient(options ...*HttpOptions) (*http.Client, error) {
   312  	if len(options) == 0 {
   313  		return http.DefaultClient, nil
   314  	}
   315  
   316  	var (
   317  		// Default values matching DefaultHttpClient
   318  		RequestTimeoutMs        = 30 * time.Second
   319  		KeepAliveTimeMs         = 30 * time.Second
   320  		TLSHandshakeTimeoutMs   = 10 * time.Second
   321  		ExpectContinueTimeout   = 1 * time.Second
   322  		IdleConnTimeout         = 90 * time.Second
   323  		DualStack               = true
   324  		MaxIdleConnsPerHost     = http.DefaultMaxIdleConnsPerHost
   325  		MaxIdleConns            = 100
   326  		FollowRedirects         = true
   327  		ResponseHeaderTimeoutMs time.Duration
   328  		TimeoutMs               time.Duration
   329  	)
   330  
   331  	for _, option := range options {
   332  		switch option.Key {
   333  		case "RequestTimeoutMs":
   334  			RequestTimeoutMs = time.Duration(AsInt(option.Value)) * time.Millisecond
   335  		case "TimeoutMs":
   336  			TimeoutMs = time.Duration(AsInt(option.Value)) * time.Millisecond
   337  		case "KeepAliveTimeMs":
   338  			KeepAliveTimeMs = time.Duration(AsInt(option.Value)) * time.Millisecond
   339  		case "TLSHandshakeTimeoutMs":
   340  			KeepAliveTimeMs = time.Duration(AsInt(option.Value)) * time.Millisecond
   341  		case "ResponseHeaderTimeoutMs":
   342  			ResponseHeaderTimeoutMs = time.Duration(AsInt(option.Value)) * time.Millisecond
   343  		case "MaxIdleConns":
   344  			MaxIdleConns = AsInt(option.Value)
   345  		case "MaxIdleConnsPerHost":
   346  			MaxIdleConnsPerHost = AsInt(option.Value)
   347  		case "DualStack":
   348  			DualStack = AsBoolean(option.Value)
   349  		case "FollowRedirects":
   350  			FollowRedirects = AsBoolean(option.Value)
   351  		default:
   352  			return nil, fmt.Errorf("Invalid option: %v", option.Key)
   353  
   354  		}
   355  	}
   356  	roundTripper := http.Transport{
   357  		Proxy: http.ProxyFromEnvironment,
   358  		DialContext: (&net.Dialer{
   359  			Timeout:   RequestTimeoutMs,
   360  			KeepAlive: KeepAliveTimeMs,
   361  			DualStack: DualStack,
   362  		}).DialContext,
   363  		MaxIdleConns:          MaxIdleConns,
   364  		ExpectContinueTimeout: ExpectContinueTimeout,
   365  		IdleConnTimeout:       IdleConnTimeout,
   366  		TLSHandshakeTimeout:   TLSHandshakeTimeoutMs,
   367  		MaxIdleConnsPerHost:   MaxIdleConnsPerHost,
   368  		ResponseHeaderTimeout: ResponseHeaderTimeoutMs,
   369  	}
   370  
   371  	client := &http.Client{
   372  		Transport: &roundTripper,
   373  		Timeout:   TimeoutMs,
   374  	}
   375  
   376  	if !FollowRedirects {
   377  		client.CheckRedirect = func(req *http.Request, via []*http.Request) error {
   378  			return http.ErrUseLastResponse
   379  		}
   380  	}
   381  	return client, nil
   382  }
   383  
   384  // ToolboxHTTPClient contains preconfigured http client
   385  type ToolboxHTTPClient struct {
   386  	httpClient *http.Client
   387  }
   388  
   389  // NewToolboxHTTPClient instantiate new client with provided options
   390  func NewToolboxHTTPClient(options ...*HttpOptions) (*ToolboxHTTPClient, error) {
   391  	client, err := NewHttpClient(options...)
   392  	if err != nil {
   393  		return nil, err
   394  	}
   395  	return &ToolboxHTTPClient{client}, nil
   396  }
   397  
   398  // Request sends http request using the existing client
   399  func (c *ToolboxHTTPClient) Request(method, url string, request, response interface{}, encoderFactory EncoderFactory, decoderFactory DecoderFactory) (err error) {
   400  	if _, found := httpMethods[strings.ToUpper(method)]; !found {
   401  		return errors.New("unsupported method:" + method)
   402  	}
   403  	var buffer *bytes.Buffer
   404  
   405  	if request != nil {
   406  		buffer = new(bytes.Buffer)
   407  		if IsString(request) {
   408  			buffer.Write([]byte(AsString(request)))
   409  		} else {
   410  			err := encoderFactory.Create(buffer).Encode(&request)
   411  			if err != nil {
   412  				return fmt.Errorf("failed to encode request: %v due to ", err)
   413  			}
   414  		}
   415  	}
   416  	var serverResponse *http.Response
   417  	var httpRequest *http.Request
   418  	httpMethod := strings.ToUpper(method)
   419  
   420  	if request != nil {
   421  		httpRequest, err = http.NewRequest(httpMethod, url, buffer)
   422  		if err != nil {
   423  			return err
   424  		}
   425  		httpRequest.Header.Set(contentTypeHeader, jsonContentType)
   426  	} else {
   427  		httpRequest, err = http.NewRequest(httpMethod, url, nil)
   428  		if err != nil {
   429  			return err
   430  		}
   431  	}
   432  	serverResponse, err = c.httpClient.Do(httpRequest)
   433  	if serverResponse != nil {
   434  		// must close we have serverResponse to avoid fd leak
   435  		defer serverResponse.Body.Close()
   436  	}
   437  	if err != nil && serverResponse != nil {
   438  		return fmt.Errorf("failed to get response %v %v", err, serverResponse.Header.Get("error"))
   439  	}
   440  
   441  	if response != nil {
   442  		updateResponse(serverResponse, response)
   443  		if serverResponse == nil {
   444  			return fmt.Errorf("failed to receive response %v", err)
   445  		}
   446  		var errorPrefix = fmt.Sprintf("failed to process response: %v, ", serverResponse.StatusCode)
   447  		body, err := ioutil.ReadAll(serverResponse.Body)
   448  		if err != nil {
   449  			return fmt.Errorf("%v unable read body %v", errorPrefix, err)
   450  		}
   451  		if len(body) == 0 {
   452  			return fmt.Errorf("%v response body was empty", errorPrefix)
   453  		}
   454  
   455  		if serverResponse.StatusCode == http.StatusNotFound {
   456  			updateResponse(serverResponse, response)
   457  			return nil
   458  		}
   459  
   460  		if int(serverResponse.StatusCode/100)*100 == http.StatusInternalServerError {
   461  			return errors.New(string(body))
   462  		}
   463  		err = decoderFactory.Create(strings.NewReader(string(body))).Decode(response)
   464  		if err != nil {
   465  			return fmt.Errorf("%v. unable decode response as %T: body: %v: %v", errorPrefix, response, string(body), err)
   466  		}
   467  		updateResponse(serverResponse, response)
   468  	}
   469  	return nil
   470  }
   471  
   472  //StatucCodeMutator client side reponse optional interface
   473  type StatucCodeMutator interface {
   474  	SetStatusCode(code int)
   475  }
   476  
   477  //StatucCodeAccessor server side response accessor
   478  type StatucCodeAccessor interface {
   479  	GetStatusCode() int
   480  }
   481  
   482  //ContentTypeMutator client side reponse optional interface
   483  type ContentTypeMutator interface {
   484  	SetContentType(contentType string)
   485  }
   486  
   487  //ContentTypeAccessor server side response accessor
   488  type ContentTypeAccessor interface {
   489  	GetContentType() string
   490  }
   491  
   492  //updateResponse update response with content type and status code if applicable
   493  func updateResponse(httpResponse *http.Response, response interface{}) {
   494  	if response == nil {
   495  		return
   496  	}
   497  	statusCodeMutator, ok := response.(StatucCodeMutator)
   498  	if ok {
   499  		statusCodeMutator.SetStatusCode(httpResponse.StatusCode)
   500  	}
   501  	contentTypeMutator, ok := response.(ContentTypeMutator)
   502  	if ok {
   503  		contentTypeMutator.SetContentType(httpResponse.Header.Get(contentTypeHeader))
   504  	}
   505  }