github.com/Asutorufa/yuhaiin@v0.3.6-0.20240502055049-7984da7023a0/pkg/net/nat/table.go (about) 1 package nat 2 3 import ( 4 "context" 5 "errors" 6 "fmt" 7 "io" 8 "net" 9 "os" 10 "time" 11 12 "github.com/Asutorufa/yuhaiin/pkg/log" 13 "github.com/Asutorufa/yuhaiin/pkg/net/netapi" 14 "github.com/Asutorufa/yuhaiin/pkg/utils/pool" 15 "github.com/Asutorufa/yuhaiin/pkg/utils/singleflight" 16 "github.com/Asutorufa/yuhaiin/pkg/utils/syncmap" 17 ) 18 19 var IdleTimeout = time.Minute * 3 20 var MaxSegmentSize = pool.MaxSegmentSize 21 22 func NewTable(dialer netapi.Proxy) *Table { 23 return &Table{dialer: dialer} 24 } 25 26 type Table struct { 27 dialer netapi.Proxy 28 cache syncmap.SyncMap[string, *SourceTable] 29 sf singleflight.Group[string, *SourceTable] 30 } 31 32 func (u *Table) write(ctx context.Context, t *SourceTable, pkt *netapi.Packet) error { 33 key := pkt.Dst.String() 34 35 // ! we need write to same ip when use fakeip/domain, eg: quic will need it to create stream 36 uaddr, ok := t.udpAddrCache.Load(key) 37 if !ok { 38 var err error 39 uaddr, err, _ = t.sf.Do(key, func() (*net.UDPAddr, error) { 40 realAddr, err := u.dialer.Dispatch(ctx, pkt.Dst) 41 if err != nil { 42 return nil, fmt.Errorf("dispatch addr failed: %w", err) 43 } 44 45 ur := realAddr.UDPAddr(ctx) 46 if ur.Err != nil { 47 return nil, ur.Err 48 } 49 50 uaddr = ur.V 51 52 t.udpAddrCache.LoadOrStore(key, uaddr) 53 54 if !pkt.Dst.IsFqdn() { 55 // map fakeip/hosts 56 if uaddrStr := uaddr.String(); uaddrStr != key { 57 // TODO: maybe two dst(fake ip) have same uaddr, need help 58 t.originAddrStore.LoadOrStore(uaddrStr, pkt.Dst) 59 } 60 } 61 62 return uaddr, nil 63 }) 64 if err != nil { 65 return err 66 } 67 } 68 69 _, err := t.dstPacketConn.WriteTo(pkt.Payload.Bytes(), uaddr) 70 _ = t.dstPacketConn.SetReadDeadline(time.Now().Add(IdleTimeout)) 71 return err 72 } 73 74 func (u *Table) Write(ctx context.Context, pkt *netapi.Packet) error { 75 defer pkt.Payload.Free() 76 77 key := pkt.Src.String() 78 79 t, ok := u.cache.Load(key) 80 if ok { 81 return u.write(ctx, t, pkt) 82 } 83 84 t, err, _ := u.sf.Do(key, func() (*SourceTable, error) { 85 netapi.StoreFromContext(ctx). 86 Add(netapi.SourceKey{}, pkt.Src). 87 Add(netapi.DestinationKey{}, pkt.Dst) 88 89 dstpconn, err := u.dialer.PacketConn(ctx, pkt.Dst) 90 if err != nil { 91 return nil, fmt.Errorf("dial %s failed: %w", pkt.Dst, err) 92 } 93 94 table, _ := u.cache.LoadOrStore(key, &SourceTable{dstPacketConn: dstpconn}) 95 96 go func() { 97 log.IfErr("udp remote to local", 98 func() error { return u.writeBack(pkt, table) }, 99 net.ErrClosed, 100 io.EOF, 101 os.ErrDeadlineExceeded, 102 ) 103 u.cache.Delete(key) 104 dstpconn.Close() 105 }() 106 107 return table, nil 108 }) 109 if err != nil { 110 return err 111 } 112 113 if err = u.write(ctx, t, pkt); err != nil { 114 return fmt.Errorf("write data to remote failed: %w", err) 115 } 116 117 return nil 118 } 119 120 func (u *Table) writeBack(pkt *netapi.Packet, table *SourceTable) error { 121 data := pool.GetBytes(MaxSegmentSize) 122 defer pool.PutBytes(data) 123 124 for { 125 _ = table.dstPacketConn.SetReadDeadline(time.Now().Add(IdleTimeout)) 126 n, from, err := table.dstPacketConn.ReadFrom(data) 127 if err != nil { 128 if errors.Is(err, context.DeadlineExceeded) || 129 errors.Is(err, context.Canceled) || 130 errors.Is(err, os.ErrDeadlineExceeded) { 131 return nil 132 } 133 return fmt.Errorf("read from proxy failed: %w", err) 134 } 135 136 faddr, err := netapi.ParseSysAddr(from) 137 if err != nil { 138 return fmt.Errorf("parse addr failed: %w", err) 139 } 140 141 if !faddr.IsFqdn() { 142 if addr, ok := table.originAddrStore.Load(faddr.String()); ok { 143 // TODO: maybe two dst(fake ip) have same uaddr, need help 144 from = addr 145 } 146 } 147 148 // write back to client with source address 149 if _, err := pkt.WriteBack(data[:n], from); err != nil { 150 return fmt.Errorf("write back to client failed: %w", err) 151 } 152 } 153 } 154 155 func (u *Table) Close() error { 156 u.cache.Range(func(_ string, value *SourceTable) bool { 157 value.dstPacketConn.Close() 158 return true 159 }) 160 161 return nil 162 } 163 164 type SourceTable struct { 165 dstPacketConn net.PacketConn 166 originAddrStore syncmap.SyncMap[string, netapi.Address] 167 udpAddrCache syncmap.SyncMap[string, *net.UDPAddr] 168 sf singleflight.Group[string, *net.UDPAddr] 169 }