github.com/Asutorufa/yuhaiin@v0.3.6-0.20240502055049-7984da7023a0/pkg/net/proxy/shadowsocksr/protocol/base.go (about) 1 package protocol 2 3 import ( 4 "bytes" 5 crand "crypto/rand" 6 "fmt" 7 "math/rand/v2" 8 "net" 9 "strings" 10 "sync" 11 "sync/atomic" 12 13 "github.com/Asutorufa/yuhaiin/pkg/net/proxy/shadowsocksr/cipher" 14 "github.com/Asutorufa/yuhaiin/pkg/utils/pool" 15 ) 16 17 type protocol interface { 18 EncryptStream(dst *bytes.Buffer, data []byte) error 19 DecryptStream(dst *bytes.Buffer, data []byte) (int, error) 20 EncryptPacket(data []byte) ([]byte, error) 21 DecryptPacket(data []byte) ([]byte, error) 22 23 GetOverhead() int 24 } 25 26 type errorProtocol struct{ error } 27 28 func NewErrorProtocol(err error) protocol { return &errorProtocol{err} } 29 func (e *errorProtocol) EncryptStream(dst *bytes.Buffer, data []byte) error { return e.error } 30 func (e *errorProtocol) DecryptStream(dst *bytes.Buffer, data []byte) (int, error) { 31 return 0, e.error 32 } 33 func (e *errorProtocol) EncryptPacket(data []byte) ([]byte, error) { return nil, e.error } 34 func (e *errorProtocol) DecryptPacket(data []byte) ([]byte, error) { return nil, e.error } 35 func (e *errorProtocol) GetOverhead() int { return 0 } 36 37 type AuthData struct { 38 clientID [4]byte 39 connectionID atomic.Uint32 40 41 mu sync.Mutex 42 } 43 44 func NewAuth() *AuthData { return &AuthData{} } 45 46 func (a *AuthData) nextAuth() { 47 if a.connectionID.Load() <= 0xFF000000 && a.connectionID.Load() != 0 { 48 a.connectionID.Add(1) 49 return 50 } 51 52 a.mu.Lock() 53 defer a.mu.Unlock() 54 crand.Read(a.clientID[:]) 55 a.connectionID.Store(rand.Uint32() & 0xFFFFFF) 56 } 57 58 type packetConn struct { 59 protocol protocol 60 net.PacketConn 61 } 62 63 func newPacketConn(conn net.PacketConn, p protocol) net.PacketConn { return &packetConn{p, conn} } 64 65 func (c *packetConn) WriteTo(b []byte, addr net.Addr) (int, error) { 66 data, err := c.protocol.EncryptPacket(b) 67 if err != nil { 68 return 0, err 69 } 70 _, err = c.PacketConn.WriteTo(data, addr) 71 return len(b), err 72 } 73 74 func (c *packetConn) ReadFrom(b []byte) (int, net.Addr, error) { 75 n, addr, err := c.PacketConn.ReadFrom(b) 76 if err != nil { 77 return n, addr, err 78 } 79 decoded, err := c.protocol.DecryptPacket(b[:n]) 80 if err != nil { 81 return n, addr, err 82 } 83 copy(b, decoded) 84 return len(decoded), addr, nil 85 } 86 87 func (c *packetConn) Close() error { return c.PacketConn.Close() } 88 89 type conn struct { 90 protocol protocol 91 net.Conn 92 93 ciphertext, plaintext bytes.Buffer 94 } 95 96 func newConn(c net.Conn, p protocol) net.Conn { 97 return &conn{ 98 Conn: c, 99 protocol: p, 100 } 101 } 102 103 func (c *conn) Read(b []byte) (n int, err error) { 104 if c.plaintext.Len() > 0 { 105 return c.plaintext.Read(b) 106 } 107 108 n, err = c.Conn.Read(b) 109 if err != nil { 110 return 0, err 111 } 112 113 c.ciphertext.Write(b[:n]) 114 length, err := c.protocol.DecryptStream(&c.plaintext, c.ciphertext.Bytes()) 115 if err != nil { 116 c.ciphertext.Reset() 117 return 0, err 118 } 119 c.ciphertext.Next(length) 120 121 n, _ = c.plaintext.Read(b) 122 return n, nil 123 } 124 125 func (c *conn) Write(b []byte) (n int, err error) { 126 buf := pool.GetBuffer() 127 defer pool.PutBuffer(buf) 128 129 if err = c.protocol.EncryptStream(buf, b); err != nil { 130 return 0, err 131 } 132 if _, err = c.Conn.Write(buf.Bytes()); err != nil { 133 return 0, err 134 } 135 return len(b), nil 136 } 137 138 var ProtocolMethod = map[string]func(Protocol) protocol{ 139 "auth_aes128_sha1": NewAuthAES128SHA1, 140 "auth_aes128_md5": NewAuthAES128MD5, 141 "auth_chain_a": NewAuthChainA, 142 "auth_chain_b": NewAuthChainB, 143 "origin": NewOrigin, 144 "auth_sha1_v4": NewAuthSHA1v4, 145 "verify_sha1": NewVerifySHA1, 146 "ota": NewVerifySHA1, 147 } 148 149 type Protocol struct { 150 *cipher.Cipher 151 152 HeadSize int 153 TcpMss int 154 ObfsOverhead int 155 Name string 156 Param string 157 IV []byte 158 159 Auth *AuthData 160 } 161 162 func (s Protocol) stream() (protocol, error) { 163 c, ok := ProtocolMethod[strings.ToLower(s.Name)] 164 if ok { 165 return c(s), nil 166 } 167 return nil, fmt.Errorf("protocol %s not found", s.Name) 168 } 169 170 func (s Protocol) Stream(c net.Conn, iv []byte) (net.Conn, error) { 171 z := s 172 z.IV = iv 173 174 p, err := z.stream() 175 if err != nil { 176 return nil, err 177 } 178 return newConn(c, p), nil 179 } 180 181 func (s Protocol) Packet(c net.PacketConn) (net.PacketConn, error) { 182 p, err := s.stream() 183 if err != nil { 184 return nil, err 185 } 186 return newPacketConn(c, p), nil 187 } 188 189 func (s *Protocol) SetHeadLen(data []byte, defaultValue int) { 190 s.HeadSize = GetHeadSize(data, defaultValue) 191 } 192 193 // https://github.com/shadowsocksrr/shadowsocksr/blob/fd723a92c488d202b407323f0512987346944136/shadowsocks/obfsplugin/plain.py#L93 194 func GetHeadSize(data []byte, defaultValue int) int { 195 if len(data) < 2 { 196 return defaultValue 197 } 198 headType := data[0] & 0x07 199 switch headType { 200 case 1: 201 // IPv4 1+4+2 202 return 7 203 case 4: 204 // IPv6 1+16+2 205 return 19 206 case 3: 207 // domain name, variant length 208 return 4 + int(data[1]) 209 } 210 211 return defaultValue 212 }