github.com/yaling888/clash@v1.53.0/transport/snell/snell.go (about) 1 package snell 2 3 import ( 4 "errors" 5 "fmt" 6 "io" 7 "net" 8 "sync" 9 10 "github.com/yaling888/clash/common/pool" 11 "github.com/yaling888/clash/transport/shadowsocks/shadowaead" 12 "github.com/yaling888/clash/transport/socks5" 13 ) 14 15 const ( 16 Version1 = 1 17 Version2 = 2 18 Version3 = 3 19 DefaultSnellVersion = Version1 20 21 // max packet length 22 maxLength = 0x3FFF 23 ) 24 25 const ( 26 CommandPing byte = 0 27 CommandConnect byte = 1 28 CommandConnectV2 byte = 5 29 CommandUDP byte = 6 30 CommondUDPForward byte = 1 31 32 CommandTunnel byte = 0 33 CommandPong byte = 1 34 CommandError byte = 2 35 36 Version byte = 1 37 ) 38 39 var endSignal = []byte{} 40 41 type Snell struct { 42 net.Conn 43 buffer [1]byte 44 reply bool 45 } 46 47 func (s *Snell) Read(b []byte) (int, error) { 48 if s.reply { 49 return s.Conn.Read(b) 50 } 51 52 s.reply = true 53 if _, err := io.ReadFull(s.Conn, s.buffer[:]); err != nil { 54 return 0, err 55 } 56 57 if s.buffer[0] == CommandTunnel { 58 return s.Conn.Read(b) 59 } else if s.buffer[0] != CommandError { 60 return 0, errors.New("command not support") 61 } 62 63 // CommandError 64 // 1 byte error code 65 if _, err := io.ReadFull(s.Conn, s.buffer[:]); err != nil { 66 return 0, err 67 } 68 errcode := int(s.buffer[0]) 69 70 // 1 byte error message length 71 if _, err := io.ReadFull(s.Conn, s.buffer[:]); err != nil { 72 return 0, err 73 } 74 length := int(s.buffer[0]) 75 msg := make([]byte, length) 76 77 if _, err := io.ReadFull(s.Conn, msg); err != nil { 78 return 0, err 79 } 80 81 return 0, fmt.Errorf("server reported code: %d, message: %s", errcode, string(msg)) 82 } 83 84 func WriteHeader(conn net.Conn, host string, port uint, version int) error { 85 buf := pool.BufferWriter{} 86 87 buf.PutUint8(Version) 88 if version == Version2 { 89 buf.PutUint8(CommandConnectV2) 90 } else { 91 buf.PutUint8(CommandConnect) 92 } 93 94 // clientID length & id 95 buf.PutUint8(0) 96 97 // host & port 98 buf.PutUint8(uint8(len(host))) 99 buf.PutString(host) 100 buf.PutUint16be(uint16(port)) 101 102 if _, err := conn.Write(buf.Bytes()); err != nil { 103 return err 104 } 105 106 return nil 107 } 108 109 func WriteUDPHeader(conn net.Conn, version int) error { 110 if version < Version3 { 111 return errors.New("unsupport UDP version") 112 } 113 114 // version, command, clientID length 115 _, err := conn.Write([]byte{Version, CommandUDP, 0x00}) 116 return err 117 } 118 119 // HalfClose works only on version2 120 func HalfClose(conn net.Conn) error { 121 if _, err := conn.Write(endSignal); err != nil { 122 return err 123 } 124 125 if s, ok := conn.(*Snell); ok { 126 s.reply = false 127 } 128 return nil 129 } 130 131 func StreamConn(conn net.Conn, psk []byte, version int) *Snell { 132 var cipher shadowaead.Cipher 133 if version != Version1 { 134 cipher = NewAES128GCM(psk) 135 } else { 136 cipher = NewChacha20Poly1305(psk) 137 } 138 return &Snell{Conn: shadowaead.NewConn(conn, cipher)} 139 } 140 141 func PacketConn(conn net.Conn) net.PacketConn { 142 return &packetConn{ 143 Conn: conn, 144 } 145 } 146 147 func writePacket(w io.Writer, socks5Addr, payload []byte) (int, error) { 148 buf := pool.GetBufferWriter() 149 defer pool.PutBufferWriter(buf) 150 151 // compose snell UDP address format (refer: icpz/snell-server-reversed) 152 // a brand-new wheel to replace socks5 address format, well done Yachen 153 buf.PutUint8(CommondUDPForward) 154 switch socks5Addr[0] { 155 case socks5.AtypDomainName: 156 hostLen := socks5Addr[1] 157 buf.PutSlice(socks5Addr[1 : 1+1+hostLen+2]) 158 case socks5.AtypIPv4: 159 buf.PutSlice([]byte{0x00, 0x04}) 160 buf.PutSlice(socks5Addr[1 : 1+net.IPv4len+2]) 161 case socks5.AtypIPv6: 162 buf.PutSlice([]byte{0x00, 0x06}) 163 buf.PutSlice(socks5Addr[1 : 1+net.IPv6len+2]) 164 } 165 166 buf.PutSlice(payload) 167 _, err := w.Write(buf.Bytes()) 168 if err != nil { 169 return 0, err 170 } 171 return len(payload), nil 172 } 173 174 func WritePacket(w io.Writer, socks5Addr, payload []byte) (int, error) { 175 if len(payload) <= maxLength { 176 return writePacket(w, socks5Addr, payload) 177 } 178 179 offset := 0 180 total := len(payload) 181 for { 182 cursor := offset + maxLength 183 if cursor > total { 184 cursor = total 185 } 186 187 n, err := writePacket(w, socks5Addr, payload[offset:cursor]) 188 if err != nil { 189 return offset + n, err 190 } 191 192 offset = cursor 193 if offset == total { 194 break 195 } 196 } 197 198 return total, nil 199 } 200 201 func ReadPacket(r io.Reader, payload []byte) (net.Addr, int, error) { 202 bufP := pool.GetNetBuf() 203 defer pool.PutNetBuf(bufP) 204 205 n, err := r.Read(*bufP) 206 headLen := 1 207 if err != nil { 208 return nil, 0, err 209 } 210 if n < headLen { 211 return nil, 0, errors.New("insufficient UDP length") 212 } 213 214 // parse snell UDP response address format 215 switch (*bufP)[0] { 216 case 0x04: 217 headLen += net.IPv4len + 2 218 if n < headLen { 219 err = errors.New("insufficient UDP length") 220 break 221 } 222 (*bufP)[0] = socks5.AtypIPv4 223 case 0x06: 224 headLen += net.IPv6len + 2 225 if n < headLen { 226 err = errors.New("insufficient UDP length") 227 break 228 } 229 (*bufP)[0] = socks5.AtypIPv6 230 default: 231 err = errors.New("ip version invalid") 232 } 233 234 if err != nil { 235 return nil, 0, err 236 } 237 238 addr := socks5.SplitAddr((*bufP)[0:]) 239 if addr == nil { 240 return nil, 0, errors.New("remote address invalid") 241 } 242 uAddr := addr.UDPAddr() 243 if uAddr == nil { 244 return nil, 0, errors.New("parse addr error") 245 } 246 247 length := len(payload) 248 if n-headLen < length { 249 length = n - headLen 250 } 251 copy(payload[:], (*bufP)[headLen:headLen+length]) 252 253 return uAddr, length, nil 254 } 255 256 type packetConn struct { 257 net.Conn 258 rMux sync.Mutex 259 wMux sync.Mutex 260 } 261 262 func (pc *packetConn) WriteTo(b []byte, addr net.Addr) (int, error) { 263 pc.wMux.Lock() 264 defer pc.wMux.Unlock() 265 266 return WritePacket(pc, socks5.ParseAddr(addr.String()), b) 267 } 268 269 func (pc *packetConn) ReadFrom(b []byte) (int, net.Addr, error) { 270 pc.rMux.Lock() 271 defer pc.rMux.Unlock() 272 273 addr, n, err := ReadPacket(pc.Conn, b) 274 if err != nil { 275 return 0, nil, err 276 } 277 278 return n, addr, nil 279 }