github.com/Asutorufa/yuhaiin@v0.3.6-0.20240502055049-7984da7023a0/pkg/net/proxy/shadowsocksr/protocol/base.go (about)

     1  package protocol
     2  
     3  import (
     4  	"bytes"
     5  	crand "crypto/rand"
     6  	"fmt"
     7  	"math/rand/v2"
     8  	"net"
     9  	"strings"
    10  	"sync"
    11  	"sync/atomic"
    12  
    13  	"github.com/Asutorufa/yuhaiin/pkg/net/proxy/shadowsocksr/cipher"
    14  	"github.com/Asutorufa/yuhaiin/pkg/utils/pool"
    15  )
    16  
    17  type protocol interface {
    18  	EncryptStream(dst *bytes.Buffer, data []byte) error
    19  	DecryptStream(dst *bytes.Buffer, data []byte) (int, error)
    20  	EncryptPacket(data []byte) ([]byte, error)
    21  	DecryptPacket(data []byte) ([]byte, error)
    22  
    23  	GetOverhead() int
    24  }
    25  
    26  type errorProtocol struct{ error }
    27  
    28  func NewErrorProtocol(err error) protocol                                   { return &errorProtocol{err} }
    29  func (e *errorProtocol) EncryptStream(dst *bytes.Buffer, data []byte) error { return e.error }
    30  func (e *errorProtocol) DecryptStream(dst *bytes.Buffer, data []byte) (int, error) {
    31  	return 0, e.error
    32  }
    33  func (e *errorProtocol) EncryptPacket(data []byte) ([]byte, error) { return nil, e.error }
    34  func (e *errorProtocol) DecryptPacket(data []byte) ([]byte, error) { return nil, e.error }
    35  func (e *errorProtocol) GetOverhead() int                          { return 0 }
    36  
    37  type AuthData struct {
    38  	clientID     [4]byte
    39  	connectionID atomic.Uint32
    40  
    41  	mu sync.Mutex
    42  }
    43  
    44  func NewAuth() *AuthData { return &AuthData{} }
    45  
    46  func (a *AuthData) nextAuth() {
    47  	if a.connectionID.Load() <= 0xFF000000 && a.connectionID.Load() != 0 {
    48  		a.connectionID.Add(1)
    49  		return
    50  	}
    51  
    52  	a.mu.Lock()
    53  	defer a.mu.Unlock()
    54  	crand.Read(a.clientID[:])
    55  	a.connectionID.Store(rand.Uint32() & 0xFFFFFF)
    56  }
    57  
    58  type packetConn struct {
    59  	protocol protocol
    60  	net.PacketConn
    61  }
    62  
    63  func newPacketConn(conn net.PacketConn, p protocol) net.PacketConn { return &packetConn{p, conn} }
    64  
    65  func (c *packetConn) WriteTo(b []byte, addr net.Addr) (int, error) {
    66  	data, err := c.protocol.EncryptPacket(b)
    67  	if err != nil {
    68  		return 0, err
    69  	}
    70  	_, err = c.PacketConn.WriteTo(data, addr)
    71  	return len(b), err
    72  }
    73  
    74  func (c *packetConn) ReadFrom(b []byte) (int, net.Addr, error) {
    75  	n, addr, err := c.PacketConn.ReadFrom(b)
    76  	if err != nil {
    77  		return n, addr, err
    78  	}
    79  	decoded, err := c.protocol.DecryptPacket(b[:n])
    80  	if err != nil {
    81  		return n, addr, err
    82  	}
    83  	copy(b, decoded)
    84  	return len(decoded), addr, nil
    85  }
    86  
    87  func (c *packetConn) Close() error { return c.PacketConn.Close() }
    88  
    89  type conn struct {
    90  	protocol protocol
    91  	net.Conn
    92  
    93  	ciphertext, plaintext bytes.Buffer
    94  }
    95  
    96  func newConn(c net.Conn, p protocol) net.Conn {
    97  	return &conn{
    98  		Conn:     c,
    99  		protocol: p,
   100  	}
   101  }
   102  
   103  func (c *conn) Read(b []byte) (n int, err error) {
   104  	if c.plaintext.Len() > 0 {
   105  		return c.plaintext.Read(b)
   106  	}
   107  
   108  	n, err = c.Conn.Read(b)
   109  	if err != nil {
   110  		return 0, err
   111  	}
   112  
   113  	c.ciphertext.Write(b[:n])
   114  	length, err := c.protocol.DecryptStream(&c.plaintext, c.ciphertext.Bytes())
   115  	if err != nil {
   116  		c.ciphertext.Reset()
   117  		return 0, err
   118  	}
   119  	c.ciphertext.Next(length)
   120  
   121  	n, _ = c.plaintext.Read(b)
   122  	return n, nil
   123  }
   124  
   125  func (c *conn) Write(b []byte) (n int, err error) {
   126  	buf := pool.GetBuffer()
   127  	defer pool.PutBuffer(buf)
   128  
   129  	if err = c.protocol.EncryptStream(buf, b); err != nil {
   130  		return 0, err
   131  	}
   132  	if _, err = c.Conn.Write(buf.Bytes()); err != nil {
   133  		return 0, err
   134  	}
   135  	return len(b), nil
   136  }
   137  
   138  var ProtocolMethod = map[string]func(Protocol) protocol{
   139  	"auth_aes128_sha1": NewAuthAES128SHA1,
   140  	"auth_aes128_md5":  NewAuthAES128MD5,
   141  	"auth_chain_a":     NewAuthChainA,
   142  	"auth_chain_b":     NewAuthChainB,
   143  	"origin":           NewOrigin,
   144  	"auth_sha1_v4":     NewAuthSHA1v4,
   145  	"verify_sha1":      NewVerifySHA1,
   146  	"ota":              NewVerifySHA1,
   147  }
   148  
   149  type Protocol struct {
   150  	*cipher.Cipher
   151  
   152  	HeadSize     int
   153  	TcpMss       int
   154  	ObfsOverhead int
   155  	Name         string
   156  	Param        string
   157  	IV           []byte
   158  
   159  	Auth *AuthData
   160  }
   161  
   162  func (s Protocol) stream() (protocol, error) {
   163  	c, ok := ProtocolMethod[strings.ToLower(s.Name)]
   164  	if ok {
   165  		return c(s), nil
   166  	}
   167  	return nil, fmt.Errorf("protocol %s not found", s.Name)
   168  }
   169  
   170  func (s Protocol) Stream(c net.Conn, iv []byte) (net.Conn, error) {
   171  	z := s
   172  	z.IV = iv
   173  
   174  	p, err := z.stream()
   175  	if err != nil {
   176  		return nil, err
   177  	}
   178  	return newConn(c, p), nil
   179  }
   180  
   181  func (s Protocol) Packet(c net.PacketConn) (net.PacketConn, error) {
   182  	p, err := s.stream()
   183  	if err != nil {
   184  		return nil, err
   185  	}
   186  	return newPacketConn(c, p), nil
   187  }
   188  
   189  func (s *Protocol) SetHeadLen(data []byte, defaultValue int) {
   190  	s.HeadSize = GetHeadSize(data, defaultValue)
   191  }
   192  
   193  // https://github.com/shadowsocksrr/shadowsocksr/blob/fd723a92c488d202b407323f0512987346944136/shadowsocks/obfsplugin/plain.py#L93
   194  func GetHeadSize(data []byte, defaultValue int) int {
   195  	if len(data) < 2 {
   196  		return defaultValue
   197  	}
   198  	headType := data[0] & 0x07
   199  	switch headType {
   200  	case 1:
   201  		// IPv4 1+4+2
   202  		return 7
   203  	case 4:
   204  		// IPv6 1+16+2
   205  		return 19
   206  	case 3:
   207  		// domain name, variant length
   208  		return 4 + int(data[1])
   209  	}
   210  
   211  	return defaultValue
   212  }