github.com/cloudwego/kitex@v0.9.0/pkg/remote/trans/netpollmux/mux_conn.go (about)

     1  /*
     2   * Copyright 2021 CloudWeGo Authors
     3   *
     4   * Licensed under the Apache License, Version 2.0 (the "License");
     5   * you may not use this file except in compliance with the License.
     6   * You may obtain a copy of the License at
     7   *
     8   *     http://www.apache.org/licenses/LICENSE-2.0
     9   *
    10   * Unless required by applicable law or agreed to in writing, software
    11   * distributed under the License is distributed on an "AS IS" BASIS,
    12   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    13   * See the License for the specific language governing permissions and
    14   * limitations under the License.
    15   */
    16  
    17  package netpollmux
    18  
    19  import (
    20  	"context"
    21  	"errors"
    22  	"fmt"
    23  	"io"
    24  	"net"
    25  	"sync"
    26  
    27  	"github.com/cloudwego/netpoll"
    28  	"github.com/cloudwego/netpoll/mux"
    29  
    30  	"github.com/cloudwego/kitex/pkg/klog"
    31  	"github.com/cloudwego/kitex/pkg/remote"
    32  	"github.com/cloudwego/kitex/pkg/remote/codec"
    33  	np "github.com/cloudwego/kitex/pkg/remote/trans/netpoll"
    34  	"github.com/cloudwego/kitex/pkg/remote/transmeta"
    35  	"github.com/cloudwego/kitex/pkg/rpcinfo"
    36  )
    37  
    38  // ErrConnClosed .
    39  var ErrConnClosed = errors.New("conn closed")
    40  
    41  var defaultCodec = codec.NewDefaultCodec()
    42  
    43  func newMuxCliConn(connection netpoll.Connection) *muxCliConn {
    44  	c := &muxCliConn{
    45  		muxConn:  newMuxConn(connection),
    46  		seqIDMap: newShardMap(mux.ShardSize),
    47  	}
    48  	connection.SetOnRequest(c.OnRequest)
    49  	connection.AddCloseCallback(func(connection netpoll.Connection) error {
    50  		return c.forceClose()
    51  	})
    52  	return c
    53  }
    54  
    55  type muxCliConn struct {
    56  	muxConn
    57  	closing  bool      // whether the server is going to close this connection
    58  	seqIDMap *shardMap // (k,v) is (sequenceID, notify)
    59  }
    60  
    61  func (c *muxCliConn) IsActive() bool {
    62  	return !c.closing && c.muxConn.IsActive()
    63  }
    64  
    65  // OnRequest is called when the connection creates.
    66  func (c *muxCliConn) OnRequest(ctx context.Context, connection netpoll.Connection) (err error) {
    67  	// check protocol header
    68  	length, seqID, err := parseHeader(connection.Reader())
    69  	if err != nil {
    70  		err = fmt.Errorf("%w: addr(%s)", err, connection.RemoteAddr())
    71  		return c.onError(ctx, err, connection)
    72  	}
    73  	// reader is nil if return error
    74  	reader, err := connection.Reader().Slice(length)
    75  	if err != nil {
    76  		err = fmt.Errorf("mux read package slice failed: addr(%s), %w", connection.RemoteAddr(), err)
    77  		return c.onError(ctx, err, connection)
    78  	}
    79  	// seqId == 0 means a control frame.
    80  	if seqID == 0 {
    81  		iv := rpcinfo.NewInvocation("none", "none")
    82  		iv.SetSeqID(0)
    83  		ri := rpcinfo.NewRPCInfo(nil, nil, iv, nil, nil)
    84  		ctl := NewControlFrame()
    85  		msg := remote.NewMessage(ctl, nil, ri, remote.Reply, remote.Client)
    86  
    87  		bufReader := np.NewReaderByteBuffer(reader)
    88  		if err = defaultCodec.Decode(ctx, msg, bufReader); err != nil {
    89  			return
    90  		}
    91  
    92  		crrst := msg.TransInfo().TransStrInfo()[transmeta.HeaderConnectionReadyToReset]
    93  		if len(crrst) > 0 {
    94  			// the server is closing this connection
    95  			// in this case, let server close the real connection
    96  			// mux cli conn will mark itself is closing but will not close connection.
    97  			c.closing = true
    98  			reader.(io.Closer).Close()
    99  		}
   100  		return
   101  	}
   102  
   103  	// notify asyncCallback
   104  	callback, ok := c.seqIDMap.load(seqID)
   105  	if !ok {
   106  		reader.(io.Closer).Close()
   107  		return
   108  	}
   109  	bufReader := np.NewReaderByteBuffer(reader)
   110  	callback.Recv(bufReader, nil)
   111  	return nil
   112  }
   113  
   114  // Close does nothing.
   115  func (c *muxCliConn) Close() error {
   116  	return nil
   117  }
   118  
   119  func (c *muxCliConn) forceClose() error {
   120  	c.shardQueue.Close()
   121  	c.Connection.Close()
   122  	c.seqIDMap.rangeMap(func(seqID int32, msg EventHandler) {
   123  		msg.Recv(nil, ErrConnClosed)
   124  	})
   125  	return nil
   126  }
   127  
   128  func (c *muxCliConn) close() error {
   129  	if !c.closing {
   130  		return c.forceClose()
   131  	}
   132  	// if closing, let server close the connection
   133  	return nil
   134  }
   135  
   136  func (c *muxCliConn) onError(ctx context.Context, err error, connection netpoll.Connection) error {
   137  	klog.CtxErrorf(ctx, "KITEX: error=%s", err.Error())
   138  	connection.Close()
   139  	return err
   140  }
   141  
   142  func newMuxSvrConn(connection netpoll.Connection, pool *sync.Pool) *muxSvrConn {
   143  	c := &muxSvrConn{
   144  		muxConn: newMuxConn(connection),
   145  		pool:    pool,
   146  	}
   147  	return c
   148  }
   149  
   150  type muxSvrConn struct {
   151  	muxConn
   152  	pool *sync.Pool // pool of rpcInfo
   153  }
   154  
   155  func newMuxConn(connection netpoll.Connection) muxConn {
   156  	c := muxConn{}
   157  	c.Connection = connection
   158  	c.shardQueue = mux.NewShardQueue(mux.ShardSize, connection)
   159  	return c
   160  }
   161  
   162  var (
   163  	_ net.Conn           = &muxConn{}
   164  	_ netpoll.Connection = &muxConn{}
   165  )
   166  
   167  type muxConn struct {
   168  	netpoll.Connection                 // raw conn
   169  	shardQueue         *mux.ShardQueue // use for write
   170  }
   171  
   172  // Put puts the buffer getter back to the queue.
   173  func (c *muxConn) Put(gt mux.WriterGetter) {
   174  	c.shardQueue.Add(gt)
   175  }
   176  
   177  func (c *muxConn) GracefulShutdown() {
   178  	c.shardQueue.Close()
   179  }