github.com/imannamdari/v2ray-core/v5@v5.0.5/app/router/command/command_test.go (about)

     1  package command_test
     2  
     3  import (
     4  	"context"
     5  	"testing"
     6  	"time"
     7  
     8  	"github.com/golang/mock/gomock"
     9  	"github.com/google/go-cmp/cmp"
    10  	"github.com/google/go-cmp/cmp/cmpopts"
    11  	"google.golang.org/grpc"
    12  	"google.golang.org/grpc/test/bufconn"
    13  
    14  	"github.com/imannamdari/v2ray-core/v5/app/router"
    15  	. "github.com/imannamdari/v2ray-core/v5/app/router/command"
    16  	"github.com/imannamdari/v2ray-core/v5/app/router/routercommon"
    17  	"github.com/imannamdari/v2ray-core/v5/app/stats"
    18  	"github.com/imannamdari/v2ray-core/v5/common"
    19  	"github.com/imannamdari/v2ray-core/v5/common/net"
    20  	"github.com/imannamdari/v2ray-core/v5/features/routing"
    21  	"github.com/imannamdari/v2ray-core/v5/testing/mocks"
    22  )
    23  
    24  func TestServiceSubscribeRoutingStats(t *testing.T) {
    25  	c := stats.NewChannel(&stats.ChannelConfig{
    26  		SubscriberLimit: 1,
    27  		BufferSize:      0,
    28  		Blocking:        true,
    29  	})
    30  	common.Must(c.Start())
    31  	defer c.Close()
    32  
    33  	lis := bufconn.Listen(1024 * 1024)
    34  	bufDialer := func(context.Context, string) (net.Conn, error) {
    35  		return lis.Dial()
    36  	}
    37  
    38  	testCases := []*RoutingContext{
    39  		{InboundTag: "in", OutboundTag: "out"},
    40  		{TargetIPs: [][]byte{{1, 2, 3, 4}}, TargetPort: 8080, OutboundTag: "out"},
    41  		{TargetDomain: "example.com", TargetPort: 443, OutboundTag: "out"},
    42  		{SourcePort: 9999, TargetPort: 9999, OutboundTag: "out"},
    43  		{Network: net.Network_UDP, OutboundGroupTags: []string{"outergroup", "innergroup"}, OutboundTag: "out"},
    44  		{Protocol: "bittorrent", OutboundTag: "blocked"},
    45  		{User: "example@v2fly.org", OutboundTag: "out"},
    46  		{SourceIPs: [][]byte{{127, 0, 0, 1}}, Attributes: map[string]string{"attr": "value"}, OutboundTag: "out"},
    47  	}
    48  	errCh := make(chan error)
    49  	nextPub := make(chan struct{})
    50  
    51  	// Server goroutine
    52  	go func() {
    53  		server := grpc.NewServer()
    54  		RegisterRoutingServiceServer(server, NewRoutingServer(nil, c))
    55  		errCh <- server.Serve(lis)
    56  	}()
    57  
    58  	// Publisher goroutine
    59  	go func() {
    60  		publishTestCases := func() error {
    61  			ctx, cancel := context.WithTimeout(context.Background(), time.Second)
    62  			defer cancel()
    63  			for { // Wait until there's one subscriber in routing stats channel
    64  				if len(c.Subscribers()) > 0 {
    65  					break
    66  				}
    67  				if ctx.Err() != nil {
    68  					return ctx.Err()
    69  				}
    70  			}
    71  			for _, tc := range testCases {
    72  				c.Publish(context.Background(), AsRoutingRoute(tc))
    73  				time.Sleep(time.Millisecond)
    74  			}
    75  			return nil
    76  		}
    77  
    78  		if err := publishTestCases(); err != nil {
    79  			errCh <- err
    80  		}
    81  
    82  		// Wait for next round of publishing
    83  		<-nextPub
    84  
    85  		if err := publishTestCases(); err != nil {
    86  			errCh <- err
    87  		}
    88  	}()
    89  
    90  	// Client goroutine
    91  	go func() {
    92  		defer lis.Close()
    93  		conn, err := grpc.DialContext(context.Background(), "bufnet", grpc.WithContextDialer(bufDialer), grpc.WithInsecure())
    94  		if err != nil {
    95  			errCh <- err
    96  			return
    97  		}
    98  		defer conn.Close()
    99  		client := NewRoutingServiceClient(conn)
   100  
   101  		// Test retrieving all fields
   102  		testRetrievingAllFields := func() error {
   103  			streamCtx, streamClose := context.WithCancel(context.Background())
   104  
   105  			// Test the unsubscription of stream works well
   106  			defer func() {
   107  				streamClose()
   108  				timeOutCtx, timeout := context.WithTimeout(context.Background(), time.Second)
   109  				defer timeout()
   110  				for { // Wait until there's no subscriber in routing stats channel
   111  					if len(c.Subscribers()) == 0 {
   112  						break
   113  					}
   114  					if timeOutCtx.Err() != nil {
   115  						t.Error("unexpected subscribers not decreased in channel", timeOutCtx.Err())
   116  					}
   117  				}
   118  			}()
   119  
   120  			stream, err := client.SubscribeRoutingStats(streamCtx, &SubscribeRoutingStatsRequest{})
   121  			if err != nil {
   122  				return err
   123  			}
   124  
   125  			for _, tc := range testCases {
   126  				msg, err := stream.Recv()
   127  				if err != nil {
   128  					return err
   129  				}
   130  				if r := cmp.Diff(msg, tc, cmpopts.IgnoreUnexported(RoutingContext{})); r != "" {
   131  					t.Error(r)
   132  				}
   133  			}
   134  
   135  			// Test that double subscription will fail
   136  			errStream, err := client.SubscribeRoutingStats(context.Background(), &SubscribeRoutingStatsRequest{
   137  				FieldSelectors: []string{"ip", "port", "domain", "outbound"},
   138  			})
   139  			if err != nil {
   140  				return err
   141  			}
   142  			if _, err := errStream.Recv(); err == nil {
   143  				t.Error("unexpected successful subscription")
   144  			}
   145  
   146  			return nil
   147  		}
   148  
   149  		// Test retrieving only a subset of fields
   150  		testRetrievingSubsetOfFields := func() error {
   151  			streamCtx, streamClose := context.WithCancel(context.Background())
   152  			defer streamClose()
   153  			stream, err := client.SubscribeRoutingStats(streamCtx, &SubscribeRoutingStatsRequest{
   154  				FieldSelectors: []string{"ip", "port", "domain", "outbound"},
   155  			})
   156  			if err != nil {
   157  				return err
   158  			}
   159  
   160  			// Send nextPub signal to start next round of publishing
   161  			close(nextPub)
   162  
   163  			for _, tc := range testCases {
   164  				msg, err := stream.Recv()
   165  				if err != nil {
   166  					return err
   167  				}
   168  				stat := &RoutingContext{ // Only a subset of stats is retrieved
   169  					SourceIPs:         tc.SourceIPs,
   170  					TargetIPs:         tc.TargetIPs,
   171  					SourcePort:        tc.SourcePort,
   172  					TargetPort:        tc.TargetPort,
   173  					TargetDomain:      tc.TargetDomain,
   174  					OutboundGroupTags: tc.OutboundGroupTags,
   175  					OutboundTag:       tc.OutboundTag,
   176  				}
   177  				if r := cmp.Diff(msg, stat, cmpopts.IgnoreUnexported(RoutingContext{})); r != "" {
   178  					t.Error(r)
   179  				}
   180  			}
   181  
   182  			return nil
   183  		}
   184  
   185  		if err := testRetrievingAllFields(); err != nil {
   186  			errCh <- err
   187  		}
   188  		if err := testRetrievingSubsetOfFields(); err != nil {
   189  			errCh <- err
   190  		}
   191  		errCh <- nil // Client passed all tests successfully
   192  	}()
   193  
   194  	// Wait for goroutines to complete
   195  	select {
   196  	case <-time.After(2 * time.Second):
   197  		t.Fatal("Test timeout after 2s")
   198  	case err := <-errCh:
   199  		if err != nil {
   200  			t.Fatal(err)
   201  		}
   202  	}
   203  }
   204  
   205  func TestSerivceTestRoute(t *testing.T) {
   206  	c := stats.NewChannel(&stats.ChannelConfig{
   207  		SubscriberLimit: 1,
   208  		BufferSize:      16,
   209  		Blocking:        true,
   210  	})
   211  	common.Must(c.Start())
   212  	defer c.Close()
   213  
   214  	r := new(router.Router)
   215  	mockCtl := gomock.NewController(t)
   216  	defer mockCtl.Finish()
   217  	common.Must(r.Init(context.TODO(), &router.Config{
   218  		Rule: []*router.RoutingRule{
   219  			{
   220  				InboundTag: []string{"in"},
   221  				TargetTag:  &router.RoutingRule_Tag{Tag: "out"},
   222  			},
   223  			{
   224  				Protocol:  []string{"bittorrent"},
   225  				TargetTag: &router.RoutingRule_Tag{Tag: "blocked"},
   226  			},
   227  			{
   228  				PortList:  &net.PortList{Range: []*net.PortRange{{From: 8080, To: 8080}}},
   229  				TargetTag: &router.RoutingRule_Tag{Tag: "out"},
   230  			},
   231  			{
   232  				SourcePortList: &net.PortList{Range: []*net.PortRange{{From: 9999, To: 9999}}},
   233  				TargetTag:      &router.RoutingRule_Tag{Tag: "out"},
   234  			},
   235  			{
   236  				Domain:    []*routercommon.Domain{{Type: routercommon.Domain_RootDomain, Value: "com"}},
   237  				TargetTag: &router.RoutingRule_Tag{Tag: "out"},
   238  			},
   239  			{
   240  				SourceGeoip: []*routercommon.GeoIP{{CountryCode: "private", Cidr: []*routercommon.CIDR{{Ip: []byte{127, 0, 0, 0}, Prefix: 8}}}},
   241  				TargetTag:   &router.RoutingRule_Tag{Tag: "out"},
   242  			},
   243  			{
   244  				UserEmail: []string{"example@v2fly.org"},
   245  				TargetTag: &router.RoutingRule_Tag{Tag: "out"},
   246  			},
   247  			{
   248  				Networks:  []net.Network{net.Network_UDP, net.Network_TCP},
   249  				TargetTag: &router.RoutingRule_Tag{Tag: "out"},
   250  			},
   251  		},
   252  	}, mocks.NewDNSClient(mockCtl), mocks.NewOutboundManager(mockCtl), nil))
   253  
   254  	lis := bufconn.Listen(1024 * 1024)
   255  	bufDialer := func(context.Context, string) (net.Conn, error) {
   256  		return lis.Dial()
   257  	}
   258  
   259  	errCh := make(chan error)
   260  
   261  	// Server goroutine
   262  	go func() {
   263  		server := grpc.NewServer()
   264  		RegisterRoutingServiceServer(server, NewRoutingServer(r, c))
   265  		errCh <- server.Serve(lis)
   266  	}()
   267  
   268  	// Client goroutine
   269  	go func() {
   270  		defer lis.Close()
   271  		conn, err := grpc.DialContext(context.Background(), "bufnet", grpc.WithContextDialer(bufDialer), grpc.WithInsecure())
   272  		if err != nil {
   273  			errCh <- err
   274  		}
   275  		defer conn.Close()
   276  		client := NewRoutingServiceClient(conn)
   277  
   278  		testCases := []*RoutingContext{
   279  			{InboundTag: "in", OutboundTag: "out"},
   280  			{TargetIPs: [][]byte{{1, 2, 3, 4}}, TargetPort: 8080, OutboundTag: "out"},
   281  			{TargetDomain: "example.com", TargetPort: 443, OutboundTag: "out"},
   282  			{SourcePort: 9999, TargetPort: 9999, OutboundTag: "out"},
   283  			{Network: net.Network_UDP, Protocol: "bittorrent", OutboundTag: "blocked"},
   284  			{User: "example@v2fly.org", OutboundTag: "out"},
   285  			{SourceIPs: [][]byte{{127, 0, 0, 1}}, Attributes: map[string]string{"attr": "value"}, OutboundTag: "out"},
   286  		}
   287  
   288  		// Test simple TestRoute
   289  		testSimple := func() error {
   290  			for _, tc := range testCases {
   291  				route, err := client.TestRoute(context.Background(), &TestRouteRequest{RoutingContext: tc})
   292  				if err != nil {
   293  					return err
   294  				}
   295  				if r := cmp.Diff(route, tc, cmpopts.IgnoreUnexported(RoutingContext{})); r != "" {
   296  					t.Error(r)
   297  				}
   298  			}
   299  			return nil
   300  		}
   301  
   302  		// Test TestRoute with special options
   303  		testOptions := func() error {
   304  			sub, err := c.Subscribe()
   305  			if err != nil {
   306  				return err
   307  			}
   308  			for _, tc := range testCases {
   309  				route, err := client.TestRoute(context.Background(), &TestRouteRequest{
   310  					RoutingContext: tc,
   311  					FieldSelectors: []string{"ip", "port", "domain", "outbound"},
   312  					PublishResult:  true,
   313  				})
   314  				if err != nil {
   315  					return err
   316  				}
   317  				stat := &RoutingContext{ // Only a subset of stats is retrieved
   318  					SourceIPs:         tc.SourceIPs,
   319  					TargetIPs:         tc.TargetIPs,
   320  					SourcePort:        tc.SourcePort,
   321  					TargetPort:        tc.TargetPort,
   322  					TargetDomain:      tc.TargetDomain,
   323  					OutboundGroupTags: tc.OutboundGroupTags,
   324  					OutboundTag:       tc.OutboundTag,
   325  				}
   326  				if r := cmp.Diff(route, stat, cmpopts.IgnoreUnexported(RoutingContext{})); r != "" {
   327  					t.Error(r)
   328  				}
   329  				select { // Check that routing result has been published to statistics channel
   330  				case msg, received := <-sub:
   331  					if route, ok := msg.(routing.Route); received && ok {
   332  						if r := cmp.Diff(AsProtobufMessage(nil)(route), tc, cmpopts.IgnoreUnexported(RoutingContext{})); r != "" {
   333  							t.Error(r)
   334  						}
   335  					} else {
   336  						t.Error("unexpected failure in receiving published routing result for testcase", tc)
   337  					}
   338  				case <-time.After(100 * time.Millisecond):
   339  					t.Error("unexpected failure in receiving published routing result", tc)
   340  				}
   341  			}
   342  			return nil
   343  		}
   344  
   345  		if err := testSimple(); err != nil {
   346  			errCh <- err
   347  		}
   348  		if err := testOptions(); err != nil {
   349  			errCh <- err
   350  		}
   351  		errCh <- nil // Client passed all tests successfully
   352  	}()
   353  
   354  	// Wait for goroutines to complete
   355  	select {
   356  	case <-time.After(2 * time.Second):
   357  		t.Fatal("Test timeout after 2s")
   358  	case err := <-errCh:
   359  		if err != nil {
   360  			t.Fatal(err)
   361  		}
   362  	}
   363  }