github.com/cloudwego/kitex@v0.9.0/pkg/remote/trans/detection/server_handler.go (about)

     1  /*
     2   * Copyright 2021 CloudWeGo Authors
     3   *
     4   * Licensed under the Apache License, Version 2.0 (the "License");
     5   * you may not use this file except in compliance with the License.
     6   * You may obtain a copy of the License at
     7   *
     8   *     http://www.apache.org/licenses/LICENSE-2.0
     9   *
    10   * Unless required by applicable law or agreed to in writing, software
    11   * distributed under the License is distributed on an "AS IS" BASIS,
    12   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    13   * See the License for the specific language governing permissions and
    14   * limitations under the License.
    15   */
    16  
    17  // Package detection protocol detection, it is used for a scenario that switching KitexProtobuf to gRPC.
    18  // No matter KitexProtobuf or gRPC the server side can handle with this detection handler.
    19  package detection
    20  
    21  import (
    22  	"context"
    23  	"net"
    24  
    25  	"github.com/cloudwego/kitex/pkg/endpoint"
    26  	"github.com/cloudwego/kitex/pkg/klog"
    27  	"github.com/cloudwego/kitex/pkg/remote"
    28  )
    29  
    30  // DetectableServerTransHandler implements an additional method ProtocolMatch to help
    31  // DetectionHandler to judge which serverHandler should handle the request data.
    32  type DetectableServerTransHandler interface {
    33  	remote.ServerTransHandler
    34  	ProtocolMatch(ctx context.Context, conn net.Conn) (err error)
    35  }
    36  
    37  // NewSvrTransHandlerFactory detection factory construction. Each detectableHandlerFactory should return
    38  // a ServerTransHandler which implements DetectableServerTransHandler after called NewTransHandler.
    39  func NewSvrTransHandlerFactory(defaultHandlerFactory remote.ServerTransHandlerFactory,
    40  	detectableHandlerFactory ...remote.ServerTransHandlerFactory,
    41  ) remote.ServerTransHandlerFactory {
    42  	return &svrTransHandlerFactory{
    43  		defaultHandlerFactory:    defaultHandlerFactory,
    44  		detectableHandlerFactory: detectableHandlerFactory,
    45  	}
    46  }
    47  
    48  type svrTransHandlerFactory struct {
    49  	defaultHandlerFactory    remote.ServerTransHandlerFactory
    50  	detectableHandlerFactory []remote.ServerTransHandlerFactory
    51  }
    52  
    53  func (f *svrTransHandlerFactory) MuxEnabled() bool {
    54  	return false
    55  }
    56  
    57  func (f *svrTransHandlerFactory) NewTransHandler(opt *remote.ServerOption) (remote.ServerTransHandler, error) {
    58  	t := &svrTransHandler{}
    59  	var err error
    60  	for i := range f.detectableHandlerFactory {
    61  		h, err := f.detectableHandlerFactory[i].NewTransHandler(opt)
    62  		if err != nil {
    63  			return nil, err
    64  		}
    65  		handler, ok := h.(DetectableServerTransHandler)
    66  		if !ok {
    67  			klog.Errorf("KITEX: failed to append detection server trans handler: %T", h)
    68  			continue
    69  		}
    70  		t.registered = append(t.registered, handler)
    71  	}
    72  	if t.defaultHandler, err = f.defaultHandlerFactory.NewTransHandler(opt); err != nil {
    73  		return nil, err
    74  	}
    75  	return t, nil
    76  }
    77  
    78  type svrTransHandler struct {
    79  	defaultHandler remote.ServerTransHandler
    80  	registered     []DetectableServerTransHandler
    81  }
    82  
    83  func (t *svrTransHandler) Write(ctx context.Context, conn net.Conn, send remote.Message) (nctx context.Context, err error) {
    84  	return t.which(ctx).Write(ctx, conn, send)
    85  }
    86  
    87  func (t *svrTransHandler) Read(ctx context.Context, conn net.Conn, msg remote.Message) (nctx context.Context, err error) {
    88  	return t.which(ctx).Read(ctx, conn, msg)
    89  }
    90  
    91  func (t *svrTransHandler) OnRead(ctx context.Context, conn net.Conn) (err error) {
    92  	// only need detect once when connection is reused
    93  	r := ctx.Value(handlerKey{}).(*handlerWrapper)
    94  	if r.handler != nil {
    95  		return r.handler.OnRead(r.ctx, conn)
    96  	}
    97  	// compare preface one by one
    98  	var which remote.ServerTransHandler
    99  	for i := range t.registered {
   100  		if t.registered[i].ProtocolMatch(ctx, conn) == nil {
   101  			which = t.registered[i]
   102  			break
   103  		}
   104  	}
   105  	if which != nil {
   106  		ctx, err = which.OnActive(ctx, conn)
   107  		if err != nil {
   108  			return err
   109  		}
   110  	} else {
   111  		which = t.defaultHandler
   112  	}
   113  	r.ctx, r.handler = ctx, which
   114  	return which.OnRead(ctx, conn)
   115  }
   116  
   117  func (t *svrTransHandler) OnInactive(ctx context.Context, conn net.Conn) {
   118  	// Should use the ctx returned by OnActive in r.ctx
   119  	if r, ok := ctx.Value(handlerKey{}).(*handlerWrapper); ok && r.ctx != nil {
   120  		ctx = r.ctx
   121  	}
   122  	t.which(ctx).OnInactive(ctx, conn)
   123  }
   124  
   125  func (t *svrTransHandler) OnError(ctx context.Context, err error, conn net.Conn) {
   126  	t.which(ctx).OnError(ctx, err, conn)
   127  }
   128  
   129  func (t *svrTransHandler) OnMessage(ctx context.Context, args, result remote.Message) (context.Context, error) {
   130  	return t.which(ctx).OnMessage(ctx, args, result)
   131  }
   132  
   133  func (t *svrTransHandler) which(ctx context.Context) remote.ServerTransHandler {
   134  	if r, ok := ctx.Value(handlerKey{}).(*handlerWrapper); ok && r.handler != nil {
   135  		return r.handler
   136  	}
   137  	// use noop transHandler
   138  	return noopHandler
   139  }
   140  
   141  func (t *svrTransHandler) SetPipeline(pipeline *remote.TransPipeline) {
   142  	for i := range t.registered {
   143  		t.registered[i].SetPipeline(pipeline)
   144  	}
   145  	t.defaultHandler.SetPipeline(pipeline)
   146  }
   147  
   148  func (t *svrTransHandler) SetInvokeHandleFunc(inkHdlFunc endpoint.Endpoint) {
   149  	for i := range t.registered {
   150  		if s, ok := t.registered[i].(remote.InvokeHandleFuncSetter); ok {
   151  			s.SetInvokeHandleFunc(inkHdlFunc)
   152  		}
   153  	}
   154  	if t, ok := t.defaultHandler.(remote.InvokeHandleFuncSetter); ok {
   155  		t.SetInvokeHandleFunc(inkHdlFunc)
   156  	}
   157  }
   158  
   159  func (t *svrTransHandler) OnActive(ctx context.Context, conn net.Conn) (context.Context, error) {
   160  	ctx, err := t.defaultHandler.OnActive(ctx, conn)
   161  	if err != nil {
   162  		return nil, err
   163  	}
   164  	// svrTransHandler wraps multi kinds of ServerTransHandler.
   165  	// We think that one connection only use one type, it doesn't need to do protocol detection for every request.
   166  	// And ctx is initialized with a new connection, so we put a handlerWrapper into ctx, which for recording
   167  	// the actual handler, then the later request don't need to do detection.
   168  	return context.WithValue(ctx, handlerKey{}, &handlerWrapper{}), nil
   169  }
   170  
   171  func (t *svrTransHandler) GracefulShutdown(ctx context.Context) error {
   172  	for i := range t.registered {
   173  		if g, ok := t.registered[i].(remote.GracefulShutdown); ok {
   174  			g.GracefulShutdown(ctx)
   175  		}
   176  	}
   177  	if g, ok := t.defaultHandler.(remote.GracefulShutdown); ok {
   178  		g.GracefulShutdown(ctx)
   179  	}
   180  	return nil
   181  }
   182  
   183  type handlerKey struct{}
   184  
   185  type handlerWrapper struct {
   186  	ctx     context.Context
   187  	handler remote.ServerTransHandler
   188  }