github.com/xxf098/lite-proxy@v0.15.1-0.20230422081941-12c69f323218/transport/trojan/trojan.go (about) 1 package trojan 2 3 import ( 4 "bytes" 5 "context" 6 "crypto/sha256" 7 "crypto/tls" 8 "encoding/binary" 9 "encoding/hex" 10 "errors" 11 "io" 12 "net" 13 "net/http" 14 "sync" 15 16 C "github.com/xxf098/lite-proxy/constant" 17 "github.com/xxf098/lite-proxy/transport/socks5" 18 "github.com/xxf098/lite-proxy/transport/vmess" 19 ) 20 21 const ( 22 // max packet length 23 maxLength = 8192 24 ) 25 26 var ( 27 defaultALPN = []string{"h2", "http/1.1"} 28 defaultWebsocketALPN = []string{"http/1.1"} 29 crlf = []byte{'\r', '\n'} 30 31 bufPool = sync.Pool{New: func() interface{} { return &bytes.Buffer{} }} 32 ) 33 34 type Command = byte 35 36 var ( 37 CommandTCP byte = 1 38 CommandUDP byte = 3 39 ) 40 41 type Option struct { 42 Password string 43 ALPN []string 44 ServerName string 45 SkipCertVerify bool 46 ClientSessionCache tls.ClientSessionCache 47 } 48 49 type WebsocketOption struct { 50 Host string 51 Port string 52 Path string 53 Headers http.Header 54 } 55 56 type Trojan struct { 57 option *Option 58 hexPassword []byte 59 } 60 61 func (t *Trojan) StreamConn(conn net.Conn) (net.Conn, error) { 62 alpn := defaultALPN 63 if len(t.option.ALPN) != 0 { 64 alpn = t.option.ALPN 65 } 66 67 tlsConfig := &tls.Config{ 68 NextProtos: alpn, 69 MinVersion: tls.VersionTLS12, 70 InsecureSkipVerify: t.option.SkipCertVerify, 71 ServerName: t.option.ServerName, 72 ClientSessionCache: t.option.ClientSessionCache, 73 } 74 75 tlsConn := tls.Client(conn, tlsConfig) 76 ctx, cancel := context.WithTimeout(context.Background(), C.DefaultTLSTimeout) 77 defer cancel() 78 if err := tlsConn.HandshakeContext(ctx); err != nil { 79 return nil, err 80 } 81 82 return tlsConn, nil 83 } 84 85 func (t *Trojan) StreamWebsocketConn(conn net.Conn, wsOptions *WebsocketOption) (net.Conn, error) { 86 alpn := defaultWebsocketALPN 87 if len(t.option.ALPN) != 0 { 88 alpn = t.option.ALPN 89 } 90 91 tlsConfig := &tls.Config{ 92 NextProtos: alpn, 93 MinVersion: tls.VersionTLS12, 94 InsecureSkipVerify: t.option.SkipCertVerify, 95 ServerName: t.option.ServerName, 96 } 97 98 return vmess.StreamWebsocketConn(conn, &vmess.WebsocketConfig{ 99 Host: wsOptions.Host, 100 Port: wsOptions.Port, 101 Path: wsOptions.Path, 102 Headers: wsOptions.Headers, 103 TLS: true, 104 TLSConfig: tlsConfig, 105 }) 106 } 107 108 func (t *Trojan) WriteHeader(w io.Writer, command Command, socks5Addr []byte) error { 109 buf := bufPool.Get().(*bytes.Buffer) 110 defer bufPool.Put(buf) 111 defer buf.Reset() 112 113 buf.Write(t.hexPassword) 114 buf.Write(crlf) 115 116 buf.WriteByte(command) 117 buf.Write(socks5Addr) 118 buf.Write(crlf) 119 120 _, err := w.Write(buf.Bytes()) 121 return err 122 } 123 124 func (t *Trojan) PacketConn(conn net.Conn) net.PacketConn { 125 return &PacketConn{ 126 Conn: conn, 127 } 128 } 129 130 func writePacket(w io.Writer, socks5Addr, payload []byte) (int, error) { 131 buf := bufPool.Get().(*bytes.Buffer) 132 defer bufPool.Put(buf) 133 defer buf.Reset() 134 135 buf.Write(socks5Addr) 136 binary.Write(buf, binary.BigEndian, uint16(len(payload))) 137 buf.Write(crlf) 138 buf.Write(payload) 139 140 return w.Write(buf.Bytes()) 141 } 142 143 func WritePacket(w io.Writer, socks5Addr, payload []byte) (int, error) { 144 if len(payload) <= maxLength { 145 return writePacket(w, socks5Addr, payload) 146 } 147 148 offset := 0 149 total := len(payload) 150 for { 151 cursor := offset + maxLength 152 if cursor > total { 153 cursor = total 154 } 155 156 n, err := writePacket(w, socks5Addr, payload[offset:cursor]) 157 if err != nil { 158 return offset + n, err 159 } 160 161 offset = cursor 162 if offset == total { 163 break 164 } 165 } 166 167 return total, nil 168 } 169 170 func ReadPacket(r io.Reader, payload []byte) (net.Addr, int, int, error) { 171 addr, err := socks5.ReadAddr(r, payload) 172 if err != nil { 173 return nil, 0, 0, errors.New("read addr error") 174 } 175 uAddr := addr.UDPAddr() 176 177 if _, err = io.ReadFull(r, payload[:2]); err != nil { 178 return nil, 0, 0, errors.New("read length error") 179 } 180 181 total := int(binary.BigEndian.Uint16(payload[:2])) 182 if total > maxLength { 183 return nil, 0, 0, errors.New("packet invalid") 184 } 185 186 // read crlf 187 if _, err = io.ReadFull(r, payload[:2]); err != nil { 188 return nil, 0, 0, errors.New("read crlf error") 189 } 190 191 length := len(payload) 192 if total < length { 193 length = total 194 } 195 196 if _, err = io.ReadFull(r, payload[:length]); err != nil { 197 return nil, 0, 0, errors.New("read packet error") 198 } 199 200 return uAddr, length, total - length, nil 201 } 202 203 func New(option *Option) *Trojan { 204 return &Trojan{option, hexSha224([]byte(option.Password))} 205 } 206 207 type PacketConn struct { 208 net.Conn 209 remain int 210 rAddr net.Addr 211 mux sync.Mutex 212 } 213 214 func (pc *PacketConn) WriteTo(b []byte, addr net.Addr) (int, error) { 215 return WritePacket(pc, socks5.ParseAddr(addr.String()), b) 216 } 217 218 func (pc *PacketConn) ReadFrom(b []byte) (int, net.Addr, error) { 219 pc.mux.Lock() 220 defer pc.mux.Unlock() 221 if pc.remain != 0 { 222 length := len(b) 223 if pc.remain < length { 224 length = pc.remain 225 } 226 227 n, err := pc.Conn.Read(b[:length]) 228 if err != nil { 229 return 0, nil, err 230 } 231 232 pc.remain -= n 233 addr := pc.rAddr 234 if pc.remain == 0 { 235 pc.rAddr = nil 236 } 237 238 return n, addr, nil 239 } 240 241 addr, n, remain, err := ReadPacket(pc.Conn, b) 242 if err != nil { 243 return 0, nil, err 244 } 245 246 if remain != 0 { 247 pc.remain = remain 248 pc.rAddr = addr 249 } 250 251 return n, addr, nil 252 } 253 254 func hexSha224(data []byte) []byte { 255 buf := make([]byte, 56) 256 hash := sha256.New224() 257 hash.Write(data) 258 hex.Encode(buf, hash.Sum(nil)) 259 return buf 260 }