github.com/Asutorufa/yuhaiin@v0.3.6-0.20240502055049-7984da7023a0/pkg/net/proxy/trojan/client.go (about) 1 package trojan 2 3 import ( 4 "bytes" 5 "context" 6 "crypto/sha256" 7 "encoding/binary" 8 "encoding/hex" 9 "fmt" 10 "io" 11 "net" 12 "sync" 13 14 "github.com/Asutorufa/yuhaiin/pkg/net/netapi" 15 "github.com/Asutorufa/yuhaiin/pkg/net/proxy/socks5/tools" 16 "github.com/Asutorufa/yuhaiin/pkg/protos/node/point" 17 "github.com/Asutorufa/yuhaiin/pkg/protos/node/protocol" 18 "github.com/Asutorufa/yuhaiin/pkg/protos/statistic" 19 "github.com/Asutorufa/yuhaiin/pkg/utils/pool" 20 ) 21 22 const ( 23 MaxPacketSize = 1024 * 8 24 ) 25 26 type Command byte 27 28 const ( 29 Connect Command = 1 // TCP 30 Associate Command = 3 // UDP 31 Mux Command = 0x7f 32 ) 33 34 var crlf = []byte{'\r', '\n'} 35 36 func (c *Client) WriteHeader(conn net.Conn, cmd Command, addr netapi.Address) (err error) { 37 buf := pool.GetBytesWriter(pool.DefaultSize) 38 defer buf.Free() 39 40 _, _ = buf.Write(c.password) 41 _, _ = buf.Write(crlf) 42 buf.WriteByte(byte(cmd)) 43 tools.EncodeAddr(addr, buf) 44 _, _ = buf.Write(crlf) 45 46 _, err = conn.Write(buf.Bytes()) 47 return 48 } 49 50 // modified from https://github.com/p4gefau1t/trojan-go/blob/master/tunnel/trojan/client.go 51 type Client struct { 52 proxy netapi.Proxy 53 netapi.EmptyDispatch 54 password []byte 55 } 56 57 func init() { 58 point.RegisterProtocol(NewClient) 59 } 60 61 func NewClient(config *protocol.Protocol_Trojan) point.WrapProxy { 62 return func(dialer netapi.Proxy) (netapi.Proxy, error) { 63 return &Client{ 64 password: hexSha224([]byte(config.Trojan.Password)), 65 proxy: dialer, 66 }, nil 67 } 68 } 69 70 func (c *Client) Conn(ctx context.Context, addr netapi.Address) (net.Conn, error) { 71 conn, err := c.proxy.Conn(ctx, addr) 72 if err != nil { 73 return nil, err 74 } 75 76 if err = c.WriteHeader(conn, Connect, addr); err != nil { 77 conn.Close() 78 return nil, fmt.Errorf("write header failed: %w", err) 79 } 80 return conn, nil 81 } 82 83 func (c *Client) PacketConn(ctx context.Context, addr netapi.Address) (net.PacketConn, error) { 84 conn, err := c.proxy.Conn(ctx, addr) 85 if err != nil { 86 return nil, err 87 } 88 if err = c.WriteHeader(conn, Associate, addr); err != nil { 89 conn.Close() 90 return nil, fmt.Errorf("write header failed: %w", err) 91 } 92 return &PacketConn{Conn: conn}, nil 93 } 94 95 type PacketConn struct { 96 net.Conn 97 98 remain int 99 addr netapi.Address 100 mux sync.Mutex 101 } 102 103 func (c *PacketConn) WriteTo(payload []byte, addr net.Addr) (int, error) { 104 taddr, err := netapi.ParseSysAddr(addr) 105 if err != nil { 106 return 0, fmt.Errorf("failed to parse addr: %w", err) 107 } 108 109 w := pool.GetBuffer() 110 defer pool.PutBuffer(w) 111 112 tools.EncodeAddr(taddr, w) 113 addrSize := w.Len() 114 115 b := bytes.NewBuffer(payload) 116 117 for b.Len() > 0 { 118 data := b.Next(MaxPacketSize) 119 120 w.Truncate(addrSize) 121 122 binary.Write(w, binary.BigEndian, uint16(len(data))) 123 124 w.Write(crlf) // crlf 125 126 w.Write(data) 127 128 _, err = c.Conn.Write(w.Bytes()) 129 if err != nil { 130 return len(payload) - b.Len() + len(data), fmt.Errorf("write to %v failed: %w", addr, err) 131 } 132 } 133 134 return len(payload), nil 135 } 136 137 func (c *PacketConn) ReadFrom(payload []byte) (n int, _ net.Addr, err error) { 138 c.mux.Lock() 139 defer c.mux.Unlock() 140 141 if c.remain > 0 { 142 z := min(len(payload), c.remain) 143 144 n, err := c.Conn.Read(payload[:z]) 145 if err != nil { 146 return 0, c.addr, err 147 } 148 149 c.remain -= n 150 return n, c.addr, err 151 } 152 153 addr, err := tools.ResolveAddr(c.Conn) 154 if err != nil { 155 return 0, nil, fmt.Errorf("failed to resolve udp packet addr: %w", err) 156 } 157 158 c.addr = addr.Address(statistic.Type_udp) 159 160 var length uint16 161 if err = binary.Read(c.Conn, binary.BigEndian, &length); err != nil { 162 return 0, nil, fmt.Errorf("read length failed: %w", err) 163 } 164 if length > MaxPacketSize { 165 return 0, nil, fmt.Errorf("invalid packet size") 166 } 167 168 crlf := [2]byte{} 169 if _, err := io.ReadFull(c.Conn, crlf[:]); err != nil { 170 return 0, nil, fmt.Errorf("read crlf failed: %w", err) 171 } 172 173 plen := min(int(length), len(payload)) 174 c.remain = int(length) - plen 175 176 n, err = io.ReadFull(c.Conn, payload[:plen]) 177 return n, c.addr, err 178 } 179 180 func hexSha224(data []byte) []byte { 181 buf := make([]byte, 56) 182 hash := sha256.New224() 183 hash.Write(data) 184 hex.Encode(buf, hash.Sum(nil)) 185 return buf 186 }