github.com/songzhibin97/gkit@v1.2.13/net/tcp/conn.go (about)

     1  package tcp
     2  
     3  import (
     4  	"bufio"
     5  	"crypto/tls"
     6  	"errors"
     7  	"io"
     8  	"net"
     9  	"time"
    10  
    11  	"github.com/songzhibin97/gkit/cache/buffer"
    12  )
    13  
    14  var defaultRetry Retry
    15  
    16  // Conn 封装原始 net.conn 对象
    17  type Conn struct {
    18  	// net.conn: 原始的conn对象
    19  	net.Conn
    20  
    21  	// reader: 用于读取conn缓冲区
    22  	reader *bufio.Reader
    23  
    24  	// sendTimeout: 发送超时时间
    25  	sendTimeout time.Time
    26  
    27  	// recvTimeout: 接受超时时间
    28  	recvTimeout time.Time
    29  
    30  	// recvBufferInterval: 读取缓存间隔时间
    31  	recvBufferInterval time.Duration
    32  }
    33  
    34  // Retry 重试配置
    35  type Retry struct {
    36  	// Count: 重试次数,每重试一次就会 -1, 如果==0默认不重试
    37  	Count uint
    38  
    39  	// Interval: 重试间隔
    40  	Interval time.Duration
    41  }
    42  
    43  // Send 发送数据至对端,有重试机制
    44  func (c *Conn) Send(data []byte, retry *Retry) error {
    45  	if retry == nil {
    46  		retry = &defaultRetry
    47  	}
    48  	for {
    49  		_, err := c.Write(data)
    50  		switch {
    51  		case err != nil && errors.Is(err, io.EOF):
    52  			// EOF 处理
    53  			return nil
    54  		case err != nil && retry.Count > 0:
    55  			// 触发重试
    56  			retry.Count--
    57  			if retry.Interval == 0 {
    58  				retry.Interval = DefaultRetryInterval
    59  			}
    60  			time.Sleep(retry.Interval)
    61  		default:
    62  			return err
    63  		}
    64  	}
    65  }
    66  
    67  // Recv 接受数据
    68  // length == 0 从 Conn一次读取立即返回
    69  // length < 0 从 Conn 接收所有数据,并将其返回,直到没有数据
    70  // length > 0 从 Conn 接收到对应的数据返回
    71  func (c *Conn) Recv(length int, retry *Retry) ([]byte, error) {
    72  	if retry == nil {
    73  		retry = &defaultRetry
    74  	}
    75  	var (
    76  		// err: error
    77  		err error
    78  
    79  		// size: 返回一次读取的大小
    80  		size int
    81  
    82  		// index: 目前指向的索引的位置
    83  		index int
    84  
    85  		// bf: 读取后的缓冲区
    86  		bf []byte
    87  
    88  		// flag: 判断是否循环读
    89  		flag bool
    90  	)
    91  	if length > 0 {
    92  		// 读取指定的长度
    93  		bf = *buffer.GetBytes(length)
    94  	} else {
    95  		// 需要 eof 返回
    96  		bf = *buffer.GetBytes(DefaultReadBuffer)
    97  	}
    98  cycle:
    99  	for {
   100  		if length < 0 && index > 0 {
   101  			// length < 0 要接受所有的数据,直至EOF
   102  			flag = true
   103  			if err = c.SetReadDeadline(time.Now().Add(c.recvBufferInterval)); err != nil {
   104  				return nil, err
   105  			}
   106  		}
   107  		size, err = c.reader.Read(bf[index:])
   108  		if size > 0 {
   109  			index += size
   110  			if length > 0 {
   111  				if index == length {
   112  					break cycle
   113  				}
   114  			} else {
   115  				if index >= DefaultReadBuffer {
   116  					bf = append(bf, make([]byte, DefaultReadBuffer)...)
   117  				} else if !flag {
   118  					break cycle
   119  				}
   120  			}
   121  		}
   122  		if err != nil {
   123  			switch {
   124  			case errors.Is(err, io.EOF):
   125  				break cycle
   126  			case flag && isTimeout(err):
   127  				if err = c.SetReadDeadline(time.Now().Add(c.recvBufferInterval)); err != nil {
   128  					return nil, err
   129  				}
   130  				break cycle
   131  			case retry.Count > 0:
   132  				// 触发重试
   133  				retry.Count--
   134  				if retry.Interval == 0 {
   135  					retry.Interval = DefaultRetryInterval
   136  				}
   137  				time.Sleep(retry.Interval)
   138  				goto cycle
   139  			default:
   140  				return nil, err
   141  			}
   142  		}
   143  		if length == 0 {
   144  			break cycle
   145  		}
   146  	}
   147  	return bf[:index], nil
   148  }
   149  
   150  // RecvLine 读取一行 '\n'
   151  func (c *Conn) RecvLine(retry *Retry) ([]byte, error) {
   152  	var (
   153  		// err
   154  		err error
   155  
   156  		data []byte
   157  
   158  		index int
   159  
   160  		bf = (*buffer.GetBytes(1024))[:0]
   161  	)
   162  	for {
   163  		data, err = c.Recv(1, retry)
   164  		if err != nil || data[0] == '\n' {
   165  			break
   166  		}
   167  		index++
   168  		bf = append(bf, data...)
   169  	}
   170  	return bf[:index], err
   171  }
   172  
   173  // RecvWithTimeout 读取已经超时的链接
   174  func (c *Conn) RecvWithTimeout(length int, timeout time.Duration, retry *Retry) ([]byte, error) {
   175  	if err := c.SetRecvDeadline(time.Now().Add(timeout)); err != nil {
   176  		return nil, err
   177  	}
   178  	defer c.SetRecvDeadline(time.Time{})
   179  	return c.Recv(length, retry)
   180  }
   181  
   182  // SendWithTimeout 写入数据给已经超时的链接
   183  func (c *Conn) SendWithTimeout(data []byte, timeout time.Duration, retry *Retry) error {
   184  	if err := c.SetSendDeadline(time.Now().Add(timeout)); err != nil {
   185  		return err
   186  	}
   187  	defer c.SetSendDeadline(time.Time{})
   188  	return c.Send(data, retry)
   189  }
   190  
   191  // SendRecv 写入数据并读取返回
   192  func (c *Conn) SendRecv(data []byte, length int, retry *Retry) ([]byte, error) {
   193  	if err := c.Send(data, retry); err != nil {
   194  		return nil, err
   195  	}
   196  	return c.Recv(length, retry)
   197  }
   198  
   199  // SendRecvWithTimeout 将数据写入并读出已经超时的链接
   200  func (c *Conn) SendRecvWithTimeout(data []byte, timeout time.Duration, length int, retry *Retry) ([]byte, error) {
   201  	if err := c.Send(data, retry); err != nil {
   202  		return nil, err
   203  	}
   204  	return c.RecvWithTimeout(length, timeout, retry)
   205  }
   206  
   207  func (c *Conn) SetDeadline(t time.Time) error {
   208  	err := c.Conn.SetDeadline(t)
   209  	if err == nil {
   210  		c.recvTimeout = t
   211  		c.sendTimeout = t
   212  	}
   213  	return err
   214  }
   215  
   216  func (c *Conn) SetRecvDeadline(t time.Time) error {
   217  	err := c.SetReadDeadline(t)
   218  	if err == nil {
   219  		c.sendTimeout = t
   220  	}
   221  	return err
   222  }
   223  
   224  func (c *Conn) SetSendDeadline(t time.Time) error {
   225  	err := c.SetWriteDeadline(t)
   226  	if err == nil {
   227  		c.sendTimeout = t
   228  	}
   229  	return err
   230  }
   231  
   232  // isTimeout: 判断是否是超时的error错误
   233  func isTimeout(err error) bool {
   234  	if err == nil {
   235  		return false
   236  	}
   237  	if netErr, ok := err.(net.Error); ok && netErr.Timeout() {
   238  		return true
   239  	}
   240  	return false
   241  }
   242  
   243  // RecoveryBuffer 用于回收已经不使用的 *[]byte
   244  // 如果使用已经回收的资源,可能会造成panic,请注意
   245  func RecoveryBuffer(data *[]byte) {
   246  	buffer.PutBytes(data)
   247  }
   248  
   249  // SetRecvBufferInterval 读取缓存间隔时间
   250  func (c *Conn) SetRecvBufferInterval(t time.Duration) {
   251  	c.recvBufferInterval = t
   252  }
   253  
   254  // NewConnByNetConn 通过原始的 net.Conn 链接建立 Conn 封装对象
   255  func NewConnByNetConn(conn net.Conn) *Conn {
   256  	return &Conn{
   257  		Conn:               conn,
   258  		reader:             bufio.NewReader(conn),
   259  		sendTimeout:        time.Time{},
   260  		recvTimeout:        time.Time{},
   261  		recvBufferInterval: DefaultWaitTimeout,
   262  	}
   263  }
   264  
   265  // newNetConn 新建conn
   266  func newNetConn(addr string, timeout *time.Duration) (net.Conn, error) {
   267  	if timeout == nil {
   268  		timeout = &DefaultConnTimeout
   269  	}
   270  	return net.DialTimeout("tcp", addr, *timeout)
   271  }
   272  
   273  // newNetConnTLS
   274  func newNetConnTLS(addr string, tlsConfig *tls.Config, timeout *time.Duration) (net.Conn, error) {
   275  	if timeout == nil {
   276  		timeout = &DefaultConnTimeout
   277  	}
   278  	dialer := &net.Dialer{Timeout: *timeout}
   279  	return tls.DialWithDialer(dialer, "tcp", addr, tlsConfig)
   280  }
   281  
   282  // NewConn 通过原始拨号建立
   283  func NewConn(addr string, timeout *time.Duration) (*Conn, error) {
   284  	if conn, err := newNetConn(addr, timeout); err != nil {
   285  		return nil, err
   286  	} else {
   287  		return NewConnByNetConn(conn), nil
   288  	}
   289  }
   290  
   291  // NewConnTLS 通过tls建立
   292  func NewConnTLS(addr string, tlsConfig *tls.Config, timeout *time.Duration) (*Conn, error) {
   293  	if conn, err := newNetConnTLS(addr, tlsConfig, timeout); err != nil {
   294  		return nil, err
   295  	} else {
   296  		return NewConnByNetConn(conn), nil
   297  	}
   298  }