
     1  /*
     2   * Copyright 2023 Wang Min Xiang
     3   *
     4   * Licensed under the Apache License, Version 2.0 (the "License");
     5   * you may not use this file except in compliance with the License.
     6   * You may obtain a copy of the License at
     7   *
     8   *
     9   *
    10   * Unless required by applicable law or agreed to in writing, software
    11   * distributed under the License is distributed on an "AS IS" BASIS,
    12   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    13   * See the License for the specific language governing permissions and
    14   * limitations under the License.
    15   *
    16   */
    18  package clusters
    20  import (
    21  	"bytes"
    22  	""
    23  	""
    24  	""
    25  	""
    26  	""
    27  	""
    28  	""
    29  	""
    30  	""
    31  	""
    32  )
    34  var (
    35  	slashBytes                = []byte{'/'}
    36  	internalContentTypeHeader = []byte("application/avro+fns")
    37  	spanKey                   = []byte("span")
    38  )
    40  type Entry struct {
    41  	Key   []byte `json:"key" avro:"key"`
    42  	Value []byte `json:"value" avro:"value"`
    43  }
    45  type RequestBody struct {
    46  	ContextUserValues []Entry `json:"contextUserValues" avro:"contextUserValues"`
    47  	Params            []byte  `json:"params" avro:"params"`
    48  }
    50  type ResponseBody struct {
    51  	Succeed     bool    `json:"succeed" avro:"succeed"`
    52  	Data        []byte  `json:"data" avro:"data"`
    53  	Attachments []Entry `json:"attachments" avro:"attachments"`
    54  }
    56  func (rsp ResponseBody) GetSpan() (v *tracings.Span, has bool) {
    57  	for _, attachment := range rsp.Attachments {
    58  		if bytes.Equal(attachment.Key, spanKey) {
    59  			if len(attachment.Value) == 0 {
    60  				return
    61  			}
    62  			v = new(tracings.Span)
    63  			err := avro.Unmarshal(attachment.Value, v)
    64  			if err != nil {
    65  				return
    66  			}
    67  			has = true
    68  			return
    69  		}
    70  	}
    71  	return
    72  }
    74  func NewInternalHandler(local services.Endpoints, signature signatures.Signature) transports.MuxHandler {
    75  	return &InternalHandler{
    76  		signature: signature,
    77  		endpoints: local,
    78  	}
    79  }
    81  type InternalHandler struct {
    82  	signature signatures.Signature
    83  	endpoints services.Endpoints
    84  }
    86  func (handler *InternalHandler) Name() string {
    87  	return "internal"
    88  }
    90  func (handler *InternalHandler) Construct(_ transports.MuxHandlerOptions) error {
    91  	return nil
    92  }
    94  func (handler *InternalHandler) Match(_ context.Context, method []byte, path []byte, header transports.Header) bool {
    95  	matched := bytes.Equal(method, transports.MethodPost) &&
    96  		len(bytes.Split(path, slashBytes)) == 3 &&
    97  		len(header.Get(transports.SignatureHeaderName)) != 0 &&
    98  		bytes.Equal(header.Get(transports.ContentTypeHeaderName), internalContentTypeHeader)
    99  	return matched
   100  }
   102  func (handler *InternalHandler) Handle(w transports.ResponseWriter, r transports.Request) {
   103  	// path
   104  	path := r.Path()
   105  	pathItems := bytes.Split(path, slashBytes)
   106  	if len(pathItems) != 3 {
   107  		w.Failed(ErrInvalidPath.WithMeta("path", bytex.ToString(path)))
   108  		return
   109  	}
   110  	service := pathItems[1]
   111  	fn := pathItems[2]
   113  	// sign
   114  	sign := r.Header().Get(transports.SignatureHeaderName)
   115  	if len(sign) == 0 {
   116  		w.Failed(ErrSignatureLost.WithMeta("path", bytex.ToString(path)))
   117  		return
   118  	}
   119  	// body
   120  	body, bodyErr := r.Body()
   121  	if bodyErr != nil {
   122  		w.Failed(ErrInvalidBody.WithMeta("path", bytex.ToString(path)))
   123  		return
   124  	}
   126  	if !handler.signature.Verify(body, sign) {
   127  		w.Failed(ErrSignatureUnverified.WithMeta("path", bytex.ToString(path)))
   128  		return
   129  	}
   131  	rb := RequestBody{}
   132  	decodeErr := avro.Unmarshal(body, &rb)
   133  	if decodeErr != nil {
   134  		w.Failed(ErrInvalidBody.WithMeta("path", bytex.ToString(path)).WithCause(decodeErr))
   135  		return
   136  	}
   137  	// user values
   138  	for _, userValue := range rb.ContextUserValues {
   139  		r.SetUserValue(userValue.Key, userValue.Value)
   140  	}
   142  	// header >>>
   143  	options := make([]services.RequestOption, 0, 1)
   144  	// internal
   145  	options = append(options, services.WithInternalRequest())
   146  	// endpoint id
   147  	endpointId := r.Header().Get(transports.EndpointIdHeaderName)
   148  	if len(endpointId) > 0 {
   149  		options = append(options, services.WithEndpointId(endpointId))
   150  	}
   151  	// device id
   152  	deviceId := r.Header().Get(transports.DeviceIdHeaderName)
   153  	if len(deviceId) == 0 {
   154  		w.Failed(ErrDeviceId.WithMeta("path", bytex.ToString(path)))
   155  		return
   156  	}
   157  	options = append(options, services.WithDeviceId(deviceId))
   158  	// device ip
   159  	deviceIp := r.Header().Get(transports.DeviceIpHeaderName)
   160  	if len(deviceIp) > 0 {
   161  		options = append(options, services.WithDeviceIp(deviceIp))
   162  	}
   163  	// request id
   164  	requestId := r.Header().Get(transports.RequestIdHeaderName)
   165  	hasRequestId := len(requestId) > 0
   166  	if hasRequestId {
   167  		options = append(options, services.WithRequestId(requestId))
   168  	}
   169  	// request version
   170  	acceptedVersions := r.Header().Get(transports.RequestVersionsHeaderName)
   171  	if len(acceptedVersions) > 0 {
   172  		intervals, intervalsErr := versions.ParseIntervals(acceptedVersions)
   173  		if intervalsErr != nil {
   174  			w.Failed(ErrInvalidRequestVersions.WithMeta("path", bytex.ToString(path)).WithMeta("versions", bytex.ToString(acceptedVersions)).WithCause(intervalsErr))
   175  			return
   176  		}
   177  		options = append(options, services.WithRequestVersions(intervals))
   178  	}
   179  	// authorization
   180  	authorization := r.Header().Get(transports.AuthorizationHeaderName)
   181  	if len(authorization) > 0 {
   182  		options = append(options, services.WithToken(authorization))
   183  	}
   184  	// header <<<
   186  	// param
   187  	param := avros.RawMessage(rb.Params)
   189  	var ctx context.Context = r
   191  	// handle
   192  	response, err := handler.endpoints.Request(
   193  		ctx, service, fn,
   194  		param,
   195  		options...,
   196  	)
   197  	succeed := err == nil
   198  	var data []byte
   199  	var dataErr error
   200  	var span *tracings.Span
   201  	if succeed {
   202  		if response.Valid() {
   203  			responseValue := response.Value()
   204  			data, dataErr = avro.Marshal(responseValue)
   205  		}
   206  	} else {
   207  		data, _ = avro.Marshal(errors.Wrap(err))
   208  	}
   209  	if dataErr != nil {
   210  		succeed = false
   211  		data, _ = avro.Marshal(errors.Warning("fns: encode endpoint response failed").WithMeta("path", bytex.ToString(path)).WithCause(dataErr))
   212  	}
   214  	if hasRequestId {
   215  		trace, hasTrace := tracings.Load(ctx)
   216  		if hasTrace {
   217  			span = trace.Span
   218  		}
   219  	}
   221  	rsb := ResponseBody{
   222  		Succeed:     succeed,
   223  		Data:        data,
   224  		Attachments: make([]Entry, 0, 1),
   225  	}
   226  	if span != nil {
   227  		spanBytes, _ := avro.Marshal(span)
   228  		rsb.Attachments = append(rsb.Attachments, Entry{
   229  			Key:   spanKey,
   230  			Value: spanBytes,
   231  		})
   232  	}
   234  	p, encodeErr := avro.Marshal(rsb)
   235  	if encodeErr != nil {
   236  		w.Failed(errors.Warning("fns: proto marshal failed").WithCause(encodeErr))
   237  		return
   238  	}
   239  	_, _ = w.Write(p)
   240  }