github.com/cloudwego/kitex@v0.9.0/pkg/remote/trans/detection/server_handler.go (about) 1 /* 2 * Copyright 2021 CloudWeGo Authors 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 // Package detection protocol detection, it is used for a scenario that switching KitexProtobuf to gRPC. 18 // No matter KitexProtobuf or gRPC the server side can handle with this detection handler. 19 package detection 20 21 import ( 22 "context" 23 "net" 24 25 "github.com/cloudwego/kitex/pkg/endpoint" 26 "github.com/cloudwego/kitex/pkg/klog" 27 "github.com/cloudwego/kitex/pkg/remote" 28 ) 29 30 // DetectableServerTransHandler implements an additional method ProtocolMatch to help 31 // DetectionHandler to judge which serverHandler should handle the request data. 32 type DetectableServerTransHandler interface { 33 remote.ServerTransHandler 34 ProtocolMatch(ctx context.Context, conn net.Conn) (err error) 35 } 36 37 // NewSvrTransHandlerFactory detection factory construction. Each detectableHandlerFactory should return 38 // a ServerTransHandler which implements DetectableServerTransHandler after called NewTransHandler. 39 func NewSvrTransHandlerFactory(defaultHandlerFactory remote.ServerTransHandlerFactory, 40 detectableHandlerFactory ...remote.ServerTransHandlerFactory, 41 ) remote.ServerTransHandlerFactory { 42 return &svrTransHandlerFactory{ 43 defaultHandlerFactory: defaultHandlerFactory, 44 detectableHandlerFactory: detectableHandlerFactory, 45 } 46 } 47 48 type svrTransHandlerFactory struct { 49 defaultHandlerFactory remote.ServerTransHandlerFactory 50 detectableHandlerFactory []remote.ServerTransHandlerFactory 51 } 52 53 func (f *svrTransHandlerFactory) MuxEnabled() bool { 54 return false 55 } 56 57 func (f *svrTransHandlerFactory) NewTransHandler(opt *remote.ServerOption) (remote.ServerTransHandler, error) { 58 t := &svrTransHandler{} 59 var err error 60 for i := range f.detectableHandlerFactory { 61 h, err := f.detectableHandlerFactory[i].NewTransHandler(opt) 62 if err != nil { 63 return nil, err 64 } 65 handler, ok := h.(DetectableServerTransHandler) 66 if !ok { 67 klog.Errorf("KITEX: failed to append detection server trans handler: %T", h) 68 continue 69 } 70 t.registered = append(t.registered, handler) 71 } 72 if t.defaultHandler, err = f.defaultHandlerFactory.NewTransHandler(opt); err != nil { 73 return nil, err 74 } 75 return t, nil 76 } 77 78 type svrTransHandler struct { 79 defaultHandler remote.ServerTransHandler 80 registered []DetectableServerTransHandler 81 } 82 83 func (t *svrTransHandler) Write(ctx context.Context, conn net.Conn, send remote.Message) (nctx context.Context, err error) { 84 return t.which(ctx).Write(ctx, conn, send) 85 } 86 87 func (t *svrTransHandler) Read(ctx context.Context, conn net.Conn, msg remote.Message) (nctx context.Context, err error) { 88 return t.which(ctx).Read(ctx, conn, msg) 89 } 90 91 func (t *svrTransHandler) OnRead(ctx context.Context, conn net.Conn) (err error) { 92 // only need detect once when connection is reused 93 r := ctx.Value(handlerKey{}).(*handlerWrapper) 94 if r.handler != nil { 95 return r.handler.OnRead(r.ctx, conn) 96 } 97 // compare preface one by one 98 var which remote.ServerTransHandler 99 for i := range t.registered { 100 if t.registered[i].ProtocolMatch(ctx, conn) == nil { 101 which = t.registered[i] 102 break 103 } 104 } 105 if which != nil { 106 ctx, err = which.OnActive(ctx, conn) 107 if err != nil { 108 return err 109 } 110 } else { 111 which = t.defaultHandler 112 } 113 r.ctx, r.handler = ctx, which 114 return which.OnRead(ctx, conn) 115 } 116 117 func (t *svrTransHandler) OnInactive(ctx context.Context, conn net.Conn) { 118 // Should use the ctx returned by OnActive in r.ctx 119 if r, ok := ctx.Value(handlerKey{}).(*handlerWrapper); ok && r.ctx != nil { 120 ctx = r.ctx 121 } 122 t.which(ctx).OnInactive(ctx, conn) 123 } 124 125 func (t *svrTransHandler) OnError(ctx context.Context, err error, conn net.Conn) { 126 t.which(ctx).OnError(ctx, err, conn) 127 } 128 129 func (t *svrTransHandler) OnMessage(ctx context.Context, args, result remote.Message) (context.Context, error) { 130 return t.which(ctx).OnMessage(ctx, args, result) 131 } 132 133 func (t *svrTransHandler) which(ctx context.Context) remote.ServerTransHandler { 134 if r, ok := ctx.Value(handlerKey{}).(*handlerWrapper); ok && r.handler != nil { 135 return r.handler 136 } 137 // use noop transHandler 138 return noopHandler 139 } 140 141 func (t *svrTransHandler) SetPipeline(pipeline *remote.TransPipeline) { 142 for i := range t.registered { 143 t.registered[i].SetPipeline(pipeline) 144 } 145 t.defaultHandler.SetPipeline(pipeline) 146 } 147 148 func (t *svrTransHandler) SetInvokeHandleFunc(inkHdlFunc endpoint.Endpoint) { 149 for i := range t.registered { 150 if s, ok := t.registered[i].(remote.InvokeHandleFuncSetter); ok { 151 s.SetInvokeHandleFunc(inkHdlFunc) 152 } 153 } 154 if t, ok := t.defaultHandler.(remote.InvokeHandleFuncSetter); ok { 155 t.SetInvokeHandleFunc(inkHdlFunc) 156 } 157 } 158 159 func (t *svrTransHandler) OnActive(ctx context.Context, conn net.Conn) (context.Context, error) { 160 ctx, err := t.defaultHandler.OnActive(ctx, conn) 161 if err != nil { 162 return nil, err 163 } 164 // svrTransHandler wraps multi kinds of ServerTransHandler. 165 // We think that one connection only use one type, it doesn't need to do protocol detection for every request. 166 // And ctx is initialized with a new connection, so we put a handlerWrapper into ctx, which for recording 167 // the actual handler, then the later request don't need to do detection. 168 return context.WithValue(ctx, handlerKey{}, &handlerWrapper{}), nil 169 } 170 171 func (t *svrTransHandler) GracefulShutdown(ctx context.Context) error { 172 for i := range t.registered { 173 if g, ok := t.registered[i].(remote.GracefulShutdown); ok { 174 g.GracefulShutdown(ctx) 175 } 176 } 177 if g, ok := t.defaultHandler.(remote.GracefulShutdown); ok { 178 g.GracefulShutdown(ctx) 179 } 180 return nil 181 } 182 183 type handlerKey struct{} 184 185 type handlerWrapper struct { 186 ctx context.Context 187 handler remote.ServerTransHandler 188 }