github.com/Asutorufa/yuhaiin@v0.3.6-0.20240502055049-7984da7023a0/pkg/net/proxy/wireguard/bind.go (about) 1 package wireguard 2 3 import ( 4 "context" 5 "errors" 6 "net" 7 "net/netip" 8 "strconv" 9 "sync" 10 11 "github.com/Asutorufa/yuhaiin/pkg/net/dialer" 12 "github.com/Asutorufa/yuhaiin/pkg/net/netapi" 13 "github.com/Asutorufa/yuhaiin/pkg/utils/yerror" 14 "github.com/tailscale/wireguard-go/conn" 15 ) 16 17 var _ conn.Endpoint = (*Endpoint)(nil) 18 19 type Endpoint netip.AddrPort 20 21 func (e Endpoint) ClearSrc() {} 22 func (e Endpoint) SrcToString() string { return "" } 23 func (e Endpoint) DstToString() string { return (netip.AddrPort)(e).String() } 24 func (e Endpoint) DstToBytes() []byte { return yerror.Ignore((netip.AddrPort)(e).MarshalBinary()) } 25 func (e Endpoint) DstIP() netip.Addr { return (netip.AddrPort)(e).Addr() } 26 func (e Endpoint) SrcIP() netip.Addr { return netip.Addr{} } 27 28 type netBindClient struct { 29 mu sync.Mutex 30 conn net.PacketConn 31 reserved []byte 32 } 33 34 func newNetBindClient(reserved []byte) *netBindClient { return &netBindClient{reserved: reserved} } 35 36 func (n *netBindClient) ParseEndpoint(s string) (conn.Endpoint, error) { 37 addrPort, err := netip.ParseAddrPort(s) 38 if err == nil { 39 return Endpoint(addrPort), nil 40 } 41 42 ipStr, port, err := net.SplitHostPort(s) 43 if err != nil { 44 return nil, err 45 } 46 47 portNum, err := strconv.ParseUint(port, 10, 16) 48 if err != nil { 49 return nil, err 50 } 51 52 ips, err := netapi.Bootstrap.LookupIP(context.TODO(), ipStr) 53 if err != nil { 54 return nil, err 55 } 56 57 ip, ok := netip.AddrFromSlice(ips[0]) 58 if !ok { 59 return nil, errors.New("failed to parse ip: " + ipStr) 60 } 61 62 return Endpoint(netip.AddrPortFrom(ip.Unmap(), uint16(portNum))), nil 63 } 64 65 func (bind *netBindClient) Open(uport uint16) ([]conn.ReceiveFunc, uint16, error) { 66 return []conn.ReceiveFunc{bind.receive}, uport, nil 67 } 68 69 func (bind *netBindClient) Close() error { 70 if bind.conn != nil { 71 return bind.conn.Close() 72 } 73 return nil 74 } 75 76 func (bind *netBindClient) connect() (net.PacketConn, error) { 77 conn := bind.conn 78 if conn != nil { 79 return conn, nil 80 } 81 82 bind.mu.Lock() 83 defer bind.mu.Unlock() 84 85 if bind.conn != nil { 86 return bind.conn, nil 87 } 88 89 conn, err := dialer.ListenPacket("udp", "") 90 if err != nil { 91 return nil, err 92 } 93 94 bind.conn = conn 95 96 return conn, nil 97 } 98 99 func (bind *netBindClient) receive(packets [][]byte, sizes []int, eps []conn.Endpoint) (n int, err error) { 100 conn, err := bind.connect() 101 if err != nil { 102 return 0, err 103 } 104 105 n, addr, err := conn.ReadFrom(packets[0]) 106 if err != nil { 107 return 0, err 108 } 109 110 var addrPort netip.AddrPort 111 uaddr, ok := addr.(*net.UDPAddr) 112 if ok { 113 addrPort = uaddr.AddrPort() 114 } else { 115 naddr, err := netapi.ParseSysAddr(addr) 116 if err != nil { 117 return 0, err 118 } 119 120 ar := naddr.AddrPort(context.Background()) 121 if ar.Err != nil { 122 return 0, ar.Err 123 } 124 125 addrPort = ar.V 126 } 127 128 eps[0] = Endpoint(addrPort) 129 if n > 3 { 130 copy(packets[0][1:4], []byte{0, 0, 0}) 131 } 132 sizes[0] = n 133 134 return 1, nil 135 } 136 137 func (bind *netBindClient) Send(buffs [][]byte, endpoint conn.Endpoint) error { 138 ep, ok := endpoint.(Endpoint) 139 if !ok { 140 return conn.ErrWrongEndpointType 141 } 142 143 addr := netip.AddrPort(ep) 144 145 conn, err := bind.connect() 146 if err != nil { 147 return err 148 } 149 150 for _, buff := range buffs { 151 if len(buff) > 3 && len(bind.reserved) == 3 { 152 copy(buff[1:], bind.reserved) 153 } 154 155 _, err = conn.WriteTo(buff, net.UDPAddrFromAddrPort(addr)) 156 if err != nil { 157 return err 158 } 159 } 160 161 return nil 162 } 163 164 func (bind *netBindClient) SetMark(mark uint32) error { return nil } 165 func (bind *netBindClient) BatchSize() int { return 1 }