github.com/xraypb/xray-core@v1.6.6/app/router/command/command_test.go (about)

     1  package command_test
     2  
     3  import (
     4  	"context"
     5  	"testing"
     6  	"time"
     7  
     8  	"github.com/golang/mock/gomock"
     9  	"github.com/google/go-cmp/cmp"
    10  	"github.com/google/go-cmp/cmp/cmpopts"
    11  	"github.com/xraypb/xray-core/app/router"
    12  	. "github.com/xraypb/xray-core/app/router/command"
    13  	"github.com/xraypb/xray-core/app/stats"
    14  	"github.com/xraypb/xray-core/common"
    15  	"github.com/xraypb/xray-core/common/net"
    16  	"github.com/xraypb/xray-core/features/routing"
    17  	"github.com/xraypb/xray-core/testing/mocks"
    18  	"google.golang.org/grpc"
    19  	"google.golang.org/grpc/test/bufconn"
    20  )
    21  
    22  func TestServiceSubscribeRoutingStats(t *testing.T) {
    23  	c := stats.NewChannel(&stats.ChannelConfig{
    24  		SubscriberLimit: 1,
    25  		BufferSize:      0,
    26  		Blocking:        true,
    27  	})
    28  	common.Must(c.Start())
    29  	defer c.Close()
    30  
    31  	lis := bufconn.Listen(1024 * 1024)
    32  	bufDialer := func(context.Context, string) (net.Conn, error) {
    33  		return lis.Dial()
    34  	}
    35  
    36  	testCases := []*RoutingContext{
    37  		{InboundTag: "in", OutboundTag: "out"},
    38  		{TargetIPs: [][]byte{{1, 2, 3, 4}}, TargetPort: 8080, OutboundTag: "out"},
    39  		{TargetDomain: "example.com", TargetPort: 443, OutboundTag: "out"},
    40  		{SourcePort: 9999, TargetPort: 9999, OutboundTag: "out"},
    41  		{Network: net.Network_UDP, OutboundGroupTags: []string{"outergroup", "innergroup"}, OutboundTag: "out"},
    42  		{Protocol: "bittorrent", OutboundTag: "blocked"},
    43  		{User: "example@example.com", OutboundTag: "out"},
    44  		{SourceIPs: [][]byte{{127, 0, 0, 1}}, Attributes: map[string]string{"attr": "value"}, OutboundTag: "out"},
    45  	}
    46  	errCh := make(chan error)
    47  
    48  	// Server goroutine
    49  	go func() {
    50  		server := grpc.NewServer()
    51  		RegisterRoutingServiceServer(server, NewRoutingServer(nil, c))
    52  		errCh <- server.Serve(lis)
    53  	}()
    54  
    55  	// Publisher goroutine
    56  	go func() {
    57  		publishTestCases := func() error {
    58  			ctx, cancel := context.WithTimeout(context.Background(), time.Second)
    59  			defer cancel()
    60  			for { // Wait until there's one subscriber in routing stats channel
    61  				if len(c.Subscribers()) > 0 {
    62  					break
    63  				}
    64  				if ctx.Err() != nil {
    65  					return ctx.Err()
    66  				}
    67  			}
    68  			for _, tc := range testCases {
    69  				c.Publish(context.Background(), AsRoutingRoute(tc))
    70  				time.Sleep(time.Millisecond)
    71  			}
    72  			return nil
    73  		}
    74  
    75  		if err := publishTestCases(); err != nil {
    76  			errCh <- err
    77  		}
    78  	}()
    79  
    80  	// Client goroutine
    81  	go func() {
    82  		defer lis.Close()
    83  		conn, err := grpc.DialContext(context.Background(), "bufnet", grpc.WithContextDialer(bufDialer), grpc.WithInsecure())
    84  		if err != nil {
    85  			errCh <- err
    86  			return
    87  		}
    88  		defer conn.Close()
    89  		client := NewRoutingServiceClient(conn)
    90  
    91  		// Test retrieving all fields
    92  		testRetrievingAllFields := func() error {
    93  			streamCtx, streamClose := context.WithCancel(context.Background())
    94  
    95  			// Test the unsubscription of stream works well
    96  			defer func() {
    97  				streamClose()
    98  				timeOutCtx, timeout := context.WithTimeout(context.Background(), time.Second)
    99  				defer timeout()
   100  				for { // Wait until there's no subscriber in routing stats channel
   101  					if len(c.Subscribers()) == 0 {
   102  						break
   103  					}
   104  					if timeOutCtx.Err() != nil {
   105  						t.Error("unexpected subscribers not decreased in channel", timeOutCtx.Err())
   106  					}
   107  				}
   108  			}()
   109  
   110  			stream, err := client.SubscribeRoutingStats(streamCtx, &SubscribeRoutingStatsRequest{})
   111  			if err != nil {
   112  				return err
   113  			}
   114  
   115  			for _, tc := range testCases {
   116  				msg, err := stream.Recv()
   117  				if err != nil {
   118  					return err
   119  				}
   120  				if r := cmp.Diff(msg, tc, cmpopts.IgnoreUnexported(RoutingContext{})); r != "" {
   121  					t.Error(r)
   122  				}
   123  			}
   124  
   125  			// Test that double subscription will fail
   126  			errStream, err := client.SubscribeRoutingStats(context.Background(), &SubscribeRoutingStatsRequest{
   127  				FieldSelectors: []string{"ip", "port", "domain", "outbound"},
   128  			})
   129  			if err != nil {
   130  				return err
   131  			}
   132  			if _, err := errStream.Recv(); err == nil {
   133  				t.Error("unexpected successful subscription")
   134  			}
   135  
   136  			return nil
   137  		}
   138  
   139  		if err := testRetrievingAllFields(); err != nil {
   140  			errCh <- err
   141  		}
   142  		errCh <- nil // Client passed all tests successfully
   143  	}()
   144  
   145  	// Wait for goroutines to complete
   146  	select {
   147  	case <-time.After(2 * time.Second):
   148  		t.Fatal("Test timeout after 2s")
   149  	case err := <-errCh:
   150  		if err != nil {
   151  			t.Fatal(err)
   152  		}
   153  	}
   154  }
   155  
   156  func TestServiceSubscribeSubsetOfFields(t *testing.T) {
   157  	c := stats.NewChannel(&stats.ChannelConfig{
   158  		SubscriberLimit: 1,
   159  		BufferSize:      0,
   160  		Blocking:        true,
   161  	})
   162  	common.Must(c.Start())
   163  	defer c.Close()
   164  
   165  	lis := bufconn.Listen(1024 * 1024)
   166  	bufDialer := func(context.Context, string) (net.Conn, error) {
   167  		return lis.Dial()
   168  	}
   169  
   170  	testCases := []*RoutingContext{
   171  		{InboundTag: "in", OutboundTag: "out"},
   172  		{TargetIPs: [][]byte{{1, 2, 3, 4}}, TargetPort: 8080, OutboundTag: "out"},
   173  		{TargetDomain: "example.com", TargetPort: 443, OutboundTag: "out"},
   174  		{SourcePort: 9999, TargetPort: 9999, OutboundTag: "out"},
   175  		{Network: net.Network_UDP, OutboundGroupTags: []string{"outergroup", "innergroup"}, OutboundTag: "out"},
   176  		{Protocol: "bittorrent", OutboundTag: "blocked"},
   177  		{User: "example@example.com", OutboundTag: "out"},
   178  		{SourceIPs: [][]byte{{127, 0, 0, 1}}, Attributes: map[string]string{"attr": "value"}, OutboundTag: "out"},
   179  	}
   180  	errCh := make(chan error)
   181  
   182  	// Server goroutine
   183  	go func() {
   184  		server := grpc.NewServer()
   185  		RegisterRoutingServiceServer(server, NewRoutingServer(nil, c))
   186  		errCh <- server.Serve(lis)
   187  	}()
   188  
   189  	// Publisher goroutine
   190  	go func() {
   191  		publishTestCases := func() error {
   192  			ctx, cancel := context.WithTimeout(context.Background(), time.Second)
   193  			defer cancel()
   194  			for { // Wait until there's one subscriber in routing stats channel
   195  				if len(c.Subscribers()) > 0 {
   196  					break
   197  				}
   198  				if ctx.Err() != nil {
   199  					return ctx.Err()
   200  				}
   201  			}
   202  			for _, tc := range testCases {
   203  				c.Publish(context.Background(), AsRoutingRoute(tc))
   204  				time.Sleep(time.Millisecond)
   205  			}
   206  			return nil
   207  		}
   208  
   209  		if err := publishTestCases(); err != nil {
   210  			errCh <- err
   211  		}
   212  	}()
   213  
   214  	// Client goroutine
   215  	go func() {
   216  		defer lis.Close()
   217  		conn, err := grpc.DialContext(context.Background(), "bufnet", grpc.WithContextDialer(bufDialer), grpc.WithInsecure())
   218  		if err != nil {
   219  			errCh <- err
   220  			return
   221  		}
   222  		defer conn.Close()
   223  		client := NewRoutingServiceClient(conn)
   224  
   225  		// Test retrieving only a subset of fields
   226  		testRetrievingSubsetOfFields := func() error {
   227  			streamCtx, streamClose := context.WithCancel(context.Background())
   228  			defer streamClose()
   229  			stream, err := client.SubscribeRoutingStats(streamCtx, &SubscribeRoutingStatsRequest{
   230  				FieldSelectors: []string{"ip", "port", "domain", "outbound"},
   231  			})
   232  			if err != nil {
   233  				return err
   234  			}
   235  
   236  			for _, tc := range testCases {
   237  				msg, err := stream.Recv()
   238  				if err != nil {
   239  					return err
   240  				}
   241  				stat := &RoutingContext{ // Only a subset of stats is retrieved
   242  					SourceIPs:         tc.SourceIPs,
   243  					TargetIPs:         tc.TargetIPs,
   244  					SourcePort:        tc.SourcePort,
   245  					TargetPort:        tc.TargetPort,
   246  					TargetDomain:      tc.TargetDomain,
   247  					OutboundGroupTags: tc.OutboundGroupTags,
   248  					OutboundTag:       tc.OutboundTag,
   249  				}
   250  				if r := cmp.Diff(msg, stat, cmpopts.IgnoreUnexported(RoutingContext{})); r != "" {
   251  					t.Error(r)
   252  				}
   253  			}
   254  
   255  			return nil
   256  		}
   257  		if err := testRetrievingSubsetOfFields(); err != nil {
   258  			errCh <- err
   259  		}
   260  		errCh <- nil // Client passed all tests successfully
   261  	}()
   262  
   263  	// Wait for goroutines to complete
   264  	select {
   265  	case <-time.After(2 * time.Second):
   266  		t.Fatal("Test timeout after 2s")
   267  	case err := <-errCh:
   268  		if err != nil {
   269  			t.Fatal(err)
   270  		}
   271  	}
   272  }
   273  
   274  func TestSerivceTestRoute(t *testing.T) {
   275  	c := stats.NewChannel(&stats.ChannelConfig{
   276  		SubscriberLimit: 1,
   277  		BufferSize:      16,
   278  		Blocking:        true,
   279  	})
   280  	common.Must(c.Start())
   281  	defer c.Close()
   282  
   283  	r := new(router.Router)
   284  	mockCtl := gomock.NewController(t)
   285  	defer mockCtl.Finish()
   286  	common.Must(r.Init(context.TODO(), &router.Config{
   287  		Rule: []*router.RoutingRule{
   288  			{
   289  				InboundTag: []string{"in"},
   290  				TargetTag:  &router.RoutingRule_Tag{Tag: "out"},
   291  			},
   292  			{
   293  				Protocol:  []string{"bittorrent"},
   294  				TargetTag: &router.RoutingRule_Tag{Tag: "blocked"},
   295  			},
   296  			{
   297  				PortList:  &net.PortList{Range: []*net.PortRange{{From: 8080, To: 8080}}},
   298  				TargetTag: &router.RoutingRule_Tag{Tag: "out"},
   299  			},
   300  			{
   301  				SourcePortList: &net.PortList{Range: []*net.PortRange{{From: 9999, To: 9999}}},
   302  				TargetTag:      &router.RoutingRule_Tag{Tag: "out"},
   303  			},
   304  			{
   305  				Domain:    []*router.Domain{{Type: router.Domain_Domain, Value: "com"}},
   306  				TargetTag: &router.RoutingRule_Tag{Tag: "out"},
   307  			},
   308  			{
   309  				SourceGeoip: []*router.GeoIP{{CountryCode: "private", Cidr: []*router.CIDR{{Ip: []byte{127, 0, 0, 0}, Prefix: 8}}}},
   310  				TargetTag:   &router.RoutingRule_Tag{Tag: "out"},
   311  			},
   312  			{
   313  				UserEmail: []string{"example@example.com"},
   314  				TargetTag: &router.RoutingRule_Tag{Tag: "out"},
   315  			},
   316  			{
   317  				Networks:  []net.Network{net.Network_UDP, net.Network_TCP},
   318  				TargetTag: &router.RoutingRule_Tag{Tag: "out"},
   319  			},
   320  		},
   321  	}, mocks.NewDNSClient(mockCtl), mocks.NewOutboundManager(mockCtl)))
   322  
   323  	lis := bufconn.Listen(1024 * 1024)
   324  	bufDialer := func(context.Context, string) (net.Conn, error) {
   325  		return lis.Dial()
   326  	}
   327  
   328  	errCh := make(chan error)
   329  
   330  	// Server goroutine
   331  	go func() {
   332  		server := grpc.NewServer()
   333  		RegisterRoutingServiceServer(server, NewRoutingServer(r, c))
   334  		errCh <- server.Serve(lis)
   335  	}()
   336  
   337  	// Client goroutine
   338  	go func() {
   339  		defer lis.Close()
   340  		conn, err := grpc.DialContext(context.Background(), "bufnet", grpc.WithContextDialer(bufDialer), grpc.WithInsecure())
   341  		if err != nil {
   342  			errCh <- err
   343  		}
   344  		defer conn.Close()
   345  		client := NewRoutingServiceClient(conn)
   346  
   347  		testCases := []*RoutingContext{
   348  			{InboundTag: "in", OutboundTag: "out"},
   349  			{TargetIPs: [][]byte{{1, 2, 3, 4}}, TargetPort: 8080, OutboundTag: "out"},
   350  			{TargetDomain: "example.com", TargetPort: 443, OutboundTag: "out"},
   351  			{SourcePort: 9999, TargetPort: 9999, OutboundTag: "out"},
   352  			{Network: net.Network_UDP, Protocol: "bittorrent", OutboundTag: "blocked"},
   353  			{User: "example@example.com", OutboundTag: "out"},
   354  			{SourceIPs: [][]byte{{127, 0, 0, 1}}, Attributes: map[string]string{"attr": "value"}, OutboundTag: "out"},
   355  		}
   356  
   357  		// Test simple TestRoute
   358  		testSimple := func() error {
   359  			for _, tc := range testCases {
   360  				route, err := client.TestRoute(context.Background(), &TestRouteRequest{RoutingContext: tc})
   361  				if err != nil {
   362  					return err
   363  				}
   364  				if r := cmp.Diff(route, tc, cmpopts.IgnoreUnexported(RoutingContext{})); r != "" {
   365  					t.Error(r)
   366  				}
   367  			}
   368  			return nil
   369  		}
   370  
   371  		// Test TestRoute with special options
   372  		testOptions := func() error {
   373  			sub, err := c.Subscribe()
   374  			if err != nil {
   375  				return err
   376  			}
   377  			for _, tc := range testCases {
   378  				route, err := client.TestRoute(context.Background(), &TestRouteRequest{
   379  					RoutingContext: tc,
   380  					FieldSelectors: []string{"ip", "port", "domain", "outbound"},
   381  					PublishResult:  true,
   382  				})
   383  				if err != nil {
   384  					return err
   385  				}
   386  				stat := &RoutingContext{ // Only a subset of stats is retrieved
   387  					SourceIPs:         tc.SourceIPs,
   388  					TargetIPs:         tc.TargetIPs,
   389  					SourcePort:        tc.SourcePort,
   390  					TargetPort:        tc.TargetPort,
   391  					TargetDomain:      tc.TargetDomain,
   392  					OutboundGroupTags: tc.OutboundGroupTags,
   393  					OutboundTag:       tc.OutboundTag,
   394  				}
   395  				if r := cmp.Diff(route, stat, cmpopts.IgnoreUnexported(RoutingContext{})); r != "" {
   396  					t.Error(r)
   397  				}
   398  				select { // Check that routing result has been published to statistics channel
   399  				case msg, received := <-sub:
   400  					if route, ok := msg.(routing.Route); received && ok {
   401  						if r := cmp.Diff(AsProtobufMessage(nil)(route), tc, cmpopts.IgnoreUnexported(RoutingContext{})); r != "" {
   402  							t.Error(r)
   403  						}
   404  					} else {
   405  						t.Error("unexpected failure in receiving published routing result for testcase", tc)
   406  					}
   407  				case <-time.After(100 * time.Millisecond):
   408  					t.Error("unexpected failure in receiving published routing result", tc)
   409  				}
   410  			}
   411  			return nil
   412  		}
   413  
   414  		if err := testSimple(); err != nil {
   415  			errCh <- err
   416  		}
   417  		if err := testOptions(); err != nil {
   418  			errCh <- err
   419  		}
   420  		errCh <- nil // Client passed all tests successfully
   421  	}()
   422  
   423  	// Wait for goroutines to complete
   424  	select {
   425  	case <-time.After(2 * time.Second):
   426  		t.Fatal("Test timeout after 2s")
   427  	case err := <-errCh:
   428  		if err != nil {
   429  			t.Fatal(err)
   430  		}
   431  	}
   432  }