github.com/imgk/caddy-trojan@v0.0.0-20221206043256-2631719e16c8/trojan/trojan_udp.go (about)

     1  package trojan
     2  
     3  import (
     4  	"bytes"
     5  	"errors"
     6  	"io"
     7  	"net"
     8  	"os"
     9  	"time"
    10  
    11  	"github.com/imgk/caddy-trojan/socks"
    12  
    13  	"github.com/imgk/memory-go"
    14  )
    15  
    16  // HandleUDP is ...
    17  // [AddrType(1 byte)][Addr(max 256 byte)][Port(2 byte)][Len(2 byte)][0x0d, 0x0a][Data(max 65535 byte)]
    18  func HandleUDP(r io.Reader, w io.Writer, timeout time.Duration, d Dialer) (int64, int64, error) {
    19  	rc, err := d.ListenPacket("udp", "")
    20  	if err != nil {
    21  		return 0, 0, err
    22  	}
    23  	defer rc.Close()
    24  
    25  	type Result struct {
    26  		Num int64
    27  		Err error
    28  	}
    29  
    30  	errCh := make(chan Result, 0)
    31  	go func(rc net.PacketConn, r io.Reader, errCh chan Result) (nr int64, err error) {
    32  		defer func() {
    33  			if errors.Is(err, io.EOF) || errors.Is(err, os.ErrDeadlineExceeded) {
    34  				err = nil
    35  			}
    36  			errCh <- Result{Num: nr, Err: err}
    37  		}()
    38  
    39  		// save previous address
    40  		ptr, bb := memory.Alloc[byte](socks.MaxAddrLen)
    41  		defer memory.Free(ptr)
    42  
    43  		tt := (*net.UDPAddr)(nil)
    44  
    45  		ptr, b := memory.Alloc[byte](64*1024 + socks.MaxAddrLen)
    46  		defer memory.Free(ptr)
    47  
    48  		for {
    49  			raddr, er := socks.ReadAddrBuffer(r, b)
    50  			if er != nil {
    51  				err = er
    52  				break
    53  			}
    54  
    55  			l := raddr.Len()
    56  
    57  			if !bytes.Equal(bb, raddr.Bytes()) {
    58  				addr, er := socks.ResolveUDPAddr(raddr)
    59  				if er != nil {
    60  					err = er
    61  					break
    62  				}
    63  				bb = raddr.AppendTo(bb[:0])
    64  				tt = addr
    65  			}
    66  
    67  			if _, er := io.ReadFull(r, b[l:l+4]); er != nil {
    68  				err = er
    69  				break
    70  			}
    71  
    72  			l += (int(b[l])<<8 | int(b[l+1]))
    73  			nr += int64(l) + 4
    74  
    75  			buf := b[raddr.Len():l]
    76  			if _, er := io.ReadFull(r, buf); er != nil {
    77  				err = er
    78  				break
    79  			}
    80  
    81  			if _, ew := rc.WriteTo(buf, tt); ew != nil {
    82  				err = ew
    83  				break
    84  			}
    85  		}
    86  		rc.SetReadDeadline(time.Now())
    87  		return
    88  	}(rc, r, errCh)
    89  
    90  	nr, nw, err := func(rc net.PacketConn, w io.Writer, errCh chan Result, timeout time.Duration) (_, nw int64, err error) {
    91  		ptr, b := memory.Alloc[byte](64*1024 + socks.MaxAddrLen + 4)
    92  		defer memory.Free(ptr)
    93  
    94  		b[socks.MaxAddrLen+2] = 0x0d
    95  		b[socks.MaxAddrLen+3] = 0x0a
    96  		for {
    97  			rc.SetReadDeadline(time.Now().Add(timeout))
    98  			n, addr, er := rc.ReadFrom(b[socks.MaxAddrLen+4:])
    99  			if er != nil {
   100  				err = er
   101  				break
   102  			}
   103  
   104  			b[socks.MaxAddrLen] = byte(n >> 8)
   105  			b[socks.MaxAddrLen+1] = byte(n)
   106  
   107  			l := func(bb []byte, addr *net.UDPAddr) int64 {
   108  				if ipv4 := addr.IP.To4(); ipv4 != nil {
   109  					const offset = socks.MaxAddrLen - (1 + net.IPv4len + 2)
   110  					bb[offset] = socks.AddrTypeIPv4
   111  					copy(bb[offset+1:], ipv4)
   112  					bb[offset+1+net.IPv4len], bb[offset+1+net.IPv4len+1] = byte(addr.Port>>8), byte(addr.Port)
   113  					return 1 + net.IPv4len + 2
   114  				} else {
   115  					const offset = socks.MaxAddrLen - (1 + net.IPv6len + 2)
   116  					bb[offset] = socks.AddrTypeIPv6
   117  					copy(bb[offset+1:], addr.IP.To16())
   118  					bb[offset+1+net.IPv6len], bb[offset+1+net.IPv6len+1] = byte(addr.Port>>8), byte(addr.Port)
   119  					return 1 + net.IPv6len + 2
   120  				}
   121  			}(b[:socks.MaxAddrLen], addr.(*net.UDPAddr))
   122  			nw += 4 + int64(n) + l
   123  
   124  			if _, ew := w.Write(b[socks.MaxAddrLen-l : socks.MaxAddrLen+4+n]); ew != nil {
   125  				err = ew
   126  				break
   127  			}
   128  		}
   129  		rc.SetWriteDeadline(time.Now())
   130  
   131  		if errors.Is(err, io.EOF) || errors.Is(err, os.ErrDeadlineExceeded) {
   132  			r := <-errCh
   133  			return r.Num, nw, r.Err
   134  		}
   135  		r := <-errCh
   136  		return r.Num, nw, err
   137  	}(rc, w, errCh, timeout)
   138  
   139  	return nr, nw, err
   140  }