
     1  /*
     2   * Copyright 2023 Wang Min Xiang
     3   *
     4   * Licensed under the Apache License, Version 2.0 (the "License");
     5   * you may not use this file except in compliance with the License.
     6   * You may obtain a copy of the License at
     7   *
     8   *
     9   *
    10   * Unless required by applicable law or agreed to in writing, software
    11   * distributed under the License is distributed on an "AS IS" BASIS,
    12   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    13   * See the License for the specific language governing permissions and
    14   * limitations under the License.
    15   *
    16   */
    18  package proxies
    20  import (
    21  	"bytes"
    22  	""
    23  	""
    24  	""
    25  	""
    26  	""
    27  	""
    28  	""
    29  	""
    30  	""
    31  	""
    32  	"net/textproto"
    33  	"strconv"
    34  )
    36  var (
    37  	slashBytes = []byte{'/'}
    38  )
    40  func NewProxyHandler(manager clusters.ClusterEndpointsManager, dialer transports.Dialer) transports.MuxHandler {
    41  	return &proxyHandler{
    42  		manager: manager,
    43  		dialer:  dialer,
    44  		group:   singleflight.Group{},
    45  	}
    46  }
    48  type proxyHandler struct {
    49  	manager clusters.ClusterEndpointsManager
    50  	dialer  transports.Dialer
    51  	group   singleflight.Group
    52  }
    54  func (handler *proxyHandler) Name() string {
    55  	return "proxy"
    56  }
    58  func (handler *proxyHandler) Construct(_ transports.MuxHandlerOptions) error {
    59  	return nil
    60  }
    62  func (handler *proxyHandler) Match(_ context.Context, method []byte, path []byte, header transports.Header) bool {
    63  	if bytes.Equal(method, transports.MethodPost) {
    64  		return len(bytes.Split(path, slashBytes)) == 3 &&
    65  			(bytes.Equal(header.Get(transports.ContentTypeHeaderName), transports.ContentTypeJsonHeaderValue) ||
    66  				bytes.Equal(header.Get(transports.ContentTypeHeaderName), transports.ContentTypeAvroHeaderValue))
    67  	}
    68  	if bytes.Equal(method, transports.MethodGet) {
    69  		return len(bytes.Split(path, slashBytes)) == 3
    70  	}
    71  	return false
    72  }
    74  func (handler *proxyHandler) Handle(w transports.ResponseWriter, r transports.Request) {
    75  	groupKeyBuf := bytebufferpool.Get()
    76  	// path
    77  	path := r.Path()
    78  	pathItems := bytes.Split(path, slashBytes)
    79  	if len(pathItems) != 3 {
    80  		bytebufferpool.Put(groupKeyBuf)
    81  		w.Failed(ErrInvalidPath.WithMeta("path", bytex.ToString(path)))
    82  		return
    83  	}
    84  	service := pathItems[1]
    85  	fn := pathItems[2]
    86  	_, _ = groupKeyBuf.Write(path)
    87  	// device id
    88  	deviceId := r.Header().Get(transports.DeviceIdHeaderName)
    89  	if len(deviceId) == 0 {
    90  		bytebufferpool.Put(groupKeyBuf)
    91  		w.Failed(ErrDeviceId.WithMeta("path", bytex.ToString(path)))
    92  		return
    93  	}
    94  	_, _ = groupKeyBuf.Write(deviceId)
    96  	// discovery
    97  	endpointGetOptions := make([]services.EndpointGetOption, 0, 1)
    98  	var intervals versions.Intervals
    99  	acceptedVersions := r.Header().Get(transports.RequestVersionsHeaderName)
   100  	if len(acceptedVersions) > 0 {
   101  		var intervalsErr error
   102  		intervals, intervalsErr = versions.ParseIntervals(acceptedVersions)
   103  		if intervalsErr != nil {
   104  			bytebufferpool.Put(groupKeyBuf)
   105  			w.Failed(ErrInvalidRequestVersions.WithMeta("path", bytex.ToString(path)).WithMeta("versions", bytex.ToString(acceptedVersions)).WithCause(intervalsErr))
   106  			return
   107  		}
   108  		endpointGetOptions = append(endpointGetOptions, services.EndpointVersions(intervals))
   109  		_, _ = groupKeyBuf.Write(acceptedVersions)
   110  	}
   112  	var queryParams transports.Params
   113  	var body []byte
   114  	method := r.Method()
   115  	if bytes.Equal(method, transports.MethodGet) {
   116  		queryParams = r.Params()
   117  		queryParamsBytes := queryParams.Encode()
   118  		path = append(path, '?')
   119  		path = append(path, queryParamsBytes...)
   120  		_, _ = groupKeyBuf.Write(queryParamsBytes)
   121  	} else {
   122  		var bodyErr error
   123  		body, bodyErr = r.Body()
   124  		if bodyErr != nil {
   125  			bytebufferpool.Put(groupKeyBuf)
   126  			w.Failed(errors.Warning("fns: read request body failed").WithCause(bodyErr).
   127  				WithMeta("endpoint", bytex.ToString(service)).
   128  				WithMeta("fn", bytex.ToString(fn)))
   129  			return
   130  		}
   131  		_, _ = groupKeyBuf.Write(body)
   132  	}
   134  	groupKey := strconv.FormatUint(mmhash.Sum64(groupKeyBuf.Bytes()), 16)
   135  	bytebufferpool.Put(groupKeyBuf)
   136  	v, err, _ :=, func() (v interface{}, err error) {
   137  		address, internal, has := handler.manager.FnAddress(r, service, fn, endpointGetOptions...)
   138  		if !has {
   139  			err = errors.NotFound("fns: endpoint was not found").
   140  				WithMeta("endpoint", bytex.ToString(service)).
   141  				WithMeta("fn", bytex.ToString(fn))
   142  			return
   143  		}
   144  		if internal {
   145  			err = errors.NotFound("fns: fn was internal").
   146  				WithMeta("endpoint", bytex.ToString(service)).
   147  				WithMeta("fn", bytex.ToString(fn))
   148  			return
   149  		}
   151  		client, clientErr := handler.dialer.Dial(bytex.FromString(address))
   152  		if clientErr != nil {
   153  			err = errors.Warning("fns: dial endpoint failed").WithCause(clientErr).
   154  				WithMeta("endpoint", bytex.ToString(service)).
   155  				WithMeta("fn", bytex.ToString(fn))
   156  			return
   157  		}
   159  		header := transports.AcquireHeader()
   160  		defer transports.ReleaseHeader(header)
   161  		r.Header().Foreach(func(key []byte, values [][]byte) {
   162  			for _, value := range values {
   163  				header.Add(key, value)
   164  			}
   165  		})
   166  		removeHopByHopHeaders(header)
   168  		status, respHeader, respBody, doErr := client.Do(r, method, path, header, body)
   169  		if doErr != nil {
   170  			err = errors.Warning("fns: send request to endpoint failed").WithCause(doErr).
   171  				WithMeta("endpoint", bytex.ToString(service)).
   172  				WithMeta("fn", bytex.ToString(fn))
   173  			return
   174  		}
   175  		v = Response{
   176  			Status: status,
   177  			Header: respHeader,
   178  			Value:  respBody,
   179  		}
   180  		return
   181  	})
   184  	if err != nil {
   185  		w.Failed(err)
   186  		return
   187  	}
   189  	response := v.(Response)
   190  	if response.Header.Len() > 0 {
   191  		response.Header.Foreach(func(key []byte, values [][]byte) {
   192  			for _, value := range values {
   193  				w.Header().Add(key, value)
   194  			}
   195  		})
   196  	}
   197  	w.SetStatus(response.Status)
   198  	_, _ = w.Write(response.Value)
   199  }
   201  type Response struct {
   202  	Status int
   203  	Header transports.Header
   204  	Value  []byte
   205  }
   207  var hopHeaders = [][]byte{
   208  	[]byte("Connection"),
   209  	[]byte("Proxy-Connection"),
   210  	[]byte("Keep-Alive"),
   211  	[]byte("Proxy-Authenticate"),
   212  	[]byte("Proxy-Authorization"),
   213  	[]byte("Te"),
   214  	[]byte("Trailer"),
   215  	[]byte("Transfer-Encoding"),
   216  	[]byte("Upgrade"),
   217  	[]byte("Origin"),
   218  }
   220  var (
   221  	comma = []byte{','}
   222  )
   224  func removeHopByHopHeaders(h transports.Header) {
   225  	// RFC 7230, section 6.1: Remove headers listed in the "Connection" header.
   226  	for _, f := range h.Values(transports.ConnectionHeaderName) {
   227  		for _, sf := range bytes.Split(f, comma) {
   228  			if sf = bytex.FromString(textproto.TrimString(bytex.ToString(sf))); len(sf) > 0 {
   229  				h.Del(sf)
   230  			}
   231  		}
   232  	}
   233  	// RFC 2616, section 13.5.1: Remove a set of known hop-by-hop headers.
   234  	// This behavior is superseded by the RFC 7230 Connection header, but
   235  	// preserve it for backwards compatibility.
   236  	for _, f := range hopHeaders {
   237  		h.Del(f)
   238  	}
   239  }