github.com/cloudwego/kitex@v0.9.0/pkg/remote/trans/netpollmux/client_handler.go (about)

     1  /*
     2   * Copyright 2021 CloudWeGo Authors
     3   *
     4   * Licensed under the Apache License, Version 2.0 (the "License");
     5   * you may not use this file except in compliance with the License.
     6   * You may obtain a copy of the License at
     7   *
     8   *     http://www.apache.org/licenses/LICENSE-2.0
     9   *
    10   * Unless required by applicable law or agreed to in writing, software
    11   * distributed under the License is distributed on an "AS IS" BASIS,
    12   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    13   * See the License for the specific language governing permissions and
    14   * limitations under the License.
    15   */
    16  
    17  package netpollmux
    18  
    19  import (
    20  	"context"
    21  	"errors"
    22  	"fmt"
    23  	"net"
    24  	"sync/atomic"
    25  
    26  	"github.com/cloudwego/netpoll"
    27  
    28  	"github.com/cloudwego/kitex/pkg/kerrors"
    29  	"github.com/cloudwego/kitex/pkg/klog"
    30  	"github.com/cloudwego/kitex/pkg/remote"
    31  	"github.com/cloudwego/kitex/pkg/remote/codec"
    32  	"github.com/cloudwego/kitex/pkg/remote/trans"
    33  	np "github.com/cloudwego/kitex/pkg/remote/trans/netpoll"
    34  	"github.com/cloudwego/kitex/pkg/rpcinfo"
    35  	"github.com/cloudwego/kitex/pkg/serviceinfo"
    36  	"github.com/cloudwego/kitex/pkg/stats"
    37  )
    38  
    39  type cliTransHandlerFactory struct{}
    40  
    41  // NewCliTransHandlerFactory creates a new netpollmux client transport handler factory.
    42  func NewCliTransHandlerFactory() remote.ClientTransHandlerFactory {
    43  	return &cliTransHandlerFactory{}
    44  }
    45  
    46  // NewTransHandler implements the remote.ClientTransHandlerFactory interface.
    47  func (f *cliTransHandlerFactory) NewTransHandler(opt *remote.ClientOption) (remote.ClientTransHandler, error) {
    48  	if _, ok := opt.ConnPool.(*MuxPool); !ok {
    49  		return nil, fmt.Errorf("ConnPool[%T] invalid, netpoll mux just support MuxPool", opt.ConnPool)
    50  	}
    51  	return newCliTransHandler(opt)
    52  }
    53  
    54  func newCliTransHandler(opt *remote.ClientOption) (*cliTransHandler, error) {
    55  	return &cliTransHandler{
    56  		opt:   opt,
    57  		codec: opt.Codec,
    58  	}, nil
    59  }
    60  
    61  var _ remote.ClientTransHandler = &cliTransHandler{}
    62  
    63  type cliTransHandler struct {
    64  	opt       *remote.ClientOption
    65  	codec     remote.Codec
    66  	transPipe *remote.TransPipeline
    67  }
    68  
    69  // Write implements the remote.ClientTransHandler interface.
    70  func (t *cliTransHandler) Write(ctx context.Context, conn net.Conn, sendMsg remote.Message) (nctx context.Context, err error) {
    71  	ri := sendMsg.RPCInfo()
    72  	rpcinfo.Record(ctx, ri, stats.WriteStart, nil)
    73  	buf := netpoll.NewLinkBuffer()
    74  	bufWriter := np.NewWriterByteBuffer(buf)
    75  	defer func() {
    76  		if err != nil {
    77  			buf.Close()
    78  			bufWriter.Release(err)
    79  		}
    80  		rpcinfo.Record(ctx, ri, stats.WriteFinish, nil)
    81  	}()
    82  
    83  	// Set header flag = 1
    84  	tags := sendMsg.Tags()
    85  	if tags == nil {
    86  		tags = make(map[string]interface{})
    87  	}
    88  	tags[codec.HeaderFlagsKey] = codec.HeaderFlagSupportOutOfOrder
    89  
    90  	// encode
    91  	sendMsg.SetPayloadCodec(t.opt.PayloadCodec)
    92  	err = t.codec.Encode(ctx, sendMsg, bufWriter)
    93  	if err != nil {
    94  		return ctx, err
    95  	}
    96  
    97  	mc, _ := conn.(*muxCliConn)
    98  
    99  	// if oneway
   100  	var methodInfo serviceinfo.MethodInfo
   101  	if methodInfo, err = trans.GetMethodInfo(ri, sendMsg.ServiceInfo()); err != nil {
   102  		return ctx, err
   103  	}
   104  	if methodInfo.OneWay() {
   105  		mc.Put(func() (_ netpoll.Writer, isNil bool) {
   106  			return buf, false
   107  		})
   108  		return ctx, nil
   109  	}
   110  
   111  	// add notify
   112  	seqID := ri.Invocation().SeqID()
   113  	callback := newAsyncCallback(buf, bufWriter)
   114  	mc.seqIDMap.store(seqID, callback)
   115  	mc.Put(callback.getter)
   116  	return ctx, err
   117  }
   118  
   119  // Read implements the remote.ClientTransHandler interface.
   120  func (t *cliTransHandler) Read(ctx context.Context, conn net.Conn, msg remote.Message) (nctx context.Context, err error) {
   121  	ri := msg.RPCInfo()
   122  	mc, _ := conn.(*muxCliConn)
   123  	seqID := ri.Invocation().SeqID()
   124  	// load & delete before return
   125  	event, _ := mc.seqIDMap.load(seqID)
   126  	defer mc.seqIDMap.delete(seqID)
   127  
   128  	callback, _ := event.(*asyncCallback)
   129  	defer callback.Close()
   130  
   131  	readTimeout := trans.GetReadTimeout(ri.Config())
   132  	if readTimeout > 0 {
   133  		var cancel context.CancelFunc
   134  		ctx, cancel = context.WithTimeout(ctx, readTimeout)
   135  		defer cancel()
   136  	}
   137  	select {
   138  	case <-ctx.Done():
   139  		// timeout
   140  		return ctx, fmt.Errorf("recv wait timeout %s, seqID=%d", readTimeout, seqID)
   141  	case bufReader := <-callback.notifyChan:
   142  		// recv
   143  		if bufReader == nil {
   144  			return ctx, ErrConnClosed
   145  		}
   146  		rpcinfo.Record(ctx, ri, stats.ReadStart, nil)
   147  		msg.SetPayloadCodec(t.opt.PayloadCodec)
   148  		err := t.codec.Decode(ctx, msg, bufReader)
   149  		if err != nil && errors.Is(err, netpoll.ErrReadTimeout) {
   150  			err = kerrors.ErrRPCTimeout.WithCause(err)
   151  		}
   152  		if l := bufReader.ReadableLen(); l > 0 {
   153  			bufReader.Skip(l)
   154  		}
   155  		bufReader.Release(nil)
   156  		rpcinfo.Record(ctx, ri, stats.ReadFinish, err)
   157  		return ctx, err
   158  	}
   159  }
   160  
   161  // OnMessage implements the remote.ClientTransHandler interface.
   162  func (t *cliTransHandler) OnMessage(ctx context.Context, args, result remote.Message) (context.Context, error) {
   163  	// do nothing
   164  	return ctx, nil
   165  }
   166  
   167  // OnInactive implements the remote.ClientTransHandler interface.
   168  func (t *cliTransHandler) OnInactive(ctx context.Context, conn net.Conn) {
   169  	// ineffective now and do nothing
   170  }
   171  
   172  // OnError implements the remote.ClientTransHandler interface.
   173  func (t *cliTransHandler) OnError(ctx context.Context, err error, conn net.Conn) {
   174  	if pe, ok := err.(*kerrors.DetailedError); ok {
   175  		klog.CtxErrorf(ctx, "KITEX: send request error, remote=%s, error=%s\nstack=%s", conn.RemoteAddr(), err.Error(), pe.Stack())
   176  	} else {
   177  		klog.CtxErrorf(ctx, "KITEX: send request error, remote=%s, error=%s", conn.RemoteAddr(), err.Error())
   178  	}
   179  }
   180  
   181  // SetPipeline implements the remote.ClientTransHandler interface.
   182  func (t *cliTransHandler) SetPipeline(p *remote.TransPipeline) {
   183  	t.transPipe = p
   184  }
   185  
   186  type asyncCallback struct {
   187  	wbuf       *netpoll.LinkBuffer
   188  	bufWriter  remote.ByteBuffer
   189  	notifyChan chan remote.ByteBuffer // notify recv reader
   190  	closed     int32                  // 1 is closed, 2 means wbuf has been flush
   191  }
   192  
   193  func newAsyncCallback(wbuf *netpoll.LinkBuffer, bufWriter remote.ByteBuffer) *asyncCallback {
   194  	return &asyncCallback{
   195  		wbuf:       wbuf,
   196  		bufWriter:  bufWriter,
   197  		notifyChan: make(chan remote.ByteBuffer, 1),
   198  	}
   199  }
   200  
   201  // Recv is called when receive a message.
   202  func (c *asyncCallback) Recv(bufReader remote.ByteBuffer, err error) error {
   203  	c.notify(bufReader)
   204  	return nil
   205  }
   206  
   207  // Close is used to close the mux connection.
   208  func (c *asyncCallback) Close() error {
   209  	if atomic.CompareAndSwapInt32(&c.closed, 0, 1) {
   210  		c.wbuf.Close()
   211  		c.bufWriter.Release(nil)
   212  	}
   213  	return nil
   214  }
   215  
   216  func (c *asyncCallback) notify(bufReader remote.ByteBuffer) {
   217  	select {
   218  	case c.notifyChan <- bufReader:
   219  	default:
   220  		if bufReader != nil {
   221  			bufReader.Release(nil)
   222  		}
   223  	}
   224  }
   225  
   226  func (c *asyncCallback) getter() (w netpoll.Writer, isNil bool) {
   227  	if atomic.CompareAndSwapInt32(&c.closed, 0, 2) {
   228  		return c.wbuf, false
   229  	}
   230  	return nil, true
   231  }