github.com/yaling888/clash@v1.53.0/transport/trojan/trojan.go (about) 1 package trojan 2 3 import ( 4 "context" 5 "crypto/sha256" 6 "crypto/tls" 7 "encoding/binary" 8 "encoding/hex" 9 "errors" 10 "fmt" 11 "io" 12 "net" 13 "net/http" 14 "sync" 15 16 "github.com/yaling888/clash/common/pool" 17 C "github.com/yaling888/clash/constant" 18 "github.com/yaling888/clash/transport/h2" 19 "github.com/yaling888/clash/transport/socks5" 20 "github.com/yaling888/clash/transport/vmess" 21 ) 22 23 const ( 24 // max packet length 25 maxLength = 8192 26 ) 27 28 var ( 29 defaultALPN = []string{"h2", "http/1.1"} 30 defaultWebsocketALPN = []string{"http/1.1"} 31 32 crlf = []byte{'\r', '\n'} 33 ) 34 35 type Command = byte 36 37 const ( 38 CommandTCP byte = 1 39 CommandUDP byte = 3 40 ) 41 42 type Option struct { 43 Password string 44 ALPN []string 45 ServerName string 46 SkipCertVerify bool 47 } 48 49 type HTTPOptions struct { 50 Host string 51 Port int 52 Hosts []string 53 Path string 54 Headers http.Header 55 } 56 57 type WebsocketOption struct { 58 Host string 59 Port string 60 Path string 61 Headers http.Header 62 } 63 64 type Trojan struct { 65 option *Option 66 hexPassword []byte 67 } 68 69 func (t *Trojan) StreamConn(conn net.Conn) (net.Conn, error) { 70 alpn := defaultALPN 71 if len(t.option.ALPN) != 0 { 72 alpn = t.option.ALPN 73 } 74 75 tlsConfig := &tls.Config{ 76 NextProtos: alpn, 77 MinVersion: tls.VersionTLS12, 78 InsecureSkipVerify: t.option.SkipCertVerify, 79 ServerName: t.option.ServerName, 80 } 81 82 tlsConn := tls.Client(conn, tlsConfig) 83 84 // fix tls handshake not timeout 85 ctx, cancel := context.WithTimeout(context.Background(), C.DefaultTLSTimeout) 86 defer cancel() 87 if err := tlsConn.HandshakeContext(ctx); err != nil { 88 return nil, err 89 } 90 91 return tlsConn, nil 92 } 93 94 func (t *Trojan) StreamH2Conn(conn net.Conn, h2Option *HTTPOptions) (net.Conn, error) { 95 tlsConfig := &tls.Config{ 96 NextProtos: []string{"h2"}, 97 MinVersion: tls.VersionTLS12, 98 InsecureSkipVerify: t.option.SkipCertVerify, 99 ServerName: t.option.ServerName, 100 } 101 102 tlsConn := tls.Client(conn, tlsConfig) 103 104 ctx, cancel := context.WithTimeout(context.Background(), C.DefaultTLSTimeout) 105 defer cancel() 106 if err := tlsConn.HandshakeContext(ctx); err != nil { 107 return nil, err 108 } 109 110 return h2.StreamH2Conn(tlsConn, &h2.Config{ 111 Hosts: h2Option.Hosts, 112 Path: h2Option.Path, 113 Headers: h2Option.Headers, 114 }) 115 } 116 117 func (t *Trojan) StreamWebsocketConn(conn net.Conn, wsOptions *WebsocketOption) (net.Conn, error) { 118 alpn := defaultWebsocketALPN 119 if len(t.option.ALPN) != 0 { 120 alpn = t.option.ALPN 121 } 122 123 tlsConfig := &tls.Config{ 124 NextProtos: alpn, 125 MinVersion: tls.VersionTLS12, 126 InsecureSkipVerify: t.option.SkipCertVerify, 127 ServerName: t.option.ServerName, 128 } 129 130 return vmess.StreamWebsocketConn(conn, &vmess.WebsocketConfig{ 131 Host: wsOptions.Host, 132 Port: wsOptions.Port, 133 Path: wsOptions.Path, 134 Headers: wsOptions.Headers, 135 TLS: true, 136 TLSConfig: tlsConfig, 137 }) 138 } 139 140 func (t *Trojan) WriteHeader(w io.Writer, command Command, socks5Addr []byte) error { 141 buf := pool.BufferWriter{} 142 143 buf.PutSlice(t.hexPassword) 144 buf.PutSlice(crlf) 145 146 buf.PutUint8(command) 147 buf.PutSlice(socks5Addr) 148 buf.PutSlice(crlf) 149 150 _, err := w.Write(buf.Bytes()) 151 return err 152 } 153 154 func (t *Trojan) PacketConn(conn net.Conn) net.PacketConn { 155 return &PacketConn{ 156 Conn: conn, 157 } 158 } 159 160 func writePacket(w io.Writer, socks5Addr, payload []byte) (n int, err error) { 161 bufP := pool.GetNetBuf() 162 defer pool.PutNetBuf(bufP) 163 164 n = len(payload) 165 t := copy(*bufP, socks5Addr) 166 binary.BigEndian.PutUint16((*bufP)[t:], uint16(n)) 167 t += 2 168 t += copy((*bufP)[t:], crlf) 169 t += copy((*bufP)[t:], payload) 170 171 delta := t - n 172 n, err = w.Write((*bufP)[:t]) 173 if n < t && err == nil { 174 err = io.ErrShortWrite 175 } 176 n = max(n-delta, 0) 177 return 178 } 179 180 func WritePacket(w io.Writer, socks5Addr, payload []byte) (n int, err error) { 181 total := len(payload) 182 if total <= maxLength { 183 return writePacket(w, socks5Addr, payload) 184 } 185 186 offset := 0 187 cursor := 0 188 for { 189 cursor = min(offset+maxLength, total) 190 191 n, err = writePacket(w, socks5Addr, payload[offset:cursor]) 192 193 offset = min(offset+n, total) 194 if err != nil || offset == total { 195 n = offset 196 return 197 } 198 } 199 } 200 201 func ReadPacket(r io.Reader, payload []byte) (addr *net.UDPAddr, n int, remain int, err error) { 202 var socAddr socks5.Addr 203 socAddr, err = socks5.ReadAddr(r, payload) 204 if err != nil { 205 if err != io.EOF { 206 err = fmt.Errorf("read addr error, %w", err) 207 } 208 return 209 } 210 addr = socAddr.UDPAddr() 211 if addr == nil { 212 err = errors.New("parse addr error") 213 return 214 } 215 216 if _, err = io.ReadFull(r, payload[:2]); err != nil { 217 if err != io.EOF { 218 err = fmt.Errorf("read length error, %w", err) 219 } 220 return 221 } 222 223 total := int(binary.BigEndian.Uint16(payload[:2])) 224 if total > maxLength { 225 err = errors.New("invalid packet") 226 return 227 } 228 229 // read crlf 230 if _, err = io.ReadFull(r, payload[:2]); err != nil { 231 if err != io.EOF { 232 err = fmt.Errorf("read crlf error, %w", err) 233 } 234 return 235 } 236 237 length := min(len(payload), total) 238 if length, err = io.ReadFull(r, payload[:length]); err != nil && err != io.EOF { 239 err = fmt.Errorf("read packet error, %w", err) 240 } 241 242 return addr, length, total - length, err 243 } 244 245 func New(option *Option) *Trojan { 246 return &Trojan{option, hexSha224([]byte(option.Password))} 247 } 248 249 type PacketConn struct { 250 net.Conn 251 remain int 252 rAddr net.Addr 253 mux sync.Mutex 254 } 255 256 func (pc *PacketConn) WriteTo(b []byte, addr net.Addr) (int, error) { 257 return WritePacket(pc, socks5.ParseAddr(addr.String()), b) 258 } 259 260 func (pc *PacketConn) ReadFrom(b []byte) (int, net.Addr, error) { 261 pc.mux.Lock() 262 defer pc.mux.Unlock() 263 if pc.remain != 0 { 264 length := min(len(b), pc.remain) 265 266 n, err := pc.Conn.Read(b[:length]) 267 268 pc.remain -= n 269 addr := pc.rAddr 270 if pc.remain == 0 { 271 pc.rAddr = nil 272 } 273 274 return n, addr, err 275 } 276 277 addr, n, remain, err := ReadPacket(pc.Conn, b) 278 if err == nil && remain > 0 { 279 pc.remain = remain 280 pc.rAddr = addr 281 } 282 283 return n, addr, err 284 } 285 286 func hexSha224(data []byte) []byte { 287 buf := make([]byte, 56) 288 hash := sha256.New224() 289 hash.Write(data) 290 hex.Encode(buf, hash.Sum(nil)) 291 return buf 292 }