github.com/Asutorufa/yuhaiin@v0.3.6-0.20240502055049-7984da7023a0/pkg/net/nat/table.go (about)

     1  package nat
     2  
     3  import (
     4  	"context"
     5  	"errors"
     6  	"fmt"
     7  	"io"
     8  	"net"
     9  	"os"
    10  	"time"
    11  
    12  	"github.com/Asutorufa/yuhaiin/pkg/log"
    13  	"github.com/Asutorufa/yuhaiin/pkg/net/netapi"
    14  	"github.com/Asutorufa/yuhaiin/pkg/utils/pool"
    15  	"github.com/Asutorufa/yuhaiin/pkg/utils/singleflight"
    16  	"github.com/Asutorufa/yuhaiin/pkg/utils/syncmap"
    17  )
    18  
    19  var IdleTimeout = time.Minute * 3
    20  var MaxSegmentSize = pool.MaxSegmentSize
    21  
    22  func NewTable(dialer netapi.Proxy) *Table {
    23  	return &Table{dialer: dialer}
    24  }
    25  
    26  type Table struct {
    27  	dialer netapi.Proxy
    28  	cache  syncmap.SyncMap[string, *SourceTable]
    29  	sf     singleflight.Group[string, *SourceTable]
    30  }
    31  
    32  func (u *Table) write(ctx context.Context, t *SourceTable, pkt *netapi.Packet) error {
    33  	key := pkt.Dst.String()
    34  
    35  	// ! we need write to same ip when use fakeip/domain, eg: quic will need it to create stream
    36  	uaddr, ok := t.udpAddrCache.Load(key)
    37  	if !ok {
    38  		var err error
    39  		uaddr, err, _ = t.sf.Do(key, func() (*net.UDPAddr, error) {
    40  			realAddr, err := u.dialer.Dispatch(ctx, pkt.Dst)
    41  			if err != nil {
    42  				return nil, fmt.Errorf("dispatch addr failed: %w", err)
    43  			}
    44  
    45  			ur := realAddr.UDPAddr(ctx)
    46  			if ur.Err != nil {
    47  				return nil, ur.Err
    48  			}
    49  
    50  			uaddr = ur.V
    51  
    52  			t.udpAddrCache.LoadOrStore(key, uaddr)
    53  
    54  			if !pkt.Dst.IsFqdn() {
    55  				// map fakeip/hosts
    56  				if uaddrStr := uaddr.String(); uaddrStr != key {
    57  					// TODO: maybe two dst(fake ip) have same uaddr, need help
    58  					t.originAddrStore.LoadOrStore(uaddrStr, pkt.Dst)
    59  				}
    60  			}
    61  
    62  			return uaddr, nil
    63  		})
    64  		if err != nil {
    65  			return err
    66  		}
    67  	}
    68  
    69  	_, err := t.dstPacketConn.WriteTo(pkt.Payload.Bytes(), uaddr)
    70  	_ = t.dstPacketConn.SetReadDeadline(time.Now().Add(IdleTimeout))
    71  	return err
    72  }
    73  
    74  func (u *Table) Write(ctx context.Context, pkt *netapi.Packet) error {
    75  	defer pkt.Payload.Free()
    76  
    77  	key := pkt.Src.String()
    78  
    79  	t, ok := u.cache.Load(key)
    80  	if ok {
    81  		return u.write(ctx, t, pkt)
    82  	}
    83  
    84  	t, err, _ := u.sf.Do(key, func() (*SourceTable, error) {
    85  		netapi.StoreFromContext(ctx).
    86  			Add(netapi.SourceKey{}, pkt.Src).
    87  			Add(netapi.DestinationKey{}, pkt.Dst)
    88  
    89  		dstpconn, err := u.dialer.PacketConn(ctx, pkt.Dst)
    90  		if err != nil {
    91  			return nil, fmt.Errorf("dial %s failed: %w", pkt.Dst, err)
    92  		}
    93  
    94  		table, _ := u.cache.LoadOrStore(key, &SourceTable{dstPacketConn: dstpconn})
    95  
    96  		go func() {
    97  			log.IfErr("udp remote to local",
    98  				func() error { return u.writeBack(pkt, table) },
    99  				net.ErrClosed,
   100  				io.EOF,
   101  				os.ErrDeadlineExceeded,
   102  			)
   103  			u.cache.Delete(key)
   104  			dstpconn.Close()
   105  		}()
   106  
   107  		return table, nil
   108  	})
   109  	if err != nil {
   110  		return err
   111  	}
   112  
   113  	if err = u.write(ctx, t, pkt); err != nil {
   114  		return fmt.Errorf("write data to remote failed: %w", err)
   115  	}
   116  
   117  	return nil
   118  }
   119  
   120  func (u *Table) writeBack(pkt *netapi.Packet, table *SourceTable) error {
   121  	data := pool.GetBytes(MaxSegmentSize)
   122  	defer pool.PutBytes(data)
   123  
   124  	for {
   125  		_ = table.dstPacketConn.SetReadDeadline(time.Now().Add(IdleTimeout))
   126  		n, from, err := table.dstPacketConn.ReadFrom(data)
   127  		if err != nil {
   128  			if errors.Is(err, context.DeadlineExceeded) ||
   129  				errors.Is(err, context.Canceled) ||
   130  				errors.Is(err, os.ErrDeadlineExceeded) {
   131  				return nil
   132  			}
   133  			return fmt.Errorf("read from proxy failed: %w", err)
   134  		}
   135  
   136  		faddr, err := netapi.ParseSysAddr(from)
   137  		if err != nil {
   138  			return fmt.Errorf("parse addr failed: %w", err)
   139  		}
   140  
   141  		if !faddr.IsFqdn() {
   142  			if addr, ok := table.originAddrStore.Load(faddr.String()); ok {
   143  				// TODO: maybe two dst(fake ip) have same uaddr, need help
   144  				from = addr
   145  			}
   146  		}
   147  
   148  		// write back to client with source address
   149  		if _, err := pkt.WriteBack(data[:n], from); err != nil {
   150  			return fmt.Errorf("write back to client failed: %w", err)
   151  		}
   152  	}
   153  }
   154  
   155  func (u *Table) Close() error {
   156  	u.cache.Range(func(_ string, value *SourceTable) bool {
   157  		value.dstPacketConn.Close()
   158  		return true
   159  	})
   160  
   161  	return nil
   162  }
   163  
   164  type SourceTable struct {
   165  	dstPacketConn   net.PacketConn
   166  	originAddrStore syncmap.SyncMap[string, netapi.Address]
   167  	udpAddrCache    syncmap.SyncMap[string, *net.UDPAddr]
   168  	sf              singleflight.Group[string, *net.UDPAddr]
   169  }