github.com/Asutorufa/yuhaiin@v0.3.6-0.20240502055049-7984da7023a0/pkg/net/proxy/wireguard/wireguard.go (about) 1 /* SPDX-License-Identifier: MIT 2 * 3 * Copyright (C) 2017-2022 WireGuard LLC. All Rights Reserved. 4 */ 5 6 package wireguard 7 8 import ( 9 "bytes" 10 "context" 11 "encoding/base64" 12 "encoding/hex" 13 "fmt" 14 "log/slog" 15 "net" 16 "sync" 17 "sync/atomic" 18 "time" 19 20 "github.com/Asutorufa/yuhaiin/pkg/log" 21 "github.com/Asutorufa/yuhaiin/pkg/net/netapi" 22 "github.com/Asutorufa/yuhaiin/pkg/protos/node/point" 23 "github.com/Asutorufa/yuhaiin/pkg/protos/node/protocol" 24 "github.com/tailscale/wireguard-go/device" 25 "gvisor.dev/gvisor/pkg/tcpip/adapters/gonet" 26 ) 27 28 type Wireguard struct { 29 netapi.EmptyDispatch 30 net *Net 31 bind *netBindClient 32 33 conf *protocol.Wireguard 34 mu sync.Mutex 35 36 count atomic.Int64 37 38 lastNewConn time.Time 39 idleTimeout time.Duration 40 41 device *device.Device 42 } 43 44 func init() { 45 point.RegisterProtocol(NewClient) 46 } 47 48 func NewClient(conf *protocol.Protocol_Wireguard) point.WrapProxy { 49 return func(p netapi.Proxy) (netapi.Proxy, error) { 50 51 if conf.Wireguard.IdleTimeout == 0 { 52 conf.Wireguard.IdleTimeout = 60 * 5 53 } 54 if conf.Wireguard.IdleTimeout <= 30 { 55 conf.Wireguard.IdleTimeout = 30 56 } 57 58 return &Wireguard{ 59 conf: conf.Wireguard, 60 idleTimeout: time.Duration(conf.Wireguard.IdleTimeout) * time.Second, 61 }, nil 62 } 63 } 64 65 func (w *Wireguard) collect() { 66 readyClose := false 67 68 for { 69 time.Sleep(w.idleTimeout) 70 71 br := func() bool { 72 w.mu.Lock() 73 defer w.mu.Unlock() 74 75 log.Debug("wireguard check idle timeout") 76 77 if w.count.Load() > 0 { 78 readyClose = false 79 return false 80 } 81 82 if !w.lastNewConn.IsZero() && time.Since(w.lastNewConn) < time.Minute { 83 readyClose = false 84 return false 85 } 86 87 if readyClose { 88 log.Debug("wireguard closing") 89 if w.device != nil { 90 w.device.Close() 91 w.device = nil 92 } 93 94 if w.bind != nil { 95 w.bind.Close() 96 w.bind = nil 97 } 98 log.Debug("wireguard closed") 99 w.net = nil 100 return true 101 } 102 103 log.Debug("wireguard ready to close") 104 105 readyClose = true 106 return false 107 }() 108 109 if br { 110 break 111 } 112 } 113 } 114 115 func (w *Wireguard) initNet() (*Net, error) { 116 net := w.net 117 if net != nil { 118 return net, nil 119 } 120 121 w.mu.Lock() 122 defer w.mu.Unlock() 123 124 if w.net != nil { 125 return w.net, nil 126 } 127 128 dev, bind, net, err := makeVirtualTun(w.conf) 129 if err != nil { 130 return nil, err 131 } 132 133 w.device = dev 134 w.net = net 135 w.bind = bind 136 go w.collect() 137 138 return net, nil 139 } 140 141 func (w *Wireguard) Conn(ctx context.Context, addr netapi.Address) (net.Conn, error) { 142 net, err := w.initNet() 143 if err != nil { 144 return nil, err 145 } 146 147 addrPort := addr.AddrPort(ctx) 148 149 if addrPort.Err != nil { 150 return nil, addrPort.Err 151 } 152 153 conn, err := net.DialContextTCPAddrPort(ctx, addrPort.V) 154 if err != nil { 155 return nil, err 156 } 157 158 w.count.Add(1) 159 w.lastNewConn = time.Now() 160 161 return &wrapGoNetTcpConn{w, conn}, nil 162 } 163 164 type wrapGoNetTcpConn struct { 165 wireguard *Wireguard 166 *gonet.TCPConn 167 } 168 169 func (w *wrapGoNetTcpConn) Close() error { 170 w.wireguard.count.Add(-1) 171 return w.TCPConn.Close() 172 } 173 174 func (w *Wireguard) PacketConn(ctx context.Context, addr netapi.Address) (net.PacketConn, error) { 175 net, err := w.initNet() 176 if err != nil { 177 return nil, err 178 } 179 180 goUC, err := net.ListenUDP(nil) 181 if err != nil { 182 return nil, err 183 } 184 185 w.count.Add(1) 186 w.lastNewConn = time.Now() 187 188 return &wrapGoNetUdpConn{w, goUC}, nil 189 } 190 191 type wrapGoNetUdpConn struct { 192 wireguard *Wireguard 193 *gonet.UDPConn 194 } 195 196 func (w *wrapGoNetUdpConn) Close() error { 197 w.wireguard.count.Add(-1) 198 return w.UDPConn.Close() 199 } 200 201 func (w *wrapGoNetUdpConn) WriteTo(buf []byte, addr net.Addr) (int, error) { 202 a, err := netapi.ParseSysAddr(addr) 203 if err != nil { 204 return 0, err 205 } 206 207 ur := a.UDPAddr(context.TODO()) 208 209 if ur.Err != nil { 210 return 0, ur.Err 211 } 212 213 return w.UDPConn.WriteTo(buf, ur.V) 214 } 215 216 // creates a tun interface on netstack given a configuration 217 func makeVirtualTun(h *protocol.Wireguard) (*device.Device, *netBindClient, *Net, error) { 218 endpoints, err := parseEndpoints(h) 219 if err != nil { 220 return nil, nil, nil, err 221 } 222 tun, tnet, err := CreateNetTUN(endpoints, int(h.Mtu)) 223 if err != nil { 224 return nil, nil, nil, err 225 } 226 227 bind := newNetBindClient(h.GetReserved()) 228 // dev := device.NewDevice(tun, conn.NewDefaultBind(), nil /* device.NewLogger(device.LogLevelVerbose, "") */) 229 dev := device.NewDevice( 230 tun, 231 bind, 232 &device.Logger{ 233 Verbosef: func(format string, args ...any) { 234 log.Output(2, slog.LevelDebug, fmt.Sprintf(format, args...)) 235 }, 236 Errorf: func(format string, args ...any) { 237 log.Output(2, slog.LevelError, fmt.Sprintf(format, args...)) 238 }, 239 }) 240 241 err = dev.IpcSet(createIPCRequest(h)) 242 if err != nil { 243 dev.Close() 244 return nil, nil, nil, err 245 } 246 247 err = dev.Up() 248 if err != nil { 249 dev.Close() 250 return nil, nil, nil, err 251 } 252 253 return dev, bind, tnet, nil 254 } 255 256 func base64ToHex(s string) string { 257 data, _ := base64.StdEncoding.DecodeString(s) 258 return hex.EncodeToString(data) 259 } 260 261 // serialize the config into an IPC request 262 func createIPCRequest(conf *protocol.Wireguard) string { 263 var request bytes.Buffer 264 265 request.WriteString(fmt.Sprintf("private_key=%s\n", base64ToHex(conf.SecretKey))) 266 267 for _, peer := range conf.Peers { 268 request.WriteString(fmt.Sprintf("public_key=%s\nendpoint=%s\n", base64ToHex(peer.PublicKey), peer.Endpoint)) 269 if peer.KeepAlive != 0 { 270 request.WriteString(fmt.Sprintf("persistent_keepalive_interval=%d\n", peer.KeepAlive)) 271 } 272 if peer.PreSharedKey != "" { 273 request.WriteString(fmt.Sprintf("preshared_key=%s\n", base64ToHex(peer.PreSharedKey))) 274 } 275 276 for _, ip := range peer.AllowedIps { 277 request.WriteString(fmt.Sprintf("allowed_ip=%s\n", ip)) 278 } 279 } 280 281 return request.String()[:request.Len()] 282 }