github.com/kelleygo/clashcore@v1.0.2/component/loopback/detector.go (about)

     1  package loopback
     2  
     3  import (
     4  	"errors"
     5  	"fmt"
     6  	"net/netip"
     7  
     8  	"github.com/kelleygo/clashcore/common/callback"
     9  	"github.com/kelleygo/clashcore/component/iface"
    10  	C "github.com/kelleygo/clashcore/constant"
    11  
    12  	"github.com/puzpuzpuz/xsync/v3"
    13  )
    14  
    15  var ErrReject = errors.New("reject loopback connection")
    16  
    17  type Detector struct {
    18  	connMap       *xsync.MapOf[netip.AddrPort, struct{}]
    19  	packetConnMap *xsync.MapOf[uint16, struct{}]
    20  }
    21  
    22  func NewDetector() *Detector {
    23  	return &Detector{
    24  		connMap:       xsync.NewMapOf[netip.AddrPort, struct{}](),
    25  		packetConnMap: xsync.NewMapOf[uint16, struct{}](),
    26  	}
    27  }
    28  
    29  func (l *Detector) NewConn(conn C.Conn) C.Conn {
    30  	metadata := C.Metadata{}
    31  	if metadata.SetRemoteAddr(conn.LocalAddr()) != nil {
    32  		return conn
    33  	}
    34  	connAddr := metadata.AddrPort()
    35  	if !connAddr.IsValid() {
    36  		return conn
    37  	}
    38  	l.connMap.Store(connAddr, struct{}{})
    39  	return callback.NewCloseCallbackConn(conn, func() {
    40  		l.connMap.Delete(connAddr)
    41  	})
    42  }
    43  
    44  func (l *Detector) NewPacketConn(conn C.PacketConn) C.PacketConn {
    45  	metadata := C.Metadata{}
    46  	if metadata.SetRemoteAddr(conn.LocalAddr()) != nil {
    47  		return conn
    48  	}
    49  	connAddr := metadata.AddrPort()
    50  	if !connAddr.IsValid() {
    51  		return conn
    52  	}
    53  	port := connAddr.Port()
    54  	l.packetConnMap.Store(port, struct{}{})
    55  	return callback.NewCloseCallbackPacketConn(conn, func() {
    56  		l.packetConnMap.Delete(port)
    57  	})
    58  }
    59  
    60  func (l *Detector) CheckConn(metadata *C.Metadata) error {
    61  	connAddr := metadata.SourceAddrPort()
    62  	if !connAddr.IsValid() {
    63  		return nil
    64  	}
    65  	if _, ok := l.connMap.Load(connAddr); ok {
    66  		return fmt.Errorf("%w to: %s", ErrReject, metadata.RemoteAddress())
    67  	}
    68  	return nil
    69  }
    70  
    71  func (l *Detector) CheckPacketConn(metadata *C.Metadata) error {
    72  	connAddr := metadata.SourceAddrPort()
    73  	if !connAddr.IsValid() {
    74  		return nil
    75  	}
    76  
    77  	isLocalIp, err := iface.IsLocalIp(connAddr.Addr())
    78  	if err != nil {
    79  		return err
    80  	}
    81  	if !isLocalIp && !connAddr.Addr().IsLoopback() {
    82  		return nil
    83  	}
    84  
    85  	if _, ok := l.packetConnMap.Load(connAddr.Port()); ok {
    86  		return fmt.Errorf("%w to: %s", ErrReject, metadata.RemoteAddress())
    87  	}
    88  	return nil
    89  }