github.com/database64128/shadowsocks-go@v1.7.0/router/router.go (about)

     1  package router
     2  
     3  import (
     4  	"fmt"
     5  
     6  	"github.com/database64128/shadowsocks-go/dns"
     7  	"github.com/database64128/shadowsocks-go/domainset"
     8  	"github.com/database64128/shadowsocks-go/prefixset"
     9  	"github.com/database64128/shadowsocks-go/zerocopy"
    10  	"github.com/oschwald/geoip2-golang"
    11  	"go.uber.org/zap"
    12  	"go4.org/netipx"
    13  )
    14  
    15  // Config is the configuration for a Router.
    16  type Config struct {
    17  	DefaultTCPClientName  string             `json:"defaultTCPClientName"`
    18  	DefaultUDPClientName  string             `json:"defaultUDPClientName"`
    19  	GeoLite2CountryDbPath string             `json:"geoLite2CountryDbPath"`
    20  	DomainSets            []domainset.Config `json:"domainSets"`
    21  	PrefixSets            []prefixset.Config `json:"prefixSets"`
    22  	Routes                []RouteConfig      `json:"routes"`
    23  }
    24  
    25  // Router creates a router from the RouterConfig.
    26  func (rc *Config) Router(logger *zap.Logger, resolvers []*dns.Resolver, resolverMap map[string]*dns.Resolver, tcpClientMap map[string]zerocopy.TCPClient, udpClientMap map[string]zerocopy.UDPClient, serverIndexByName map[string]int) (*Router, error) {
    27  	defaultRoute := Route{name: "default"}
    28  
    29  	switch rc.DefaultTCPClientName {
    30  	case "reject":
    31  	case "":
    32  		if len(tcpClientMap) == 1 {
    33  			for _, tcpClient := range tcpClientMap {
    34  				defaultRoute.tcpClient = tcpClient
    35  			}
    36  		}
    37  	default:
    38  		defaultRoute.tcpClient = tcpClientMap[rc.DefaultTCPClientName]
    39  		if defaultRoute.tcpClient == nil {
    40  			return nil, fmt.Errorf("default TCP client not found: %s", rc.DefaultTCPClientName)
    41  		}
    42  	}
    43  
    44  	switch rc.DefaultUDPClientName {
    45  	case "reject":
    46  	case "":
    47  		if len(udpClientMap) == 1 {
    48  			for _, udpClient := range udpClientMap {
    49  				defaultRoute.udpClient = udpClient
    50  			}
    51  		}
    52  	default:
    53  		defaultRoute.udpClient = udpClientMap[rc.DefaultUDPClientName]
    54  		if defaultRoute.udpClient == nil {
    55  			return nil, fmt.Errorf("default UDP client not found: %s", rc.DefaultUDPClientName)
    56  		}
    57  	}
    58  
    59  	var (
    60  		geoip *geoip2.Reader
    61  		err   error
    62  	)
    63  
    64  	if rc.GeoLite2CountryDbPath != "" {
    65  		geoip, err = geoip2.Open(rc.GeoLite2CountryDbPath)
    66  		if err != nil {
    67  			return nil, err
    68  		}
    69  	}
    70  
    71  	domainSetMap := make(map[string]domainset.DomainSet, len(rc.DomainSets))
    72  
    73  	for i := range rc.DomainSets {
    74  		domainSet, err := rc.DomainSets[i].DomainSet()
    75  		if err != nil {
    76  			return nil, err
    77  		}
    78  		domainSetMap[rc.DomainSets[i].Name] = domainSet
    79  	}
    80  
    81  	prefixSetMap := make(map[string]*netipx.IPSet, len(rc.PrefixSets))
    82  
    83  	for i := range rc.PrefixSets {
    84  		s, err := rc.PrefixSets[i].IPSet()
    85  		if err != nil {
    86  			return nil, err
    87  		}
    88  		prefixSetMap[rc.PrefixSets[i].Name] = s
    89  	}
    90  
    91  	routes := make([]Route, len(rc.Routes)+1)
    92  
    93  	for i := range rc.Routes {
    94  		route, err := rc.Routes[i].Route(geoip, logger, resolvers, resolverMap, tcpClientMap, udpClientMap, serverIndexByName, domainSetMap, prefixSetMap)
    95  		if err != nil {
    96  			return nil, err
    97  		}
    98  		routes[i] = route
    99  	}
   100  
   101  	routes[len(rc.Routes)] = defaultRoute
   102  
   103  	return &Router{
   104  		geoip:  geoip,
   105  		logger: logger,
   106  		routes: routes,
   107  	}, nil
   108  }
   109  
   110  // Router looks up the destination client for requests received by servers.
   111  type Router struct {
   112  	geoip  *geoip2.Reader
   113  	logger *zap.Logger
   114  	routes []Route
   115  }
   116  
   117  // Close closes the router.
   118  func (r *Router) Close() error {
   119  	if r.geoip != nil {
   120  		return r.geoip.Close()
   121  	}
   122  	return nil
   123  }
   124  
   125  // GetTCPClient returns the zerocopy.TCPClient for a TCP request received by server
   126  // from sourceAddrPort to targetAddr.
   127  func (r *Router) GetTCPClient(requestInfo RequestInfo) (zerocopy.TCPClient, error) {
   128  	route, err := r.match(protocolTCP, requestInfo)
   129  	if err != nil {
   130  		return nil, err
   131  	}
   132  
   133  	if ce := r.logger.Check(zap.DebugLevel, "Matched route for TCP connection"); ce != nil {
   134  		ce.Write(
   135  			zap.Int("serverIndex", requestInfo.ServerIndex),
   136  			zap.String("username", requestInfo.Username),
   137  			zap.Stringer("sourceAddrPort", requestInfo.SourceAddrPort),
   138  			zap.Stringer("targetAddress", requestInfo.TargetAddr),
   139  			zap.Stringer("route", route),
   140  		)
   141  	}
   142  
   143  	return route.TCPClient()
   144  }
   145  
   146  // GetUDPClient returns the zerocopy.UDPClient for a UDP session received by server.
   147  // The first received packet of the session is from sourceAddrPort to targetAddr.
   148  func (r *Router) GetUDPClient(requestInfo RequestInfo) (zerocopy.UDPClient, error) {
   149  	route, err := r.match(protocolUDP, requestInfo)
   150  	if err != nil {
   151  		return nil, err
   152  	}
   153  
   154  	if ce := r.logger.Check(zap.DebugLevel, "Matched route for UDP session"); ce != nil {
   155  		ce.Write(
   156  			zap.Int("serverIndex", requestInfo.ServerIndex),
   157  			zap.String("username", requestInfo.Username),
   158  			zap.Stringer("sourceAddrPort", requestInfo.SourceAddrPort),
   159  			zap.Stringer("targetAddress", requestInfo.TargetAddr),
   160  			zap.Stringer("route", route),
   161  		)
   162  	}
   163  
   164  	return route.UDPClient()
   165  }
   166  
   167  // match returns the matched route for the new TCP request or UDP session.
   168  func (r *Router) match(network protocol, requestInfo RequestInfo) (*Route, error) {
   169  	for i := range r.routes {
   170  		matched, err := r.routes[i].Match(network, requestInfo)
   171  		if err != nil {
   172  			return nil, err
   173  		}
   174  		if matched {
   175  			return &r.routes[i], nil
   176  		}
   177  	}
   178  	panic("did not match default route")
   179  }