github.com/kelleygo/clashcore@v1.0.2/transport/trojan/trojan.go (about) 1 package trojan 2 3 import ( 4 "context" 5 "crypto/sha256" 6 "crypto/tls" 7 "encoding/binary" 8 "encoding/hex" 9 "errors" 10 "io" 11 "net" 12 "net/http" 13 "sync" 14 15 N "github.com/kelleygo/clashcore/common/net" 16 "github.com/kelleygo/clashcore/common/pool" 17 "github.com/kelleygo/clashcore/component/ca" 18 tlsC "github.com/kelleygo/clashcore/component/tls" 19 C "github.com/kelleygo/clashcore/constant" 20 "github.com/kelleygo/clashcore/transport/socks5" 21 "github.com/kelleygo/clashcore/transport/vmess" 22 ) 23 24 const ( 25 // max packet length 26 maxLength = 8192 27 ) 28 29 var ( 30 defaultALPN = []string{"h2", "http/1.1"} 31 defaultWebsocketALPN = []string{"http/1.1"} 32 33 crlf = []byte{'\r', '\n'} 34 ) 35 36 type Command = byte 37 38 const ( 39 CommandTCP byte = 1 40 CommandUDP byte = 3 41 42 // deprecated XTLS commands, as souvenirs 43 commandXRD byte = 0xf0 // XTLS direct mode 44 commandXRO byte = 0xf1 // XTLS origin mode 45 ) 46 47 type Option struct { 48 Password string 49 ALPN []string 50 ServerName string 51 SkipCertVerify bool 52 Fingerprint string 53 ClientFingerprint string 54 Reality *tlsC.RealityConfig 55 } 56 57 type WebsocketOption struct { 58 Host string 59 Port string 60 Path string 61 Headers http.Header 62 V2rayHttpUpgrade bool 63 V2rayHttpUpgradeFastOpen bool 64 } 65 66 type Trojan struct { 67 option *Option 68 hexPassword []byte 69 } 70 71 func (t *Trojan) StreamConn(ctx context.Context, conn net.Conn) (net.Conn, error) { 72 alpn := defaultALPN 73 if len(t.option.ALPN) != 0 { 74 alpn = t.option.ALPN 75 } 76 tlsConfig := &tls.Config{ 77 NextProtos: alpn, 78 MinVersion: tls.VersionTLS12, 79 InsecureSkipVerify: t.option.SkipCertVerify, 80 ServerName: t.option.ServerName, 81 } 82 83 var err error 84 tlsConfig, err = ca.GetSpecifiedFingerprintTLSConfig(tlsConfig, t.option.Fingerprint) 85 if err != nil { 86 return nil, err 87 } 88 89 if len(t.option.ClientFingerprint) != 0 { 90 if t.option.Reality == nil { 91 utlsConn, valid := vmess.GetUTLSConn(conn, t.option.ClientFingerprint, tlsConfig) 92 if valid { 93 ctx, cancel := context.WithTimeout(context.Background(), C.DefaultTLSTimeout) 94 defer cancel() 95 96 err := utlsConn.(*tlsC.UConn).HandshakeContext(ctx) 97 return utlsConn, err 98 } 99 } else { 100 ctx, cancel := context.WithTimeout(context.Background(), C.DefaultTLSTimeout) 101 defer cancel() 102 return tlsC.GetRealityConn(ctx, conn, t.option.ClientFingerprint, tlsConfig, t.option.Reality) 103 } 104 } 105 if t.option.Reality != nil { 106 return nil, errors.New("REALITY is based on uTLS, please set a client-fingerprint") 107 } 108 109 tlsConn := tls.Client(conn, tlsConfig) 110 111 // fix tls handshake not timeout 112 ctx, cancel := context.WithTimeout(context.Background(), C.DefaultTLSTimeout) 113 defer cancel() 114 115 err = tlsConn.HandshakeContext(ctx) 116 return tlsConn, err 117 } 118 119 func (t *Trojan) StreamWebsocketConn(ctx context.Context, conn net.Conn, wsOptions *WebsocketOption) (net.Conn, error) { 120 alpn := defaultWebsocketALPN 121 if len(t.option.ALPN) != 0 { 122 alpn = t.option.ALPN 123 } 124 125 tlsConfig := &tls.Config{ 126 NextProtos: alpn, 127 MinVersion: tls.VersionTLS12, 128 InsecureSkipVerify: t.option.SkipCertVerify, 129 ServerName: t.option.ServerName, 130 } 131 132 return vmess.StreamWebsocketConn(ctx, conn, &vmess.WebsocketConfig{ 133 Host: wsOptions.Host, 134 Port: wsOptions.Port, 135 Path: wsOptions.Path, 136 Headers: wsOptions.Headers, 137 V2rayHttpUpgrade: wsOptions.V2rayHttpUpgrade, 138 V2rayHttpUpgradeFastOpen: wsOptions.V2rayHttpUpgradeFastOpen, 139 TLS: true, 140 TLSConfig: tlsConfig, 141 ClientFingerprint: t.option.ClientFingerprint, 142 }) 143 } 144 145 func (t *Trojan) WriteHeader(w io.Writer, command Command, socks5Addr []byte) error { 146 buf := pool.GetBuffer() 147 defer pool.PutBuffer(buf) 148 149 buf.Write(t.hexPassword) 150 buf.Write(crlf) 151 152 buf.WriteByte(command) 153 buf.Write(socks5Addr) 154 buf.Write(crlf) 155 156 _, err := w.Write(buf.Bytes()) 157 return err 158 } 159 160 func (t *Trojan) PacketConn(conn net.Conn) net.PacketConn { 161 return &PacketConn{ 162 Conn: conn, 163 } 164 } 165 166 func writePacket(w io.Writer, socks5Addr, payload []byte) (int, error) { 167 buf := pool.GetBuffer() 168 defer pool.PutBuffer(buf) 169 170 buf.Write(socks5Addr) 171 binary.Write(buf, binary.BigEndian, uint16(len(payload))) 172 buf.Write(crlf) 173 buf.Write(payload) 174 175 return w.Write(buf.Bytes()) 176 } 177 178 func WritePacket(w io.Writer, socks5Addr, payload []byte) (int, error) { 179 if len(payload) <= maxLength { 180 return writePacket(w, socks5Addr, payload) 181 } 182 183 offset := 0 184 total := len(payload) 185 for { 186 cursor := offset + maxLength 187 if cursor > total { 188 cursor = total 189 } 190 191 n, err := writePacket(w, socks5Addr, payload[offset:cursor]) 192 if err != nil { 193 return offset + n, err 194 } 195 196 offset = cursor 197 if offset == total { 198 break 199 } 200 } 201 202 return total, nil 203 } 204 205 func ReadPacket(r io.Reader, payload []byte) (net.Addr, int, int, error) { 206 addr, err := socks5.ReadAddr(r, payload) 207 if err != nil { 208 return nil, 0, 0, errors.New("read addr error") 209 } 210 uAddr := addr.UDPAddr() 211 if uAddr == nil { 212 return nil, 0, 0, errors.New("parse addr error") 213 } 214 215 if _, err = io.ReadFull(r, payload[:2]); err != nil { 216 return nil, 0, 0, errors.New("read length error") 217 } 218 219 total := int(binary.BigEndian.Uint16(payload[:2])) 220 if total > maxLength { 221 return nil, 0, 0, errors.New("packet invalid") 222 } 223 224 // read crlf 225 if _, err = io.ReadFull(r, payload[:2]); err != nil { 226 return nil, 0, 0, errors.New("read crlf error") 227 } 228 229 length := len(payload) 230 if total < length { 231 length = total 232 } 233 234 if _, err = io.ReadFull(r, payload[:length]); err != nil { 235 return nil, 0, 0, errors.New("read packet error") 236 } 237 238 return uAddr, length, total - length, nil 239 } 240 241 func New(option *Option) *Trojan { 242 return &Trojan{option, hexSha224([]byte(option.Password))} 243 } 244 245 var _ N.EnhancePacketConn = (*PacketConn)(nil) 246 247 type PacketConn struct { 248 net.Conn 249 remain int 250 rAddr net.Addr 251 mux sync.Mutex 252 } 253 254 func (pc *PacketConn) WriteTo(b []byte, addr net.Addr) (int, error) { 255 return WritePacket(pc, socks5.ParseAddr(addr.String()), b) 256 } 257 258 func (pc *PacketConn) ReadFrom(b []byte) (int, net.Addr, error) { 259 pc.mux.Lock() 260 defer pc.mux.Unlock() 261 if pc.remain != 0 { 262 length := len(b) 263 if pc.remain < length { 264 length = pc.remain 265 } 266 267 n, err := pc.Conn.Read(b[:length]) 268 if err != nil { 269 return 0, nil, err 270 } 271 272 pc.remain -= n 273 addr := pc.rAddr 274 if pc.remain == 0 { 275 pc.rAddr = nil 276 } 277 278 return n, addr, nil 279 } 280 281 addr, n, remain, err := ReadPacket(pc.Conn, b) 282 if err != nil { 283 return 0, nil, err 284 } 285 286 if remain != 0 { 287 pc.remain = remain 288 pc.rAddr = addr 289 } 290 291 return n, addr, nil 292 } 293 294 func (pc *PacketConn) WaitReadFrom() (data []byte, put func(), addr net.Addr, err error) { 295 pc.mux.Lock() 296 defer pc.mux.Unlock() 297 298 destination, err := socks5.ReadAddr0(pc.Conn) 299 if err != nil { 300 return nil, nil, nil, err 301 } 302 addr = destination.UDPAddr() 303 304 data = pool.Get(pool.UDPBufferSize) 305 put = func() { 306 _ = pool.Put(data) 307 } 308 309 _, err = io.ReadFull(pc.Conn, data[:2+2]) // u16be length + CR LF 310 if err != nil { 311 if put != nil { 312 put() 313 } 314 return nil, nil, nil, err 315 } 316 length := binary.BigEndian.Uint16(data) 317 318 if length > 0 { 319 data = data[:length] 320 _, err = io.ReadFull(pc.Conn, data) 321 if err != nil { 322 if put != nil { 323 put() 324 } 325 return nil, nil, nil, err 326 } 327 } else { 328 if put != nil { 329 put() 330 } 331 return nil, nil, addr, nil 332 } 333 334 return 335 } 336 337 func hexSha224(data []byte) []byte { 338 buf := make([]byte, 56) 339 hash := sha256.Sum224(data) 340 hex.Encode(buf, hash[:]) 341 return buf 342 }