github.com/xmplusdev/xray-core@v1.8.10/app/router/command/command_test.go (about)

     1  package command_test
     2  
     3  import (
     4  	"context"
     5  	"testing"
     6  	"time"
     7  
     8  	"github.com/golang/mock/gomock"
     9  	"github.com/google/go-cmp/cmp"
    10  	"github.com/google/go-cmp/cmp/cmpopts"
    11  	"github.com/xmplusdev/xray-core/app/router"
    12  	. "github.com/xmplusdev/xray-core/app/router/command"
    13  	"github.com/xmplusdev/xray-core/app/stats"
    14  	"github.com/xmplusdev/xray-core/common"
    15  	"github.com/xmplusdev/xray-core/common/net"
    16  	"github.com/xmplusdev/xray-core/features/routing"
    17  	"github.com/xmplusdev/xray-core/testing/mocks"
    18  	"google.golang.org/grpc"
    19  	"google.golang.org/grpc/credentials/insecure"
    20  	"google.golang.org/grpc/test/bufconn"
    21  )
    22  
    23  func TestServiceSubscribeRoutingStats(t *testing.T) {
    24  	c := stats.NewChannel(&stats.ChannelConfig{
    25  		SubscriberLimit: 1,
    26  		BufferSize:      0,
    27  		Blocking:        true,
    28  	})
    29  	common.Must(c.Start())
    30  	defer c.Close()
    31  
    32  	lis := bufconn.Listen(1024 * 1024)
    33  	bufDialer := func(context.Context, string) (net.Conn, error) {
    34  		return lis.Dial()
    35  	}
    36  
    37  	testCases := []*RoutingContext{
    38  		{InboundTag: "in", OutboundTag: "out"},
    39  		{TargetIPs: [][]byte{{1, 2, 3, 4}}, TargetPort: 8080, OutboundTag: "out"},
    40  		{TargetDomain: "example.com", TargetPort: 443, OutboundTag: "out"},
    41  		{SourcePort: 9999, TargetPort: 9999, OutboundTag: "out"},
    42  		{Network: net.Network_UDP, OutboundGroupTags: []string{"outergroup", "innergroup"}, OutboundTag: "out"},
    43  		{Protocol: "bittorrent", OutboundTag: "blocked"},
    44  		{User: "example@example.com", OutboundTag: "out"},
    45  		{SourceIPs: [][]byte{{127, 0, 0, 1}}, Attributes: map[string]string{"attr": "value"}, OutboundTag: "out"},
    46  	}
    47  	errCh := make(chan error)
    48  
    49  	// Server goroutine
    50  	go func() {
    51  		server := grpc.NewServer()
    52  		RegisterRoutingServiceServer(server, NewRoutingServer(nil, c))
    53  		errCh <- server.Serve(lis)
    54  	}()
    55  
    56  	// Publisher goroutine
    57  	go func() {
    58  		publishTestCases := func() error {
    59  			ctx, cancel := context.WithTimeout(context.Background(), time.Second)
    60  			defer cancel()
    61  			for { // Wait until there's one subscriber in routing stats channel
    62  				if len(c.Subscribers()) > 0 {
    63  					break
    64  				}
    65  				if ctx.Err() != nil {
    66  					return ctx.Err()
    67  				}
    68  			}
    69  			for _, tc := range testCases {
    70  				c.Publish(context.Background(), AsRoutingRoute(tc))
    71  				time.Sleep(time.Millisecond)
    72  			}
    73  			return nil
    74  		}
    75  
    76  		if err := publishTestCases(); err != nil {
    77  			errCh <- err
    78  		}
    79  	}()
    80  
    81  	// Client goroutine
    82  	go func() {
    83  		defer lis.Close()
    84  		conn, err := grpc.DialContext(context.Background(), "bufnet", grpc.WithContextDialer(bufDialer), grpc.WithTransportCredentials(insecure.NewCredentials()))
    85  		if err != nil {
    86  			errCh <- err
    87  			return
    88  		}
    89  		defer conn.Close()
    90  		client := NewRoutingServiceClient(conn)
    91  
    92  		// Test retrieving all fields
    93  		testRetrievingAllFields := func() error {
    94  			streamCtx, streamClose := context.WithCancel(context.Background())
    95  
    96  			// Test the unsubscription of stream works well
    97  			defer func() {
    98  				streamClose()
    99  				timeOutCtx, timeout := context.WithTimeout(context.Background(), time.Second)
   100  				defer timeout()
   101  				for { // Wait until there's no subscriber in routing stats channel
   102  					if len(c.Subscribers()) == 0 {
   103  						break
   104  					}
   105  					if timeOutCtx.Err() != nil {
   106  						t.Error("unexpected subscribers not decreased in channel", timeOutCtx.Err())
   107  					}
   108  				}
   109  			}()
   110  
   111  			stream, err := client.SubscribeRoutingStats(streamCtx, &SubscribeRoutingStatsRequest{})
   112  			if err != nil {
   113  				return err
   114  			}
   115  
   116  			for _, tc := range testCases {
   117  				msg, err := stream.Recv()
   118  				if err != nil {
   119  					return err
   120  				}
   121  				if r := cmp.Diff(msg, tc, cmpopts.IgnoreUnexported(RoutingContext{})); r != "" {
   122  					t.Error(r)
   123  				}
   124  			}
   125  
   126  			// Test that double subscription will fail
   127  			errStream, err := client.SubscribeRoutingStats(context.Background(), &SubscribeRoutingStatsRequest{
   128  				FieldSelectors: []string{"ip", "port", "domain", "outbound"},
   129  			})
   130  			if err != nil {
   131  				return err
   132  			}
   133  			if _, err := errStream.Recv(); err == nil {
   134  				t.Error("unexpected successful subscription")
   135  			}
   136  
   137  			return nil
   138  		}
   139  
   140  		if err := testRetrievingAllFields(); err != nil {
   141  			errCh <- err
   142  		}
   143  		errCh <- nil // Client passed all tests successfully
   144  	}()
   145  
   146  	// Wait for goroutines to complete
   147  	select {
   148  	case <-time.After(2 * time.Second):
   149  		t.Fatal("Test timeout after 2s")
   150  	case err := <-errCh:
   151  		if err != nil {
   152  			t.Fatal(err)
   153  		}
   154  	}
   155  }
   156  
   157  func TestServiceSubscribeSubsetOfFields(t *testing.T) {
   158  	c := stats.NewChannel(&stats.ChannelConfig{
   159  		SubscriberLimit: 1,
   160  		BufferSize:      0,
   161  		Blocking:        true,
   162  	})
   163  	common.Must(c.Start())
   164  	defer c.Close()
   165  
   166  	lis := bufconn.Listen(1024 * 1024)
   167  	bufDialer := func(context.Context, string) (net.Conn, error) {
   168  		return lis.Dial()
   169  	}
   170  
   171  	testCases := []*RoutingContext{
   172  		{InboundTag: "in", OutboundTag: "out"},
   173  		{TargetIPs: [][]byte{{1, 2, 3, 4}}, TargetPort: 8080, OutboundTag: "out"},
   174  		{TargetDomain: "example.com", TargetPort: 443, OutboundTag: "out"},
   175  		{SourcePort: 9999, TargetPort: 9999, OutboundTag: "out"},
   176  		{Network: net.Network_UDP, OutboundGroupTags: []string{"outergroup", "innergroup"}, OutboundTag: "out"},
   177  		{Protocol: "bittorrent", OutboundTag: "blocked"},
   178  		{User: "example@example.com", OutboundTag: "out"},
   179  		{SourceIPs: [][]byte{{127, 0, 0, 1}}, Attributes: map[string]string{"attr": "value"}, OutboundTag: "out"},
   180  	}
   181  	errCh := make(chan error)
   182  
   183  	// Server goroutine
   184  	go func() {
   185  		server := grpc.NewServer()
   186  		RegisterRoutingServiceServer(server, NewRoutingServer(nil, c))
   187  		errCh <- server.Serve(lis)
   188  	}()
   189  
   190  	// Publisher goroutine
   191  	go func() {
   192  		publishTestCases := func() error {
   193  			ctx, cancel := context.WithTimeout(context.Background(), time.Second)
   194  			defer cancel()
   195  			for { // Wait until there's one subscriber in routing stats channel
   196  				if len(c.Subscribers()) > 0 {
   197  					break
   198  				}
   199  				if ctx.Err() != nil {
   200  					return ctx.Err()
   201  				}
   202  			}
   203  			for _, tc := range testCases {
   204  				c.Publish(context.Background(), AsRoutingRoute(tc))
   205  				time.Sleep(time.Millisecond)
   206  			}
   207  			return nil
   208  		}
   209  
   210  		if err := publishTestCases(); err != nil {
   211  			errCh <- err
   212  		}
   213  	}()
   214  
   215  	// Client goroutine
   216  	go func() {
   217  		defer lis.Close()
   218  		conn, err := grpc.DialContext(context.Background(), "bufnet", grpc.WithContextDialer(bufDialer), grpc.WithTransportCredentials(insecure.NewCredentials()))
   219  		if err != nil {
   220  			errCh <- err
   221  			return
   222  		}
   223  		defer conn.Close()
   224  		client := NewRoutingServiceClient(conn)
   225  
   226  		// Test retrieving only a subset of fields
   227  		testRetrievingSubsetOfFields := func() error {
   228  			streamCtx, streamClose := context.WithCancel(context.Background())
   229  			defer streamClose()
   230  			stream, err := client.SubscribeRoutingStats(streamCtx, &SubscribeRoutingStatsRequest{
   231  				FieldSelectors: []string{"ip", "port", "domain", "outbound"},
   232  			})
   233  			if err != nil {
   234  				return err
   235  			}
   236  
   237  			for _, tc := range testCases {
   238  				msg, err := stream.Recv()
   239  				if err != nil {
   240  					return err
   241  				}
   242  				stat := &RoutingContext{ // Only a subset of stats is retrieved
   243  					SourceIPs:         tc.SourceIPs,
   244  					TargetIPs:         tc.TargetIPs,
   245  					SourcePort:        tc.SourcePort,
   246  					TargetPort:        tc.TargetPort,
   247  					TargetDomain:      tc.TargetDomain,
   248  					OutboundGroupTags: tc.OutboundGroupTags,
   249  					OutboundTag:       tc.OutboundTag,
   250  				}
   251  				if r := cmp.Diff(msg, stat, cmpopts.IgnoreUnexported(RoutingContext{})); r != "" {
   252  					t.Error(r)
   253  				}
   254  			}
   255  
   256  			return nil
   257  		}
   258  		if err := testRetrievingSubsetOfFields(); err != nil {
   259  			errCh <- err
   260  		}
   261  		errCh <- nil // Client passed all tests successfully
   262  	}()
   263  
   264  	// Wait for goroutines to complete
   265  	select {
   266  	case <-time.After(2 * time.Second):
   267  		t.Fatal("Test timeout after 2s")
   268  	case err := <-errCh:
   269  		if err != nil {
   270  			t.Fatal(err)
   271  		}
   272  	}
   273  }
   274  
   275  func TestSerivceTestRoute(t *testing.T) {
   276  	c := stats.NewChannel(&stats.ChannelConfig{
   277  		SubscriberLimit: 1,
   278  		BufferSize:      16,
   279  		Blocking:        true,
   280  	})
   281  	common.Must(c.Start())
   282  	defer c.Close()
   283  
   284  	r := new(router.Router)
   285  	mockCtl := gomock.NewController(t)
   286  	defer mockCtl.Finish()
   287  	common.Must(r.Init(context.TODO(), &router.Config{
   288  		Rule: []*router.RoutingRule{
   289  			{
   290  				InboundTag: []string{"in"},
   291  				TargetTag:  &router.RoutingRule_Tag{Tag: "out"},
   292  			},
   293  			{
   294  				Protocol:  []string{"bittorrent"},
   295  				TargetTag: &router.RoutingRule_Tag{Tag: "blocked"},
   296  			},
   297  			{
   298  				PortList:  &net.PortList{Range: []*net.PortRange{{From: 8080, To: 8080}}},
   299  				TargetTag: &router.RoutingRule_Tag{Tag: "out"},
   300  			},
   301  			{
   302  				SourcePortList: &net.PortList{Range: []*net.PortRange{{From: 9999, To: 9999}}},
   303  				TargetTag:      &router.RoutingRule_Tag{Tag: "out"},
   304  			},
   305  			{
   306  				Domain:    []*router.Domain{{Type: router.Domain_Domain, Value: "com"}},
   307  				TargetTag: &router.RoutingRule_Tag{Tag: "out"},
   308  			},
   309  			{
   310  				SourceGeoip: []*router.GeoIP{{CountryCode: "private", Cidr: []*router.CIDR{{Ip: []byte{127, 0, 0, 0}, Prefix: 8}}}},
   311  				TargetTag:   &router.RoutingRule_Tag{Tag: "out"},
   312  			},
   313  			{
   314  				UserEmail: []string{"example@example.com"},
   315  				TargetTag: &router.RoutingRule_Tag{Tag: "out"},
   316  			},
   317  			{
   318  				Networks:  []net.Network{net.Network_UDP, net.Network_TCP},
   319  				TargetTag: &router.RoutingRule_Tag{Tag: "out"},
   320  			},
   321  		},
   322  	}, mocks.NewDNSClient(mockCtl), mocks.NewOutboundManager(mockCtl), nil))
   323  
   324  	lis := bufconn.Listen(1024 * 1024)
   325  	bufDialer := func(context.Context, string) (net.Conn, error) {
   326  		return lis.Dial()
   327  	}
   328  
   329  	errCh := make(chan error)
   330  
   331  	// Server goroutine
   332  	go func() {
   333  		server := grpc.NewServer()
   334  		RegisterRoutingServiceServer(server, NewRoutingServer(r, c))
   335  		errCh <- server.Serve(lis)
   336  	}()
   337  
   338  	// Client goroutine
   339  	go func() {
   340  		defer lis.Close()
   341  		conn, err := grpc.DialContext(context.Background(), "bufnet", grpc.WithContextDialer(bufDialer), grpc.WithTransportCredentials(insecure.NewCredentials()))
   342  		if err != nil {
   343  			errCh <- err
   344  		}
   345  		defer conn.Close()
   346  		client := NewRoutingServiceClient(conn)
   347  
   348  		testCases := []*RoutingContext{
   349  			{InboundTag: "in", OutboundTag: "out"},
   350  			{TargetIPs: [][]byte{{1, 2, 3, 4}}, TargetPort: 8080, OutboundTag: "out"},
   351  			{TargetDomain: "example.com", TargetPort: 443, OutboundTag: "out"},
   352  			{SourcePort: 9999, TargetPort: 9999, OutboundTag: "out"},
   353  			{Network: net.Network_UDP, Protocol: "bittorrent", OutboundTag: "blocked"},
   354  			{User: "example@example.com", OutboundTag: "out"},
   355  			{SourceIPs: [][]byte{{127, 0, 0, 1}}, Attributes: map[string]string{"attr": "value"}, OutboundTag: "out"},
   356  		}
   357  
   358  		// Test simple TestRoute
   359  		testSimple := func() error {
   360  			for _, tc := range testCases {
   361  				route, err := client.TestRoute(context.Background(), &TestRouteRequest{RoutingContext: tc})
   362  				if err != nil {
   363  					return err
   364  				}
   365  				if r := cmp.Diff(route, tc, cmpopts.IgnoreUnexported(RoutingContext{})); r != "" {
   366  					t.Error(r)
   367  				}
   368  			}
   369  			return nil
   370  		}
   371  
   372  		// Test TestRoute with special options
   373  		testOptions := func() error {
   374  			sub, err := c.Subscribe()
   375  			if err != nil {
   376  				return err
   377  			}
   378  			for _, tc := range testCases {
   379  				route, err := client.TestRoute(context.Background(), &TestRouteRequest{
   380  					RoutingContext: tc,
   381  					FieldSelectors: []string{"ip", "port", "domain", "outbound"},
   382  					PublishResult:  true,
   383  				})
   384  				if err != nil {
   385  					return err
   386  				}
   387  				stat := &RoutingContext{ // Only a subset of stats is retrieved
   388  					SourceIPs:         tc.SourceIPs,
   389  					TargetIPs:         tc.TargetIPs,
   390  					SourcePort:        tc.SourcePort,
   391  					TargetPort:        tc.TargetPort,
   392  					TargetDomain:      tc.TargetDomain,
   393  					OutboundGroupTags: tc.OutboundGroupTags,
   394  					OutboundTag:       tc.OutboundTag,
   395  				}
   396  				if r := cmp.Diff(route, stat, cmpopts.IgnoreUnexported(RoutingContext{})); r != "" {
   397  					t.Error(r)
   398  				}
   399  				select { // Check that routing result has been published to statistics channel
   400  				case msg, received := <-sub:
   401  					if route, ok := msg.(routing.Route); received && ok {
   402  						if r := cmp.Diff(AsProtobufMessage(nil)(route), tc, cmpopts.IgnoreUnexported(RoutingContext{})); r != "" {
   403  							t.Error(r)
   404  						}
   405  					} else {
   406  						t.Error("unexpected failure in receiving published routing result for testcase", tc)
   407  					}
   408  				case <-time.After(100 * time.Millisecond):
   409  					t.Error("unexpected failure in receiving published routing result", tc)
   410  				}
   411  			}
   412  			return nil
   413  		}
   414  
   415  		if err := testSimple(); err != nil {
   416  			errCh <- err
   417  		}
   418  		if err := testOptions(); err != nil {
   419  			errCh <- err
   420  		}
   421  		errCh <- nil // Client passed all tests successfully
   422  	}()
   423  
   424  	// Wait for goroutines to complete
   425  	select {
   426  	case <-time.After(2 * time.Second):
   427  		t.Fatal("Test timeout after 2s")
   428  	case err := <-errCh:
   429  		if err != nil {
   430  			t.Fatal(err)
   431  		}
   432  	}
   433  }