github.com/Asutorufa/yuhaiin@v0.3.6-0.20240502055049-7984da7023a0/pkg/net/proxy/wireguard/wireguard.go (about)

     1  /* SPDX-License-Identifier: MIT
     2   *
     3   * Copyright (C) 2017-2022 WireGuard LLC. All Rights Reserved.
     4   */
     5  
     6  package wireguard
     7  
     8  import (
     9  	"bytes"
    10  	"context"
    11  	"encoding/base64"
    12  	"encoding/hex"
    13  	"fmt"
    14  	"log/slog"
    15  	"net"
    16  	"sync"
    17  	"sync/atomic"
    18  	"time"
    19  
    20  	"github.com/Asutorufa/yuhaiin/pkg/log"
    21  	"github.com/Asutorufa/yuhaiin/pkg/net/netapi"
    22  	"github.com/Asutorufa/yuhaiin/pkg/protos/node/point"
    23  	"github.com/Asutorufa/yuhaiin/pkg/protos/node/protocol"
    24  	"github.com/tailscale/wireguard-go/device"
    25  	"gvisor.dev/gvisor/pkg/tcpip/adapters/gonet"
    26  )
    27  
    28  type Wireguard struct {
    29  	netapi.EmptyDispatch
    30  	net  *Net
    31  	bind *netBindClient
    32  
    33  	conf *protocol.Wireguard
    34  	mu   sync.Mutex
    35  
    36  	count atomic.Int64
    37  
    38  	lastNewConn time.Time
    39  	idleTimeout time.Duration
    40  
    41  	device *device.Device
    42  }
    43  
    44  func init() {
    45  	point.RegisterProtocol(NewClient)
    46  }
    47  
    48  func NewClient(conf *protocol.Protocol_Wireguard) point.WrapProxy {
    49  	return func(p netapi.Proxy) (netapi.Proxy, error) {
    50  
    51  		if conf.Wireguard.IdleTimeout == 0 {
    52  			conf.Wireguard.IdleTimeout = 60 * 5
    53  		}
    54  		if conf.Wireguard.IdleTimeout <= 30 {
    55  			conf.Wireguard.IdleTimeout = 30
    56  		}
    57  
    58  		return &Wireguard{
    59  			conf:        conf.Wireguard,
    60  			idleTimeout: time.Duration(conf.Wireguard.IdleTimeout) * time.Second,
    61  		}, nil
    62  	}
    63  }
    64  
    65  func (w *Wireguard) collect() {
    66  	readyClose := false
    67  
    68  	for {
    69  		time.Sleep(w.idleTimeout)
    70  
    71  		br := func() bool {
    72  			w.mu.Lock()
    73  			defer w.mu.Unlock()
    74  
    75  			log.Debug("wireguard check idle timeout")
    76  
    77  			if w.count.Load() > 0 {
    78  				readyClose = false
    79  				return false
    80  			}
    81  
    82  			if !w.lastNewConn.IsZero() && time.Since(w.lastNewConn) < time.Minute {
    83  				readyClose = false
    84  				return false
    85  			}
    86  
    87  			if readyClose {
    88  				log.Debug("wireguard closing")
    89  				if w.device != nil {
    90  					w.device.Close()
    91  					w.device = nil
    92  				}
    93  
    94  				if w.bind != nil {
    95  					w.bind.Close()
    96  					w.bind = nil
    97  				}
    98  				log.Debug("wireguard closed")
    99  				w.net = nil
   100  				return true
   101  			}
   102  
   103  			log.Debug("wireguard ready to close")
   104  
   105  			readyClose = true
   106  			return false
   107  		}()
   108  
   109  		if br {
   110  			break
   111  		}
   112  	}
   113  }
   114  
   115  func (w *Wireguard) initNet() (*Net, error) {
   116  	net := w.net
   117  	if net != nil {
   118  		return net, nil
   119  	}
   120  
   121  	w.mu.Lock()
   122  	defer w.mu.Unlock()
   123  
   124  	if w.net != nil {
   125  		return w.net, nil
   126  	}
   127  
   128  	dev, bind, net, err := makeVirtualTun(w.conf)
   129  	if err != nil {
   130  		return nil, err
   131  	}
   132  
   133  	w.device = dev
   134  	w.net = net
   135  	w.bind = bind
   136  	go w.collect()
   137  
   138  	return net, nil
   139  }
   140  
   141  func (w *Wireguard) Conn(ctx context.Context, addr netapi.Address) (net.Conn, error) {
   142  	net, err := w.initNet()
   143  	if err != nil {
   144  		return nil, err
   145  	}
   146  
   147  	addrPort := addr.AddrPort(ctx)
   148  
   149  	if addrPort.Err != nil {
   150  		return nil, addrPort.Err
   151  	}
   152  
   153  	conn, err := net.DialContextTCPAddrPort(ctx, addrPort.V)
   154  	if err != nil {
   155  		return nil, err
   156  	}
   157  
   158  	w.count.Add(1)
   159  	w.lastNewConn = time.Now()
   160  
   161  	return &wrapGoNetTcpConn{w, conn}, nil
   162  }
   163  
   164  type wrapGoNetTcpConn struct {
   165  	wireguard *Wireguard
   166  	*gonet.TCPConn
   167  }
   168  
   169  func (w *wrapGoNetTcpConn) Close() error {
   170  	w.wireguard.count.Add(-1)
   171  	return w.TCPConn.Close()
   172  }
   173  
   174  func (w *Wireguard) PacketConn(ctx context.Context, addr netapi.Address) (net.PacketConn, error) {
   175  	net, err := w.initNet()
   176  	if err != nil {
   177  		return nil, err
   178  	}
   179  
   180  	goUC, err := net.ListenUDP(nil)
   181  	if err != nil {
   182  		return nil, err
   183  	}
   184  
   185  	w.count.Add(1)
   186  	w.lastNewConn = time.Now()
   187  
   188  	return &wrapGoNetUdpConn{w, goUC}, nil
   189  }
   190  
   191  type wrapGoNetUdpConn struct {
   192  	wireguard *Wireguard
   193  	*gonet.UDPConn
   194  }
   195  
   196  func (w *wrapGoNetUdpConn) Close() error {
   197  	w.wireguard.count.Add(-1)
   198  	return w.UDPConn.Close()
   199  }
   200  
   201  func (w *wrapGoNetUdpConn) WriteTo(buf []byte, addr net.Addr) (int, error) {
   202  	a, err := netapi.ParseSysAddr(addr)
   203  	if err != nil {
   204  		return 0, err
   205  	}
   206  
   207  	ur := a.UDPAddr(context.TODO())
   208  
   209  	if ur.Err != nil {
   210  		return 0, ur.Err
   211  	}
   212  
   213  	return w.UDPConn.WriteTo(buf, ur.V)
   214  }
   215  
   216  // creates a tun interface on netstack given a configuration
   217  func makeVirtualTun(h *protocol.Wireguard) (*device.Device, *netBindClient, *Net, error) {
   218  	endpoints, err := parseEndpoints(h)
   219  	if err != nil {
   220  		return nil, nil, nil, err
   221  	}
   222  	tun, tnet, err := CreateNetTUN(endpoints, int(h.Mtu))
   223  	if err != nil {
   224  		return nil, nil, nil, err
   225  	}
   226  
   227  	bind := newNetBindClient(h.GetReserved())
   228  	// dev := device.NewDevice(tun, conn.NewDefaultBind(), nil /* device.NewLogger(device.LogLevelVerbose, "") */)
   229  	dev := device.NewDevice(
   230  		tun,
   231  		bind,
   232  		&device.Logger{
   233  			Verbosef: func(format string, args ...any) {
   234  				log.Output(2, slog.LevelDebug, fmt.Sprintf(format, args...))
   235  			},
   236  			Errorf: func(format string, args ...any) {
   237  				log.Output(2, slog.LevelError, fmt.Sprintf(format, args...))
   238  			},
   239  		})
   240  
   241  	err = dev.IpcSet(createIPCRequest(h))
   242  	if err != nil {
   243  		dev.Close()
   244  		return nil, nil, nil, err
   245  	}
   246  
   247  	err = dev.Up()
   248  	if err != nil {
   249  		dev.Close()
   250  		return nil, nil, nil, err
   251  	}
   252  
   253  	return dev, bind, tnet, nil
   254  }
   255  
   256  func base64ToHex(s string) string {
   257  	data, _ := base64.StdEncoding.DecodeString(s)
   258  	return hex.EncodeToString(data)
   259  }
   260  
   261  // serialize the config into an IPC request
   262  func createIPCRequest(conf *protocol.Wireguard) string {
   263  	var request bytes.Buffer
   264  
   265  	request.WriteString(fmt.Sprintf("private_key=%s\n", base64ToHex(conf.SecretKey)))
   266  
   267  	for _, peer := range conf.Peers {
   268  		request.WriteString(fmt.Sprintf("public_key=%s\nendpoint=%s\n", base64ToHex(peer.PublicKey), peer.Endpoint))
   269  		if peer.KeepAlive != 0 {
   270  			request.WriteString(fmt.Sprintf("persistent_keepalive_interval=%d\n", peer.KeepAlive))
   271  		}
   272  		if peer.PreSharedKey != "" {
   273  			request.WriteString(fmt.Sprintf("preshared_key=%s\n", base64ToHex(peer.PreSharedKey)))
   274  		}
   275  
   276  		for _, ip := range peer.AllowedIps {
   277  			request.WriteString(fmt.Sprintf("allowed_ip=%s\n", ip))
   278  		}
   279  	}
   280  
   281  	return request.String()[:request.Len()]
   282  }