github.com/database64128/shadowsocks-go@v1.10.2-0.20240315062903-143a773533f1/router/router.go (about)

     1  package router
     2  
     3  import (
     4  	"context"
     5  	"fmt"
     6  
     7  	"github.com/database64128/shadowsocks-go/dns"
     8  	"github.com/database64128/shadowsocks-go/domainset"
     9  	"github.com/database64128/shadowsocks-go/prefixset"
    10  	"github.com/database64128/shadowsocks-go/zerocopy"
    11  	"github.com/oschwald/geoip2-golang"
    12  	"go.uber.org/zap"
    13  	"go4.org/netipx"
    14  )
    15  
    16  // Config is the configuration for a Router.
    17  type Config struct {
    18  	DefaultTCPClientName  string             `json:"defaultTCPClientName"`
    19  	DefaultUDPClientName  string             `json:"defaultUDPClientName"`
    20  	GeoLite2CountryDbPath string             `json:"geoLite2CountryDbPath"`
    21  	DomainSets            []domainset.Config `json:"domainSets"`
    22  	PrefixSets            []prefixset.Config `json:"prefixSets"`
    23  	Routes                []RouteConfig      `json:"routes"`
    24  }
    25  
    26  // Router creates a router from the RouterConfig.
    27  func (rc *Config) Router(logger *zap.Logger, resolvers []dns.SimpleResolver, resolverMap map[string]dns.SimpleResolver, tcpClientMap map[string]zerocopy.TCPClient, udpClientMap map[string]zerocopy.UDPClient, serverIndexByName map[string]int) (*Router, error) {
    28  	defaultRoute := Route{name: "default"}
    29  
    30  	switch rc.DefaultTCPClientName {
    31  	case "reject":
    32  	case "":
    33  		if len(tcpClientMap) == 1 {
    34  			for _, tcpClient := range tcpClientMap {
    35  				defaultRoute.tcpClient = tcpClient
    36  			}
    37  		}
    38  	default:
    39  		defaultRoute.tcpClient = tcpClientMap[rc.DefaultTCPClientName]
    40  		if defaultRoute.tcpClient == nil {
    41  			return nil, fmt.Errorf("default TCP client not found: %s", rc.DefaultTCPClientName)
    42  		}
    43  	}
    44  
    45  	switch rc.DefaultUDPClientName {
    46  	case "reject":
    47  	case "":
    48  		if len(udpClientMap) == 1 {
    49  			for _, udpClient := range udpClientMap {
    50  				defaultRoute.udpClient = udpClient
    51  			}
    52  		}
    53  	default:
    54  		defaultRoute.udpClient = udpClientMap[rc.DefaultUDPClientName]
    55  		if defaultRoute.udpClient == nil {
    56  			return nil, fmt.Errorf("default UDP client not found: %s", rc.DefaultUDPClientName)
    57  		}
    58  	}
    59  
    60  	var (
    61  		geoip *geoip2.Reader
    62  		err   error
    63  	)
    64  
    65  	if rc.GeoLite2CountryDbPath != "" {
    66  		geoip, err = geoip2.Open(rc.GeoLite2CountryDbPath)
    67  		if err != nil {
    68  			return nil, err
    69  		}
    70  	}
    71  
    72  	domainSetMap := make(map[string]domainset.DomainSet, len(rc.DomainSets))
    73  
    74  	for i := range rc.DomainSets {
    75  		domainSet, err := rc.DomainSets[i].DomainSet()
    76  		if err != nil {
    77  			return nil, err
    78  		}
    79  		domainSetMap[rc.DomainSets[i].Name] = domainSet
    80  	}
    81  
    82  	prefixSetMap := make(map[string]*netipx.IPSet, len(rc.PrefixSets))
    83  
    84  	for i := range rc.PrefixSets {
    85  		s, err := rc.PrefixSets[i].IPSet()
    86  		if err != nil {
    87  			return nil, err
    88  		}
    89  		prefixSetMap[rc.PrefixSets[i].Name] = s
    90  	}
    91  
    92  	routes := make([]Route, len(rc.Routes)+1)
    93  
    94  	for i := range rc.Routes {
    95  		route, err := rc.Routes[i].Route(geoip, logger, resolvers, resolverMap, tcpClientMap, udpClientMap, serverIndexByName, domainSetMap, prefixSetMap)
    96  		if err != nil {
    97  			return nil, err
    98  		}
    99  		routes[i] = route
   100  	}
   101  
   102  	routes[len(rc.Routes)] = defaultRoute
   103  
   104  	return &Router{
   105  		geoip:  geoip,
   106  		logger: logger,
   107  		routes: routes,
   108  	}, nil
   109  }
   110  
   111  // Router looks up the destination client for requests received by servers.
   112  type Router struct {
   113  	geoip  *geoip2.Reader
   114  	logger *zap.Logger
   115  	routes []Route
   116  }
   117  
   118  // Close closes the router.
   119  func (r *Router) Close() error {
   120  	if r.geoip != nil {
   121  		return r.geoip.Close()
   122  	}
   123  	return nil
   124  }
   125  
   126  // GetTCPClient returns the zerocopy.TCPClient for a TCP request received by server
   127  // from sourceAddrPort to targetAddr.
   128  func (r *Router) GetTCPClient(ctx context.Context, requestInfo RequestInfo) (zerocopy.TCPClient, error) {
   129  	route, err := r.match(ctx, protocolTCP, requestInfo)
   130  	if err != nil {
   131  		return nil, err
   132  	}
   133  
   134  	if ce := r.logger.Check(zap.DebugLevel, "Matched route for TCP connection"); ce != nil {
   135  		ce.Write(
   136  			zap.Int("serverIndex", requestInfo.ServerIndex),
   137  			zap.String("username", requestInfo.Username),
   138  			zap.Stringer("sourceAddrPort", requestInfo.SourceAddrPort),
   139  			zap.Stringer("targetAddress", requestInfo.TargetAddr),
   140  			zap.Stringer("route", route),
   141  		)
   142  	}
   143  
   144  	return route.TCPClient()
   145  }
   146  
   147  // GetUDPClient returns the zerocopy.UDPClient for a UDP session received by server.
   148  // The first received packet of the session is from sourceAddrPort to targetAddr.
   149  func (r *Router) GetUDPClient(ctx context.Context, requestInfo RequestInfo) (zerocopy.UDPClient, error) {
   150  	route, err := r.match(ctx, protocolUDP, requestInfo)
   151  	if err != nil {
   152  		return nil, err
   153  	}
   154  
   155  	if ce := r.logger.Check(zap.DebugLevel, "Matched route for UDP session"); ce != nil {
   156  		ce.Write(
   157  			zap.Int("serverIndex", requestInfo.ServerIndex),
   158  			zap.String("username", requestInfo.Username),
   159  			zap.Stringer("sourceAddrPort", requestInfo.SourceAddrPort),
   160  			zap.Stringer("targetAddress", requestInfo.TargetAddr),
   161  			zap.Stringer("route", route),
   162  		)
   163  	}
   164  
   165  	return route.UDPClient()
   166  }
   167  
   168  // match returns the matched route for the new TCP request or UDP session.
   169  func (r *Router) match(ctx context.Context, network protocol, requestInfo RequestInfo) (*Route, error) {
   170  	for i := range r.routes {
   171  		matched, err := r.routes[i].Match(ctx, network, requestInfo)
   172  		if err != nil {
   173  			return nil, err
   174  		}
   175  		if matched {
   176  			return &r.routes[i], nil
   177  		}
   178  	}
   179  	panic("did not match default route")
   180  }