github.com/LagrangeDev/LagrangeGo@v0.0.0-20240512064304-ad4a85e10cb4/client/internal/network/conn.go (about)

     1  package network
     2  
     3  // from https://github.com/Mrs4s/MiraiGo/blob/master/client/internal/network/conn.go
     4  
     5  import (
     6  	"encoding/binary"
     7  	"io"
     8  	"net"
     9  	"sync"
    10  
    11  	"github.com/pkg/errors"
    12  )
    13  
    14  type TCPClient struct {
    15  	lock                 sync.RWMutex
    16  	conn                 net.Conn
    17  	connected            bool
    18  	plannedDisconnect    func(*TCPClient)
    19  	unexpectedDisconnect func(*TCPClient, error)
    20  }
    21  
    22  var ErrConnectionClosed = errors.New("connection closed")
    23  
    24  // PlannedDisconnect 预料中的断开连接
    25  // 如调用 Close() Connect()
    26  func (t *TCPClient) PlannedDisconnect(f func(*TCPClient)) {
    27  	t.lock.Lock()
    28  	defer t.lock.Unlock()
    29  	t.plannedDisconnect = f
    30  }
    31  
    32  // UnexpectedDisconnect 未预料的断开连接
    33  func (t *TCPClient) UnexpectedDisconnect(f func(*TCPClient, error)) {
    34  	t.lock.Lock()
    35  	defer t.lock.Unlock()
    36  	t.unexpectedDisconnect = f
    37  }
    38  
    39  func (t *TCPClient) Connect(addr string) error {
    40  	t.Close()
    41  	conn, err := net.Dial("tcp", addr)
    42  	if err != nil {
    43  		return errors.Wrap(err, "dial tcp error")
    44  	}
    45  	t.lock.Lock()
    46  	defer t.lock.Unlock()
    47  	t.conn = conn
    48  	t.connected = true
    49  	return nil
    50  }
    51  
    52  func (t *TCPClient) Write(buf []byte) error {
    53  	if conn := t.getConn(); conn != nil {
    54  		_, err := conn.Write(buf)
    55  		if err != nil {
    56  			t.unexpectedClose(err)
    57  			return ErrConnectionClosed
    58  		}
    59  		return nil
    60  	}
    61  
    62  	return ErrConnectionClosed
    63  }
    64  
    65  func (t *TCPClient) ReadBytes(len int) ([]byte, error) {
    66  	buf := make([]byte, len)
    67  	if conn := t.getConn(); conn != nil {
    68  		_, err := io.ReadFull(conn, buf)
    69  		if err != nil {
    70  			// time.Sleep(time.Millisecond * 100) // 服务器会发送offline包后立即断开连接, 此时还没解析, 可能还是得加锁
    71  			t.unexpectedClose(err)
    72  			return nil, ErrConnectionClosed
    73  		}
    74  		return buf, nil
    75  	}
    76  
    77  	return nil, ErrConnectionClosed
    78  }
    79  
    80  func (t *TCPClient) ReadInt32() (int32, error) {
    81  	b, err := t.ReadBytes(4)
    82  	if err != nil {
    83  		return 0, err
    84  	}
    85  	return int32(binary.BigEndian.Uint32(b)), nil
    86  }
    87  
    88  func (t *TCPClient) Close() {
    89  	t.close()
    90  	t.invokePlannedDisconnect()
    91  }
    92  
    93  func (t *TCPClient) unexpectedClose(err error) {
    94  	t.close()
    95  	t.invokeUnexpectedDisconnect(err)
    96  }
    97  
    98  func (t *TCPClient) close() {
    99  	t.lock.Lock()
   100  	defer t.lock.Unlock()
   101  	if t.conn != nil {
   102  		_ = t.conn.Close()
   103  		t.conn = nil
   104  	}
   105  }
   106  
   107  func (t *TCPClient) invokePlannedDisconnect() {
   108  	t.lock.RLock()
   109  	defer t.lock.RUnlock()
   110  	if t.plannedDisconnect != nil && t.connected {
   111  		go t.plannedDisconnect(t)
   112  		t.connected = false
   113  	}
   114  }
   115  
   116  func (t *TCPClient) invokeUnexpectedDisconnect(err error) {
   117  	t.lock.RLock()
   118  	defer t.lock.RUnlock()
   119  	if t.unexpectedDisconnect != nil && t.connected {
   120  		go t.unexpectedDisconnect(t, err)
   121  		t.connected = false
   122  	}
   123  }
   124  
   125  func (t *TCPClient) getConn() net.Conn {
   126  	t.lock.RLock()
   127  	defer t.lock.RUnlock()
   128  	return t.conn
   129  }