github.com/kelleygo/clashcore@v1.0.2/transport/snell/snell.go (about) 1 package snell 2 3 import ( 4 "encoding/binary" 5 "errors" 6 "fmt" 7 "io" 8 "net" 9 "sync" 10 11 "github.com/kelleygo/clashcore/common/pool" 12 "github.com/kelleygo/clashcore/transport/shadowsocks/shadowaead" 13 "github.com/kelleygo/clashcore/transport/socks5" 14 ) 15 16 const ( 17 Version1 = 1 18 Version2 = 2 19 Version3 = 3 20 DefaultSnellVersion = Version1 21 22 // max packet length 23 maxLength = 0x3FFF 24 ) 25 26 const ( 27 CommandPing byte = 0 28 CommandConnect byte = 1 29 CommandConnectV2 byte = 5 30 CommandUDP byte = 6 31 CommondUDPForward byte = 1 32 33 CommandTunnel byte = 0 34 CommandPong byte = 1 35 CommandError byte = 2 36 37 Version byte = 1 38 ) 39 40 var endSignal = []byte{} 41 42 type Snell struct { 43 net.Conn 44 buffer [1]byte 45 reply bool 46 } 47 48 func (s *Snell) Read(b []byte) (int, error) { 49 if s.reply { 50 return s.Conn.Read(b) 51 } 52 53 s.reply = true 54 if _, err := io.ReadFull(s.Conn, s.buffer[:]); err != nil { 55 return 0, err 56 } 57 58 if s.buffer[0] == CommandTunnel { 59 return s.Conn.Read(b) 60 } else if s.buffer[0] != CommandError { 61 return 0, errors.New("command not support") 62 } 63 64 // CommandError 65 // 1 byte error code 66 if _, err := io.ReadFull(s.Conn, s.buffer[:]); err != nil { 67 return 0, err 68 } 69 errcode := int(s.buffer[0]) 70 71 // 1 byte error message length 72 if _, err := io.ReadFull(s.Conn, s.buffer[:]); err != nil { 73 return 0, err 74 } 75 length := int(s.buffer[0]) 76 msg := make([]byte, length) 77 78 if _, err := io.ReadFull(s.Conn, msg); err != nil { 79 return 0, err 80 } 81 82 return 0, fmt.Errorf("server reported code: %d, message: %s", errcode, string(msg)) 83 } 84 85 func WriteHeader(conn net.Conn, host string, port uint, version int) error { 86 buf := pool.GetBuffer() 87 defer pool.PutBuffer(buf) 88 buf.WriteByte(Version) 89 if version == Version2 { 90 buf.WriteByte(CommandConnectV2) 91 } else { 92 buf.WriteByte(CommandConnect) 93 } 94 95 // clientID length & id 96 buf.WriteByte(0) 97 98 // host & port 99 buf.WriteByte(uint8(len(host))) 100 buf.WriteString(host) 101 binary.Write(buf, binary.BigEndian, uint16(port)) 102 103 if _, err := conn.Write(buf.Bytes()); err != nil { 104 return err 105 } 106 107 return nil 108 } 109 110 func WriteUDPHeader(conn net.Conn, version int) error { 111 if version < Version3 { 112 return errors.New("unsupport UDP version") 113 } 114 115 // version, command, clientID length 116 _, err := conn.Write([]byte{Version, CommandUDP, 0x00}) 117 return err 118 } 119 120 // HalfClose works only on version2 121 func HalfClose(conn net.Conn) error { 122 if _, err := conn.Write(endSignal); err != nil { 123 return err 124 } 125 126 if s, ok := conn.(*Snell); ok { 127 s.reply = false 128 } 129 return nil 130 } 131 132 func StreamConn(conn net.Conn, psk []byte, version int) *Snell { 133 var cipher shadowaead.Cipher 134 if version != Version1 { 135 cipher = NewAES128GCM(psk) 136 } else { 137 cipher = NewChacha20Poly1305(psk) 138 } 139 return &Snell{Conn: shadowaead.NewConn(conn, cipher)} 140 } 141 142 func PacketConn(conn net.Conn) net.PacketConn { 143 return &packetConn{ 144 Conn: conn, 145 } 146 } 147 148 func writePacket(w io.Writer, socks5Addr, payload []byte) (int, error) { 149 buf := pool.GetBuffer() 150 defer pool.PutBuffer(buf) 151 152 // compose snell UDP address format (refer: icpz/snell-server-reversed) 153 // a brand new wheel to replace socks5 address format, well done Yachen 154 buf.WriteByte(CommondUDPForward) 155 switch socks5Addr[0] { 156 case socks5.AtypDomainName: 157 hostLen := socks5Addr[1] 158 buf.Write(socks5Addr[1 : 1+1+hostLen+2]) 159 case socks5.AtypIPv4: 160 buf.Write([]byte{0x00, 0x04}) 161 buf.Write(socks5Addr[1 : 1+net.IPv4len+2]) 162 case socks5.AtypIPv6: 163 buf.Write([]byte{0x00, 0x06}) 164 buf.Write(socks5Addr[1 : 1+net.IPv6len+2]) 165 } 166 167 buf.Write(payload) 168 _, err := w.Write(buf.Bytes()) 169 if err != nil { 170 return 0, err 171 } 172 return len(payload), nil 173 } 174 175 func WritePacket(w io.Writer, socks5Addr, payload []byte) (int, error) { 176 if len(payload) <= maxLength { 177 return writePacket(w, socks5Addr, payload) 178 } 179 180 offset := 0 181 total := len(payload) 182 for { 183 cursor := offset + maxLength 184 if cursor > total { 185 cursor = total 186 } 187 188 n, err := writePacket(w, socks5Addr, payload[offset:cursor]) 189 if err != nil { 190 return offset + n, err 191 } 192 193 offset = cursor 194 if offset == total { 195 break 196 } 197 } 198 199 return total, nil 200 } 201 202 func ReadPacket(r io.Reader, payload []byte) (net.Addr, int, error) { 203 buf := pool.Get(pool.UDPBufferSize) 204 defer pool.Put(buf) 205 206 n, err := r.Read(buf) 207 headLen := 1 208 if err != nil { 209 return nil, 0, err 210 } 211 if n < headLen { 212 return nil, 0, errors.New("insufficient UDP length") 213 } 214 215 // parse snell UDP response address format 216 switch buf[0] { 217 case 0x04: 218 headLen += net.IPv4len + 2 219 if n < headLen { 220 err = errors.New("insufficient UDP length") 221 break 222 } 223 buf[0] = socks5.AtypIPv4 224 case 0x06: 225 headLen += net.IPv6len + 2 226 if n < headLen { 227 err = errors.New("insufficient UDP length") 228 break 229 } 230 buf[0] = socks5.AtypIPv6 231 default: 232 err = errors.New("ip version invalid") 233 } 234 235 if err != nil { 236 return nil, 0, err 237 } 238 239 addr := socks5.SplitAddr(buf[0:]) 240 if addr == nil { 241 return nil, 0, errors.New("remote address invalid") 242 } 243 uAddr := addr.UDPAddr() 244 if uAddr == nil { 245 return nil, 0, errors.New("parse addr error") 246 } 247 248 length := len(payload) 249 if n-headLen < length { 250 length = n - headLen 251 } 252 copy(payload[:], buf[headLen:headLen+length]) 253 254 return uAddr, length, nil 255 } 256 257 type packetConn struct { 258 net.Conn 259 rMux sync.Mutex 260 wMux sync.Mutex 261 } 262 263 func (pc *packetConn) WriteTo(b []byte, addr net.Addr) (int, error) { 264 pc.wMux.Lock() 265 defer pc.wMux.Unlock() 266 267 return WritePacket(pc, socks5.ParseAddr(addr.String()), b) 268 } 269 270 func (pc *packetConn) ReadFrom(b []byte) (int, net.Addr, error) { 271 pc.rMux.Lock() 272 defer pc.rMux.Unlock() 273 274 addr, n, err := ReadPacket(pc.Conn, b) 275 if err != nil { 276 return 0, nil, err 277 } 278 279 return n, addr, nil 280 }