github.com/metacubex/mihomo@v1.18.5/transport/trojan/trojan.go (about) 1 package trojan 2 3 import ( 4 "context" 5 "crypto/sha256" 6 "crypto/tls" 7 "encoding/binary" 8 "encoding/hex" 9 "errors" 10 "io" 11 "net" 12 "net/http" 13 "sync" 14 15 N "github.com/metacubex/mihomo/common/net" 16 "github.com/metacubex/mihomo/common/pool" 17 "github.com/metacubex/mihomo/component/ca" 18 tlsC "github.com/metacubex/mihomo/component/tls" 19 C "github.com/metacubex/mihomo/constant" 20 "github.com/metacubex/mihomo/transport/socks5" 21 "github.com/metacubex/mihomo/transport/vmess" 22 ) 23 24 const ( 25 // max packet length 26 maxLength = 8192 27 ) 28 29 var ( 30 defaultALPN = []string{"h2", "http/1.1"} 31 defaultWebsocketALPN = []string{"http/1.1"} 32 33 crlf = []byte{'\r', '\n'} 34 ) 35 36 type Command = byte 37 38 const ( 39 CommandTCP byte = 1 40 CommandUDP byte = 3 41 42 // deprecated XTLS commands, as souvenirs 43 commandXRD byte = 0xf0 // XTLS direct mode 44 commandXRO byte = 0xf1 // XTLS origin mode 45 ) 46 47 type Option struct { 48 Password string 49 ALPN []string 50 ServerName string 51 SkipCertVerify bool 52 Fingerprint string 53 ClientFingerprint string 54 Reality *tlsC.RealityConfig 55 } 56 57 type WebsocketOption struct { 58 Host string 59 Port string 60 Path string 61 Headers http.Header 62 V2rayHttpUpgrade bool 63 V2rayHttpUpgradeFastOpen bool 64 } 65 66 type Trojan struct { 67 option *Option 68 hexPassword []byte 69 } 70 71 func (t *Trojan) StreamConn(ctx context.Context, conn net.Conn) (net.Conn, error) { 72 alpn := defaultALPN 73 if len(t.option.ALPN) != 0 { 74 alpn = t.option.ALPN 75 } 76 tlsConfig := &tls.Config{ 77 NextProtos: alpn, 78 MinVersion: tls.VersionTLS12, 79 InsecureSkipVerify: t.option.SkipCertVerify, 80 ServerName: t.option.ServerName, 81 } 82 83 var err error 84 tlsConfig, err = ca.GetSpecifiedFingerprintTLSConfig(tlsConfig, t.option.Fingerprint) 85 if err != nil { 86 return nil, err 87 } 88 89 if len(t.option.ClientFingerprint) != 0 { 90 if t.option.Reality == nil { 91 utlsConn, valid := vmess.GetUTLSConn(conn, t.option.ClientFingerprint, tlsConfig) 92 if valid { 93 ctx, cancel := context.WithTimeout(context.Background(), C.DefaultTLSTimeout) 94 defer cancel() 95 96 err := utlsConn.(*tlsC.UConn).HandshakeContext(ctx) 97 return utlsConn, err 98 } 99 } else { 100 ctx, cancel := context.WithTimeout(context.Background(), C.DefaultTLSTimeout) 101 defer cancel() 102 return tlsC.GetRealityConn(ctx, conn, t.option.ClientFingerprint, tlsConfig, t.option.Reality) 103 } 104 } 105 if t.option.Reality != nil { 106 return nil, errors.New("REALITY is based on uTLS, please set a client-fingerprint") 107 } 108 109 tlsConn := tls.Client(conn, tlsConfig) 110 111 // fix tls handshake not timeout 112 ctx, cancel := context.WithTimeout(context.Background(), C.DefaultTLSTimeout) 113 defer cancel() 114 115 err = tlsConn.HandshakeContext(ctx) 116 return tlsConn, err 117 } 118 119 func (t *Trojan) StreamWebsocketConn(ctx context.Context, conn net.Conn, wsOptions *WebsocketOption) (net.Conn, error) { 120 alpn := defaultWebsocketALPN 121 if len(t.option.ALPN) != 0 { 122 alpn = t.option.ALPN 123 } 124 125 tlsConfig := &tls.Config{ 126 NextProtos: alpn, 127 MinVersion: tls.VersionTLS12, 128 InsecureSkipVerify: t.option.SkipCertVerify, 129 ServerName: t.option.ServerName, 130 } 131 132 var err error 133 tlsConfig, err = ca.GetSpecifiedFingerprintTLSConfig(tlsConfig, t.option.Fingerprint) 134 if err != nil { 135 return nil, err 136 } 137 138 return vmess.StreamWebsocketConn(ctx, conn, &vmess.WebsocketConfig{ 139 Host: wsOptions.Host, 140 Port: wsOptions.Port, 141 Path: wsOptions.Path, 142 Headers: wsOptions.Headers, 143 V2rayHttpUpgrade: wsOptions.V2rayHttpUpgrade, 144 V2rayHttpUpgradeFastOpen: wsOptions.V2rayHttpUpgradeFastOpen, 145 TLS: true, 146 TLSConfig: tlsConfig, 147 ClientFingerprint: t.option.ClientFingerprint, 148 }) 149 } 150 151 func (t *Trojan) WriteHeader(w io.Writer, command Command, socks5Addr []byte) error { 152 buf := pool.GetBuffer() 153 defer pool.PutBuffer(buf) 154 155 buf.Write(t.hexPassword) 156 buf.Write(crlf) 157 158 buf.WriteByte(command) 159 buf.Write(socks5Addr) 160 buf.Write(crlf) 161 162 _, err := w.Write(buf.Bytes()) 163 return err 164 } 165 166 func (t *Trojan) PacketConn(conn net.Conn) net.PacketConn { 167 return &PacketConn{ 168 Conn: conn, 169 } 170 } 171 172 func writePacket(w io.Writer, socks5Addr, payload []byte) (int, error) { 173 buf := pool.GetBuffer() 174 defer pool.PutBuffer(buf) 175 176 buf.Write(socks5Addr) 177 binary.Write(buf, binary.BigEndian, uint16(len(payload))) 178 buf.Write(crlf) 179 buf.Write(payload) 180 181 return w.Write(buf.Bytes()) 182 } 183 184 func WritePacket(w io.Writer, socks5Addr, payload []byte) (int, error) { 185 if len(payload) <= maxLength { 186 return writePacket(w, socks5Addr, payload) 187 } 188 189 offset := 0 190 total := len(payload) 191 for { 192 cursor := offset + maxLength 193 if cursor > total { 194 cursor = total 195 } 196 197 n, err := writePacket(w, socks5Addr, payload[offset:cursor]) 198 if err != nil { 199 return offset + n, err 200 } 201 202 offset = cursor 203 if offset == total { 204 break 205 } 206 } 207 208 return total, nil 209 } 210 211 func ReadPacket(r io.Reader, payload []byte) (net.Addr, int, int, error) { 212 addr, err := socks5.ReadAddr(r, payload) 213 if err != nil { 214 return nil, 0, 0, errors.New("read addr error") 215 } 216 uAddr := addr.UDPAddr() 217 if uAddr == nil { 218 return nil, 0, 0, errors.New("parse addr error") 219 } 220 221 if _, err = io.ReadFull(r, payload[:2]); err != nil { 222 return nil, 0, 0, errors.New("read length error") 223 } 224 225 total := int(binary.BigEndian.Uint16(payload[:2])) 226 if total > maxLength { 227 return nil, 0, 0, errors.New("packet invalid") 228 } 229 230 // read crlf 231 if _, err = io.ReadFull(r, payload[:2]); err != nil { 232 return nil, 0, 0, errors.New("read crlf error") 233 } 234 235 length := len(payload) 236 if total < length { 237 length = total 238 } 239 240 if _, err = io.ReadFull(r, payload[:length]); err != nil { 241 return nil, 0, 0, errors.New("read packet error") 242 } 243 244 return uAddr, length, total - length, nil 245 } 246 247 func New(option *Option) *Trojan { 248 return &Trojan{option, hexSha224([]byte(option.Password))} 249 } 250 251 var _ N.EnhancePacketConn = (*PacketConn)(nil) 252 253 type PacketConn struct { 254 net.Conn 255 remain int 256 rAddr net.Addr 257 mux sync.Mutex 258 } 259 260 func (pc *PacketConn) WriteTo(b []byte, addr net.Addr) (int, error) { 261 return WritePacket(pc, socks5.ParseAddr(addr.String()), b) 262 } 263 264 func (pc *PacketConn) ReadFrom(b []byte) (int, net.Addr, error) { 265 pc.mux.Lock() 266 defer pc.mux.Unlock() 267 if pc.remain != 0 { 268 length := len(b) 269 if pc.remain < length { 270 length = pc.remain 271 } 272 273 n, err := pc.Conn.Read(b[:length]) 274 if err != nil { 275 return 0, nil, err 276 } 277 278 pc.remain -= n 279 addr := pc.rAddr 280 if pc.remain == 0 { 281 pc.rAddr = nil 282 } 283 284 return n, addr, nil 285 } 286 287 addr, n, remain, err := ReadPacket(pc.Conn, b) 288 if err != nil { 289 return 0, nil, err 290 } 291 292 if remain != 0 { 293 pc.remain = remain 294 pc.rAddr = addr 295 } 296 297 return n, addr, nil 298 } 299 300 func (pc *PacketConn) WaitReadFrom() (data []byte, put func(), addr net.Addr, err error) { 301 pc.mux.Lock() 302 defer pc.mux.Unlock() 303 304 destination, err := socks5.ReadAddr0(pc.Conn) 305 if err != nil { 306 return nil, nil, nil, err 307 } 308 addr = destination.UDPAddr() 309 310 data = pool.Get(pool.UDPBufferSize) 311 put = func() { 312 _ = pool.Put(data) 313 } 314 315 _, err = io.ReadFull(pc.Conn, data[:2+2]) // u16be length + CR LF 316 if err != nil { 317 if put != nil { 318 put() 319 } 320 return nil, nil, nil, err 321 } 322 length := binary.BigEndian.Uint16(data) 323 324 if length > 0 { 325 data = data[:length] 326 _, err = io.ReadFull(pc.Conn, data) 327 if err != nil { 328 if put != nil { 329 put() 330 } 331 return nil, nil, nil, err 332 } 333 } else { 334 if put != nil { 335 put() 336 } 337 return nil, nil, addr, nil 338 } 339 340 return 341 } 342 343 func hexSha224(data []byte) []byte { 344 buf := make([]byte, 56) 345 hash := sha256.Sum224(data) 346 hex.Encode(buf, hash[:]) 347 return buf 348 }