github.com/bepass-org/wireguard-go@v1.0.4-rc2.0.20240304192354-ebce6572bc24/wiresocks/udpfw.go (about)

     1  package wiresocks
     2  
     3  import (
     4  	"context"
     5  	"encoding/binary"
     6  	"fmt"
     7  	"io"
     8  	"net"
     9  	"sync"
    10  )
    11  
    12  type Socks5UDPForwarder struct {
    13  	socks5Server string
    14  	destAddr     *net.UDPAddr
    15  	proxyUDPAddr *net.UDPAddr
    16  	conn         *net.UDPConn
    17  	listener     *net.UDPConn
    18  	clientAddr   *net.UDPAddr
    19  }
    20  
    21  func NewVtunUDPForwarder(localBind, dest string, vtun *VirtualTun, mtu int, ctx context.Context) error {
    22  	localAddr, err := net.ResolveUDPAddr("udp", localBind)
    23  	if err != nil {
    24  		return err
    25  	}
    26  
    27  	destAddr, err := net.ResolveUDPAddr("udp", dest)
    28  	if err != nil {
    29  		return err
    30  	}
    31  
    32  	listener, err := net.ListenUDP("udp", localAddr)
    33  	if err != nil {
    34  		return err
    35  	}
    36  
    37  	rconn, err := vtun.Tnet.DialUDP(nil, destAddr)
    38  	if err != nil {
    39  		return err
    40  	}
    41  
    42  	var clientAddr *net.UDPAddr
    43  	var wg sync.WaitGroup
    44  	wg.Add(2)
    45  
    46  	go func() {
    47  		buffer := make([]byte, mtu)
    48  		for {
    49  			select {
    50  			case <-ctx.Done():
    51  				wg.Done()
    52  				return
    53  			default:
    54  				n, cAddr, err := listener.ReadFrom(buffer)
    55  				if err != nil {
    56  					continue
    57  				}
    58  
    59  				clientAddr = cAddr.(*net.UDPAddr)
    60  
    61  				rconn.WriteTo(buffer[:n], destAddr)
    62  			}
    63  		}
    64  	}()
    65  	go func() {
    66  		buffer := make([]byte, mtu)
    67  		for {
    68  			select {
    69  			case <-ctx.Done():
    70  				wg.Done()
    71  				return
    72  			default:
    73  				n, _, err := rconn.ReadFrom(buffer)
    74  				if err != nil {
    75  					continue
    76  				}
    77  				if clientAddr != nil {
    78  					listener.WriteTo(buffer[:n], clientAddr)
    79  				}
    80  			}
    81  		}
    82  	}()
    83  	go func() {
    84  		wg.Wait()
    85  		_ = listener.Close()
    86  		_ = rconn.Close()
    87  	}()
    88  	return nil
    89  }
    90  
    91  func NewSocks5UDPForwarder(localBind, socks5Server, dest string) (*Socks5UDPForwarder, error) {
    92  	localAddr, err := net.ResolveUDPAddr("udp", localBind)
    93  	if err != nil {
    94  		return nil, err
    95  	}
    96  
    97  	destAddr, err := net.ResolveUDPAddr("udp", dest)
    98  	if err != nil {
    99  		return nil, err
   100  	}
   101  
   102  	listener, err := net.ListenUDP("udp", localAddr)
   103  	if err != nil {
   104  		return nil, err
   105  	}
   106  
   107  	tcpConn, err := net.Dial("tcp", socks5Server)
   108  	if err != nil {
   109  		return nil, err
   110  	}
   111  	defer tcpConn.Close()
   112  
   113  	if err := socks5Handshake(tcpConn); err != nil {
   114  		return nil, err
   115  	}
   116  
   117  	proxyUDPAddr, err := requestUDPAssociate(tcpConn)
   118  	if err != nil {
   119  		return nil, err
   120  	}
   121  
   122  	udpConn, err := net.DialUDP("udp", nil, proxyUDPAddr)
   123  	if err != nil {
   124  		return nil, err
   125  	}
   126  
   127  	return &Socks5UDPForwarder{
   128  		socks5Server: socks5Server,
   129  		destAddr:     destAddr,
   130  		proxyUDPAddr: proxyUDPAddr,
   131  		conn:         udpConn,
   132  		listener:     listener,
   133  	}, nil
   134  }
   135  
   136  func (f *Socks5UDPForwarder) Start() {
   137  	go f.listenAndServe()
   138  	go f.receiveFromProxy()
   139  }
   140  
   141  func socks5Handshake(conn net.Conn) error {
   142  	// Send greeting
   143  	_, err := conn.Write([]byte{0x05, 0x01, 0x00}) // SOCKS5, 1 authentication method, No authentication
   144  	if err != nil {
   145  		return err
   146  	}
   147  
   148  	// Receive server response
   149  	resp := make([]byte, 2)
   150  	if _, err := io.ReadFull(conn, resp); err != nil {
   151  		return err
   152  	}
   153  
   154  	if resp[0] != 0x05 || resp[1] != 0x00 {
   155  		return fmt.Errorf("invalid SOCKS5 authentication response")
   156  	}
   157  	return nil
   158  }
   159  
   160  func (f *Socks5UDPForwarder) listenAndServe() {
   161  	for {
   162  		buffer := make([]byte, 4096)
   163  		// Listen for incoming UDP packets
   164  		n, clientAddr, err := f.listener.ReadFromUDP(buffer)
   165  		if err != nil {
   166  			fmt.Printf("Error reading from listener: %v\n", err)
   167  			continue
   168  		}
   169  
   170  		// Store client address for response mapping
   171  		f.clientAddr = clientAddr
   172  
   173  		// Forward packet to destination via SOCKS5 proxy
   174  		go f.forwardPacketToRemote(buffer[:n])
   175  	}
   176  }
   177  
   178  func (f *Socks5UDPForwarder) forwardPacketToRemote(data []byte) {
   179  	packet := make([]byte, 10+len(data))
   180  	packet[0] = 0x00 // Reserved
   181  	packet[1] = 0x00 // Reserved
   182  	packet[2] = 0x00 // Fragment
   183  	packet[3] = 0x01 // Address type (IPv4)
   184  	copy(packet[4:8], f.destAddr.IP.To4())
   185  	binary.BigEndian.PutUint16(packet[8:10], uint16(f.destAddr.Port))
   186  	copy(packet[10:], data)
   187  
   188  	_, err := f.conn.Write(packet)
   189  	if err != nil {
   190  		fmt.Printf("Error forwarding packet to remote: %v\n", err)
   191  	}
   192  }
   193  
   194  func (f *Socks5UDPForwarder) receiveFromProxy() {
   195  	for {
   196  		buffer := make([]byte, 4096)
   197  		n, err := f.conn.Read(buffer)
   198  		if err != nil {
   199  			fmt.Printf("Error reading from proxy connection: %v\n", err)
   200  			continue
   201  		}
   202  
   203  		// Forward the packet to the original client
   204  		f.listener.WriteToUDP(buffer[10:n], f.clientAddr)
   205  	}
   206  }
   207  
   208  func requestUDPAssociate(conn net.Conn) (*net.UDPAddr, error) {
   209  	// Send UDP associate request with local address and port set to zero
   210  	req := []byte{0x05, 0x03, 0x00, 0x01, 0, 0, 0, 0, 0, 0} // Command: UDP Associate
   211  	if _, err := conn.Write(req); err != nil {
   212  		return nil, err
   213  	}
   214  
   215  	// Receive response
   216  	resp := make([]byte, 10)
   217  	if _, err := io.ReadFull(conn, resp); err != nil {
   218  		return nil, err
   219  	}
   220  
   221  	if resp[1] != 0x00 {
   222  		return nil, fmt.Errorf("UDP ASSOCIATE request failed")
   223  	}
   224  
   225  	// Parse the proxy UDP address
   226  	bindIP := net.IP(resp[4:8])
   227  	bindPort := binary.BigEndian.Uint16(resp[8:10])
   228  
   229  	return &net.UDPAddr{IP: bindIP, Port: int(bindPort)}, nil
   230  }