github.com/MerlinKodo/sing-tun@v0.1.15/stack_system_nat.go (about)

     1  package tun
     2  
     3  import (
     4  	"context"
     5  	"net/netip"
     6  	"sync"
     7  	"time"
     8  )
     9  
    10  type TCPNat struct {
    11  	portIndex  uint16
    12  	portAccess sync.RWMutex
    13  	addrAccess sync.RWMutex
    14  	addrMap    map[netip.AddrPort]uint16
    15  	portMap    map[uint16]*TCPSession
    16  }
    17  
    18  type TCPSession struct {
    19  	Source      netip.AddrPort
    20  	Destination netip.AddrPort
    21  	LastActive  time.Time
    22  }
    23  
    24  func NewNat(ctx context.Context, timeout time.Duration) *TCPNat {
    25  	natMap := &TCPNat{
    26  		portIndex: 10000,
    27  		addrMap:   make(map[netip.AddrPort]uint16),
    28  		portMap:   make(map[uint16]*TCPSession),
    29  	}
    30  	go natMap.loopCheckTimeout(ctx, timeout)
    31  	return natMap
    32  }
    33  
    34  func (n *TCPNat) loopCheckTimeout(ctx context.Context, timeout time.Duration) {
    35  	ticker := time.NewTicker(timeout)
    36  	defer ticker.Stop()
    37  	for {
    38  		select {
    39  		case <-ticker.C:
    40  			n.checkTimeout(timeout)
    41  		case <-ctx.Done():
    42  			return
    43  		}
    44  	}
    45  }
    46  
    47  func (n *TCPNat) checkTimeout(timeout time.Duration) {
    48  	now := time.Now()
    49  	n.portAccess.Lock()
    50  	defer n.portAccess.Unlock()
    51  	n.addrAccess.Lock()
    52  	defer n.addrAccess.Unlock()
    53  	for natPort, session := range n.portMap {
    54  		if now.Sub(session.LastActive) > timeout {
    55  			delete(n.addrMap, session.Source)
    56  			delete(n.portMap, natPort)
    57  		}
    58  	}
    59  }
    60  
    61  func (n *TCPNat) LookupBack(port uint16) *TCPSession {
    62  	n.portAccess.RLock()
    63  	session := n.portMap[port]
    64  	n.portAccess.RUnlock()
    65  	if session != nil {
    66  		session.LastActive = time.Now()
    67  	}
    68  	return session
    69  }
    70  
    71  func (n *TCPNat) Lookup(source netip.AddrPort, destination netip.AddrPort) uint16 {
    72  	n.addrAccess.RLock()
    73  	port, loaded := n.addrMap[source]
    74  	n.addrAccess.RUnlock()
    75  	if loaded {
    76  		return port
    77  	}
    78  	n.addrAccess.Lock()
    79  	nextPort := n.portIndex
    80  	if nextPort == 0 {
    81  		nextPort = 10000
    82  		n.portIndex = 10001
    83  	} else {
    84  		n.portIndex++
    85  	}
    86  	n.addrMap[source] = nextPort
    87  	n.addrAccess.Unlock()
    88  	n.portAccess.Lock()
    89  	n.portMap[nextPort] = &TCPSession{
    90  		Source:      source,
    91  		Destination: destination,
    92  		LastActive:  time.Now(),
    93  	}
    94  	n.portAccess.Unlock()
    95  	return nextPort
    96  }