github.com/rolandhe/saber@v0.0.4/nfour/duplex/trans.go (about)

     1  // net framework basing tcp, tcp is 4th layer of osi net model
     2  // Copyright 2023 The saber Authors. All rights reserved.
     3  
     4  package duplex
     5  
     6  import (
     7  	"errors"
     8  	"github.com/rolandhe/saber/gocc"
     9  	"github.com/rolandhe/saber/nfour"
    10  	"github.com/rolandhe/saber/utils/bytutil"
    11  	"net"
    12  	"sync"
    13  	"sync/atomic"
    14  	"time"
    15  )
    16  
    17  var (
    18  	// ErrTaskTimeout 请求执行超时异常
    19  	ErrTaskTimeout = errors.New("task execute timeout")
    20  	// ErrTransShutdown Trans 客户端已经被关闭
    21  	ErrTransShutdown = errors.New("transport shut down")
    22  )
    23  
    24  // TransConf Trans 客户端配置
    25  type TransConf struct {
    26  	// ReadTimeout 网络读取超时时间
    27  	ReadTimeout time.Duration
    28  	// WriteTimeout 网络写出超时
    29  	WriteTimeout time.Duration
    30  
    31  	// IdleTimeout 连接长时间没有读取到数据的超时时间,该超过该时间,系统会输出日志,没有其他的处理,不会中断连接
    32  	IdleTimeout time.Duration
    33  	concurrent  gocc.Semaphore
    34  }
    35  
    36  // ReqTimeout 请求超时信息
    37  type ReqTimeout struct {
    38  	// ReadTimeout 网络读取超时时间
    39  	ReadTimeout time.Duration
    40  	// WriteTimeout 网络写出超时
    41  	WriteTimeout time.Duration
    42  	// WaitConcurrent 当到达最大并发时,等待执行的超时时间
    43  	WaitConcurrent time.Duration
    44  }
    45  
    46  // NewTransConf 构建客户端的配置
    47  // rwTimeout 读写超时,这种情况下,读写超时是相同的
    48  func NewTransConf(rwTimeout time.Duration, concurrent uint) *TransConf {
    49  	return &TransConf{
    50  		ReadTimeout:  rwTimeout,
    51  		WriteTimeout: rwTimeout,
    52  		IdleTimeout:  time.Minute * 30,
    53  		concurrent:   gocc.NewDefaultSemaphore(concurrent),
    54  	}
    55  }
    56  
    57  // NewTrans 构建客户端 Trans
    58  // name 表示该 Trans的名称,该名称会被输出到日志中,方便发现问题
    59  func NewTrans(addr string, conf *TransConf, name string) (*Trans, error) {
    60  	conn, err := net.Dial("tcp", addr)
    61  	if err != nil {
    62  		// handle error
    63  		nfour.NFourLogger.InfoLn(err)
    64  		return nil, err
    65  	}
    66  
    67  	t := &Trans{
    68  		conn:     conn,
    69  		conf:     conf,
    70  		sendCh:   make(chan *sendingTask, conf.concurrent.TotalTokens()),
    71  		shutDown: make(chan struct{}),
    72  		name:     name,
    73  	}
    74  
    75  	go asyncSender(t)
    76  	go asyncReader(t)
    77  
    78  	return t, nil
    79  }
    80  
    81  // Trans 多路复用模式下的客户端,每个Trans内持有一个连接,并且与服务端类似,由两个goroutine分别负责请求的发出和响应的接收。
    82  // 使用者通过Trans发送请求到服务端,并返回响应
    83  type Trans struct {
    84  	conn     net.Conn
    85  	conf     *TransConf
    86  	sendCh   chan *sendingTask
    87  	shutDown chan struct{}
    88  	status   int32
    89  	cache    sync.Map
    90  	idGen    atomic.Uint64
    91  	name     string
    92  }
    93  
    94  // Shutdown 关闭Trans
    95  // source 发起Shutdown的场景,用于日志记录
    96  func (t *Trans) Shutdown(source string) {
    97  	if atomic.CompareAndSwapInt32(&t.status, 0, 1) {
    98  		nfour.NFourLogger.Info("%s trigger %s shutdown\n", source, t.name)
    99  		close(t.shutDown)
   100  	}
   101  }
   102  
   103  // IsShutdown Trans是否已经被关闭,如果已经被关闭,将不能接收新的发送请求
   104  func (t *Trans) IsShutdown() bool {
   105  	return atomic.LoadInt32(&t.status) == 1
   106  }
   107  
   108  // SendPayload 发送二进制请求
   109  // reqTimeout 本次请求的超时时间
   110  func (t *Trans) SendPayload(req []byte, reqTimeout *ReqTimeout) ([]byte, error) {
   111  	if t.IsShutdown() {
   112  		return nil, ErrTransShutdown
   113  	}
   114  	if reqTimeout == nil {
   115  		reqTimeout = &ReqTimeout{}
   116  	}
   117  	if !t.conf.concurrent.AcquireTimeout(reqTimeout.WaitConcurrent) {
   118  		return nil, nfour.ExceedConcurrentError
   119  	}
   120  	if reqTimeout.WriteTimeout <= 0 {
   121  		reqTimeout.WriteTimeout = t.conf.WriteTimeout
   122  	}
   123  	if reqTimeout.ReadTimeout <= 0 {
   124  		reqTimeout.ReadTimeout = t.conf.ReadTimeout
   125  	}
   126  	if t.IsShutdown() {
   127  		return nil, ErrTransShutdown
   128  	}
   129  	seqId := t.idGen.Add(1)
   130  	fu := &future{
   131  		seqId:    seqId,
   132  		notifier: make(chan struct{}),
   133  	}
   134  	t.cache.Store(seqId, fu)
   135  	t.sendCh <- &sendingTask{
   136  		seqId:   seqId,
   137  		payload: req,
   138  		timeout: reqTimeout.WriteTimeout,
   139  		f:       fu,
   140  	}
   141  	return fu.get(reqTimeout.ReadTimeout)
   142  }
   143  
   144  // asyncSender/asyncReader以及外部都可以调用Shutdown发送关闭指令
   145  // 但由sender 最终来关闭连接
   146  // asyncSender识别到连接关闭指令后消除等待结果的任务
   147  func asyncSender(trans *Trans) {
   148  	releaseWait := false
   149  	for {
   150  		select {
   151  		case task := <-trans.sendCh:
   152  			if !writeCore(task.payload, task.seqId, trans.conn, task.timeout) {
   153  				nfour.NFourLogger.Info("%s write err,will shutdown\n", trans.name)
   154  				trans.Shutdown("sender")
   155  				releaseWait = true
   156  				break
   157  			}
   158  			nfour.NFourLogger.Debug("%s send success\n", trans.name)
   159  		case <-trans.shutDown:
   160  			trans.conn.Close()
   161  			releaseWait = true
   162  			nfour.NFourLogger.Info("%s get shut down event,shut down\n", trans.name)
   163  			break
   164  		case <-time.After(trans.conf.IdleTimeout):
   165  			nfour.NFourLogger.Info("%s wait send task timeout\n", trans.name)
   166  		}
   167  		if releaseWait {
   168  			break
   169  		}
   170  	}
   171  	if releaseWait {
   172  		nfour.NFourLogger.Info("%s send release not sent task\n", trans.name)
   173  		releaseCount := 0
   174  		for {
   175  			select {
   176  			case task := <-trans.sendCh:
   177  				task.f.accept(nil, ErrTransShutdown)
   178  				releaseCount++
   179  			default:
   180  				nfour.NFourLogger.Info("%s send release not sent task:%d\n", trans.name, releaseCount)
   181  				return
   182  			}
   183  		}
   184  
   185  	}
   186  }
   187  
   188  func asyncReader(trans *Trans) {
   189  	fullHeaderLength := nfour.PayLoadLenBufLength + seqIdHeaderLength
   190  	header := make([]byte, fullHeaderLength)
   191  	for {
   192  		if trans.IsShutdown() {
   193  			break
   194  		}
   195  		trans.conn.SetReadDeadline(time.Now().Add(trans.conf.IdleTimeout))
   196  		if err := nfour.InternalReadPayload(trans.conn, header, fullHeaderLength, true); err != nil {
   197  			nfour.NFourLogger.Info("%s read header error:%v\n", trans.name, err)
   198  			trans.Shutdown("reader")
   199  			break
   200  		}
   201  		l, _ := bytutil.ToInt32(header[:nfour.PayLoadLenBufLength])
   202  		bodyBuff := make([]byte, l, l)
   203  		seqId, err := bytutil.ToUint64(header[nfour.PayLoadLenBufLength:])
   204  		trans.conn.SetReadDeadline(time.Now().Add(trans.conf.ReadTimeout))
   205  		err = nfour.InternalReadPayload(trans.conn, bodyBuff, int(l), false)
   206  		if err != nil {
   207  			nfour.NFourLogger.Info("%s read payload error:%v,need %d bytes\n", trans.name, err, l)
   208  			trans.Shutdown("reader")
   209  			break
   210  		}
   211  		f, ok := trans.cache.Load(seqId)
   212  		if !ok {
   213  			nfour.NFourLogger.Info("warning: %s lost seqId:%d with read result\n", trans.name, seqId)
   214  			continue
   215  		}
   216  		if trans.IsShutdown() {
   217  			break
   218  		}
   219  		trans.cache.Delete(seqId)
   220  		fu := f.(*future)
   221  		fu.accept(bodyBuff, err)
   222  		trans.conf.concurrent.Release()
   223  	}
   224  	nfour.NFourLogger.Info("%s async reader release futures\n", trans.name)
   225  	releasedCount := 0
   226  
   227  	trans.cache.Range(func(key, value any) bool {
   228  		fu := value.(*future)
   229  		fu.accept(nil, ErrTransShutdown)
   230  		releasedCount++
   231  		return true
   232  	})
   233  	nfour.NFourLogger.Info("%s async reader release futures:%d\n", trans.name, releasedCount)
   234  }
   235  
   236  type sendingTask struct {
   237  	seqId   uint64
   238  	payload []byte
   239  	timeout time.Duration
   240  	f       *future
   241  }
   242  
   243  type future struct {
   244  	seqId    uint64
   245  	notifier chan struct{}
   246  	value    []byte
   247  	err      error
   248  	flag     atomic.Bool
   249  }
   250  
   251  func (f *future) get(timeout time.Duration) ([]byte, error) {
   252  	select {
   253  	case <-f.notifier:
   254  		return f.value, f.err
   255  	case <-time.After(timeout):
   256  		return nil, ErrTaskTimeout
   257  	}
   258  }
   259  
   260  func (f *future) accept(v []byte, err error) {
   261  	if f.flag.CompareAndSwap(false, true) {
   262  		f.value = v
   263  		f.err = err
   264  		close(f.notifier)
   265  	}
   266  }