github.com/Mrs4s/MiraiGo@v0.0.0-20240226124653-54bdd873e3fe/client/internal/network/conn.go (about)

     1  package network
     2  
     3  import (
     4  	"encoding/binary"
     5  	"io"
     6  	"net"
     7  	"sync"
     8  
     9  	"github.com/pkg/errors"
    10  )
    11  
    12  type TCPClient struct {
    13  	lock                 sync.RWMutex
    14  	conn                 net.Conn
    15  	connected            bool
    16  	plannedDisconnect    func(*TCPClient)
    17  	unexpectedDisconnect func(*TCPClient, error)
    18  }
    19  
    20  var ErrConnectionClosed = errors.New("connection closed")
    21  
    22  // PlannedDisconnect 预料中的断开连接
    23  // 如调用 Close() Connect()
    24  func (t *TCPClient) PlannedDisconnect(f func(*TCPClient)) {
    25  	t.lock.Lock()
    26  	defer t.lock.Unlock()
    27  	t.plannedDisconnect = f
    28  }
    29  
    30  // UnexpectedDisconnect 未预料的断开连接
    31  func (t *TCPClient) UnexpectedDisconnect(f func(*TCPClient, error)) {
    32  	t.lock.Lock()
    33  	defer t.lock.Unlock()
    34  	t.unexpectedDisconnect = f
    35  }
    36  
    37  func (t *TCPClient) Connect(addr string) error {
    38  	t.Close()
    39  	conn, err := net.Dial("tcp", addr)
    40  	if err != nil {
    41  		return errors.Wrap(err, "dial tcp error")
    42  	}
    43  	t.lock.Lock()
    44  	defer t.lock.Unlock()
    45  	t.conn = conn
    46  	t.connected = true
    47  	return nil
    48  }
    49  
    50  func (t *TCPClient) Write(buf []byte) error {
    51  	if conn := t.getConn(); conn != nil {
    52  		_, err := conn.Write(buf)
    53  		if err != nil {
    54  			t.unexpectedClose(err)
    55  			return ErrConnectionClosed
    56  		}
    57  		return nil
    58  	}
    59  
    60  	return ErrConnectionClosed
    61  }
    62  
    63  func (t *TCPClient) ReadBytes(len int) ([]byte, error) {
    64  	buf := make([]byte, len)
    65  	if conn := t.getConn(); conn != nil {
    66  		_, err := io.ReadFull(conn, buf)
    67  		if err != nil {
    68  			// time.Sleep(time.Millisecond * 100) // 服务器会发送offline包后立即断开连接, 此时还没解析, 可能还是得加锁
    69  			t.unexpectedClose(err)
    70  			return nil, ErrConnectionClosed
    71  		}
    72  		return buf, nil
    73  	}
    74  
    75  	return nil, ErrConnectionClosed
    76  }
    77  
    78  func (t *TCPClient) ReadInt32() (int32, error) {
    79  	b, err := t.ReadBytes(4)
    80  	if err != nil {
    81  		return 0, err
    82  	}
    83  	return int32(binary.BigEndian.Uint32(b)), nil
    84  }
    85  
    86  func (t *TCPClient) Close() {
    87  	t.close()
    88  	t.invokePlannedDisconnect()
    89  }
    90  
    91  func (t *TCPClient) unexpectedClose(err error) {
    92  	t.close()
    93  	t.invokeUnexpectedDisconnect(err)
    94  }
    95  
    96  func (t *TCPClient) close() {
    97  	t.lock.Lock()
    98  	defer t.lock.Unlock()
    99  	if t.conn != nil {
   100  		_ = t.conn.Close()
   101  		t.conn = nil
   102  	}
   103  }
   104  
   105  func (t *TCPClient) invokePlannedDisconnect() {
   106  	t.lock.RLock()
   107  	defer t.lock.RUnlock()
   108  	if t.plannedDisconnect != nil && t.connected {
   109  		go t.plannedDisconnect(t)
   110  		t.connected = false
   111  	}
   112  }
   113  
   114  func (t *TCPClient) invokeUnexpectedDisconnect(err error) {
   115  	t.lock.RLock()
   116  	defer t.lock.RUnlock()
   117  	if t.unexpectedDisconnect != nil && t.connected {
   118  		go t.unexpectedDisconnect(t, err)
   119  		t.connected = false
   120  	}
   121  }
   122  
   123  func (t *TCPClient) getConn() net.Conn {
   124  	t.lock.RLock()
   125  	defer t.lock.RUnlock()
   126  	return t.conn
   127  }