github.com/aacfactory/fns@v1.2.86-0.20240310083819-80d667fc0a17/proxies/handler.go (about) 1 /* 2 * Copyright 2023 Wang Min Xiang 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 * 16 */ 17 18 package proxies 19 20 import ( 21 "bytes" 22 "github.com/aacfactory/errors" 23 "github.com/aacfactory/fns/clusters" 24 "github.com/aacfactory/fns/commons/bytex" 25 "github.com/aacfactory/fns/commons/mmhash" 26 "github.com/aacfactory/fns/commons/versions" 27 "github.com/aacfactory/fns/context" 28 "github.com/aacfactory/fns/services" 29 "github.com/aacfactory/fns/transports" 30 "github.com/valyala/bytebufferpool" 31 "golang.org/x/sync/singleflight" 32 "net/textproto" 33 "strconv" 34 ) 35 36 var ( 37 slashBytes = []byte{'/'} 38 ) 39 40 func NewProxyHandler(manager clusters.ClusterEndpointsManager, dialer transports.Dialer) transports.MuxHandler { 41 return &proxyHandler{ 42 manager: manager, 43 dialer: dialer, 44 group: singleflight.Group{}, 45 } 46 } 47 48 type proxyHandler struct { 49 manager clusters.ClusterEndpointsManager 50 dialer transports.Dialer 51 group singleflight.Group 52 } 53 54 func (handler *proxyHandler) Name() string { 55 return "proxy" 56 } 57 58 func (handler *proxyHandler) Construct(_ transports.MuxHandlerOptions) error { 59 return nil 60 } 61 62 func (handler *proxyHandler) Match(_ context.Context, method []byte, path []byte, header transports.Header) bool { 63 if bytes.Equal(method, transports.MethodPost) { 64 return len(bytes.Split(path, slashBytes)) == 3 && 65 (bytes.Equal(header.Get(transports.ContentTypeHeaderName), transports.ContentTypeJsonHeaderValue) || 66 bytes.Equal(header.Get(transports.ContentTypeHeaderName), transports.ContentTypeAvroHeaderValue)) 67 } 68 if bytes.Equal(method, transports.MethodGet) { 69 return len(bytes.Split(path, slashBytes)) == 3 70 } 71 return false 72 } 73 74 func (handler *proxyHandler) Handle(w transports.ResponseWriter, r transports.Request) { 75 groupKeyBuf := bytebufferpool.Get() 76 // path 77 path := r.Path() 78 pathItems := bytes.Split(path, slashBytes) 79 if len(pathItems) != 3 { 80 bytebufferpool.Put(groupKeyBuf) 81 w.Failed(ErrInvalidPath.WithMeta("path", bytex.ToString(path))) 82 return 83 } 84 service := pathItems[1] 85 fn := pathItems[2] 86 _, _ = groupKeyBuf.Write(path) 87 // device id 88 deviceId := r.Header().Get(transports.DeviceIdHeaderName) 89 if len(deviceId) == 0 { 90 bytebufferpool.Put(groupKeyBuf) 91 w.Failed(ErrDeviceId.WithMeta("path", bytex.ToString(path))) 92 return 93 } 94 _, _ = groupKeyBuf.Write(deviceId) 95 96 // discovery 97 endpointGetOptions := make([]services.EndpointGetOption, 0, 1) 98 var intervals versions.Intervals 99 acceptedVersions := r.Header().Get(transports.RequestVersionsHeaderName) 100 if len(acceptedVersions) > 0 { 101 var intervalsErr error 102 intervals, intervalsErr = versions.ParseIntervals(acceptedVersions) 103 if intervalsErr != nil { 104 bytebufferpool.Put(groupKeyBuf) 105 w.Failed(ErrInvalidRequestVersions.WithMeta("path", bytex.ToString(path)).WithMeta("versions", bytex.ToString(acceptedVersions)).WithCause(intervalsErr)) 106 return 107 } 108 endpointGetOptions = append(endpointGetOptions, services.EndpointVersions(intervals)) 109 _, _ = groupKeyBuf.Write(acceptedVersions) 110 } 111 112 var queryParams transports.Params 113 var body []byte 114 method := r.Method() 115 if bytes.Equal(method, transports.MethodGet) { 116 queryParams = r.Params() 117 queryParamsBytes := queryParams.Encode() 118 path = append(path, '?') 119 path = append(path, queryParamsBytes...) 120 _, _ = groupKeyBuf.Write(queryParamsBytes) 121 } else { 122 var bodyErr error 123 body, bodyErr = r.Body() 124 if bodyErr != nil { 125 bytebufferpool.Put(groupKeyBuf) 126 w.Failed(errors.Warning("fns: read request body failed").WithCause(bodyErr). 127 WithMeta("endpoint", bytex.ToString(service)). 128 WithMeta("fn", bytex.ToString(fn))) 129 return 130 } 131 _, _ = groupKeyBuf.Write(body) 132 } 133 134 groupKey := strconv.FormatUint(mmhash.Sum64(groupKeyBuf.Bytes()), 16) 135 bytebufferpool.Put(groupKeyBuf) 136 v, err, _ := handler.group.Do(groupKey, func() (v interface{}, err error) { 137 address, internal, has := handler.manager.FnAddress(r, service, fn, endpointGetOptions...) 138 if !has { 139 err = errors.NotFound("fns: endpoint was not found"). 140 WithMeta("endpoint", bytex.ToString(service)). 141 WithMeta("fn", bytex.ToString(fn)) 142 return 143 } 144 if internal { 145 err = errors.NotFound("fns: fn was internal"). 146 WithMeta("endpoint", bytex.ToString(service)). 147 WithMeta("fn", bytex.ToString(fn)) 148 return 149 } 150 151 client, clientErr := handler.dialer.Dial(bytex.FromString(address)) 152 if clientErr != nil { 153 err = errors.Warning("fns: dial endpoint failed").WithCause(clientErr). 154 WithMeta("endpoint", bytex.ToString(service)). 155 WithMeta("fn", bytex.ToString(fn)) 156 return 157 } 158 159 header := transports.AcquireHeader() 160 defer transports.ReleaseHeader(header) 161 r.Header().Foreach(func(key []byte, values [][]byte) { 162 for _, value := range values { 163 header.Add(key, value) 164 } 165 }) 166 removeHopByHopHeaders(header) 167 168 status, respHeader, respBody, doErr := client.Do(r, method, path, header, body) 169 if doErr != nil { 170 err = errors.Warning("fns: send request to endpoint failed").WithCause(doErr). 171 WithMeta("endpoint", bytex.ToString(service)). 172 WithMeta("fn", bytex.ToString(fn)) 173 return 174 } 175 v = Response{ 176 Status: status, 177 Header: respHeader, 178 Value: respBody, 179 } 180 return 181 }) 182 handler.group.Forget(groupKey) 183 184 if err != nil { 185 w.Failed(err) 186 return 187 } 188 189 response := v.(Response) 190 if response.Header.Len() > 0 { 191 response.Header.Foreach(func(key []byte, values [][]byte) { 192 for _, value := range values { 193 w.Header().Add(key, value) 194 } 195 }) 196 } 197 w.SetStatus(response.Status) 198 _, _ = w.Write(response.Value) 199 } 200 201 type Response struct { 202 Status int 203 Header transports.Header 204 Value []byte 205 } 206 207 var hopHeaders = [][]byte{ 208 []byte("Connection"), 209 []byte("Proxy-Connection"), 210 []byte("Keep-Alive"), 211 []byte("Proxy-Authenticate"), 212 []byte("Proxy-Authorization"), 213 []byte("Te"), 214 []byte("Trailer"), 215 []byte("Transfer-Encoding"), 216 []byte("Upgrade"), 217 []byte("Origin"), 218 } 219 220 var ( 221 comma = []byte{','} 222 ) 223 224 func removeHopByHopHeaders(h transports.Header) { 225 // RFC 7230, section 6.1: Remove headers listed in the "Connection" header. 226 for _, f := range h.Values(transports.ConnectionHeaderName) { 227 for _, sf := range bytes.Split(f, comma) { 228 if sf = bytex.FromString(textproto.TrimString(bytex.ToString(sf))); len(sf) > 0 { 229 h.Del(sf) 230 } 231 } 232 } 233 // RFC 2616, section 13.5.1: Remove a set of known hop-by-hop headers. 234 // This behavior is superseded by the RFC 7230 Connection header, but 235 // preserve it for backwards compatibility. 236 for _, f := range hopHeaders { 237 h.Del(f) 238 } 239 }