github.com/eagleql/xray-core@v1.4.4/app/router/command/command_test.go (about) 1 package command_test 2 3 import ( 4 "context" 5 "testing" 6 "time" 7 8 "github.com/eagleql/xray-core/app/router" 9 . "github.com/eagleql/xray-core/app/router/command" 10 "github.com/eagleql/xray-core/app/stats" 11 "github.com/eagleql/xray-core/common" 12 "github.com/eagleql/xray-core/common/net" 13 "github.com/eagleql/xray-core/features/routing" 14 "github.com/eagleql/xray-core/testing/mocks" 15 "github.com/golang/mock/gomock" 16 "github.com/google/go-cmp/cmp" 17 "github.com/google/go-cmp/cmp/cmpopts" 18 "google.golang.org/grpc" 19 "google.golang.org/grpc/test/bufconn" 20 ) 21 22 func TestServiceSubscribeRoutingStats(t *testing.T) { 23 c := stats.NewChannel(&stats.ChannelConfig{ 24 SubscriberLimit: 1, 25 BufferSize: 0, 26 Blocking: true, 27 }) 28 common.Must(c.Start()) 29 defer c.Close() 30 31 lis := bufconn.Listen(1024 * 1024) 32 bufDialer := func(context.Context, string) (net.Conn, error) { 33 return lis.Dial() 34 } 35 36 testCases := []*RoutingContext{ 37 {InboundTag: "in", OutboundTag: "out"}, 38 {TargetIPs: [][]byte{{1, 2, 3, 4}}, TargetPort: 8080, OutboundTag: "out"}, 39 {TargetDomain: "example.com", TargetPort: 443, OutboundTag: "out"}, 40 {SourcePort: 9999, TargetPort: 9999, OutboundTag: "out"}, 41 {Network: net.Network_UDP, OutboundGroupTags: []string{"outergroup", "innergroup"}, OutboundTag: "out"}, 42 {Protocol: "bittorrent", OutboundTag: "blocked"}, 43 {User: "example@example.com", OutboundTag: "out"}, 44 {SourceIPs: [][]byte{{127, 0, 0, 1}}, Attributes: map[string]string{"attr": "value"}, OutboundTag: "out"}, 45 } 46 errCh := make(chan error) 47 nextPub := make(chan struct{}) 48 49 // Server goroutine 50 go func() { 51 server := grpc.NewServer() 52 RegisterRoutingServiceServer(server, NewRoutingServer(nil, c)) 53 errCh <- server.Serve(lis) 54 }() 55 56 // Publisher goroutine 57 go func() { 58 publishTestCases := func() error { 59 ctx, cancel := context.WithTimeout(context.Background(), time.Second) 60 defer cancel() 61 for { // Wait until there's one subscriber in routing stats channel 62 if len(c.Subscribers()) > 0 { 63 break 64 } 65 if ctx.Err() != nil { 66 return ctx.Err() 67 } 68 } 69 for _, tc := range testCases { 70 c.Publish(context.Background(), AsRoutingRoute(tc)) 71 time.Sleep(time.Millisecond) 72 } 73 return nil 74 } 75 76 if err := publishTestCases(); err != nil { 77 errCh <- err 78 } 79 80 // Wait for next round of publishing 81 <-nextPub 82 83 if err := publishTestCases(); err != nil { 84 errCh <- err 85 } 86 }() 87 88 // Client goroutine 89 go func() { 90 defer lis.Close() 91 conn, err := grpc.DialContext(context.Background(), "bufnet", grpc.WithContextDialer(bufDialer), grpc.WithInsecure()) 92 if err != nil { 93 errCh <- err 94 return 95 } 96 defer conn.Close() 97 client := NewRoutingServiceClient(conn) 98 99 // Test retrieving all fields 100 testRetrievingAllFields := func() error { 101 streamCtx, streamClose := context.WithCancel(context.Background()) 102 103 // Test the unsubscription of stream works well 104 defer func() { 105 streamClose() 106 timeOutCtx, timeout := context.WithTimeout(context.Background(), time.Second) 107 defer timeout() 108 for { // Wait until there's no subscriber in routing stats channel 109 if len(c.Subscribers()) == 0 { 110 break 111 } 112 if timeOutCtx.Err() != nil { 113 t.Error("unexpected subscribers not decreased in channel", timeOutCtx.Err()) 114 } 115 } 116 }() 117 118 stream, err := client.SubscribeRoutingStats(streamCtx, &SubscribeRoutingStatsRequest{}) 119 if err != nil { 120 return err 121 } 122 123 for _, tc := range testCases { 124 msg, err := stream.Recv() 125 if err != nil { 126 return err 127 } 128 if r := cmp.Diff(msg, tc, cmpopts.IgnoreUnexported(RoutingContext{})); r != "" { 129 t.Error(r) 130 } 131 } 132 133 // Test that double subscription will fail 134 errStream, err := client.SubscribeRoutingStats(context.Background(), &SubscribeRoutingStatsRequest{ 135 FieldSelectors: []string{"ip", "port", "domain", "outbound"}, 136 }) 137 if err != nil { 138 return err 139 } 140 if _, err := errStream.Recv(); err == nil { 141 t.Error("unexpected successful subscription") 142 } 143 144 return nil 145 } 146 147 // Test retrieving only a subset of fields 148 testRetrievingSubsetOfFields := func() error { 149 streamCtx, streamClose := context.WithCancel(context.Background()) 150 defer streamClose() 151 stream, err := client.SubscribeRoutingStats(streamCtx, &SubscribeRoutingStatsRequest{ 152 FieldSelectors: []string{"ip", "port", "domain", "outbound"}, 153 }) 154 if err != nil { 155 return err 156 } 157 158 // Send nextPub signal to start next round of publishing 159 close(nextPub) 160 161 for _, tc := range testCases { 162 msg, err := stream.Recv() 163 if err != nil { 164 return err 165 } 166 stat := &RoutingContext{ // Only a subset of stats is retrieved 167 SourceIPs: tc.SourceIPs, 168 TargetIPs: tc.TargetIPs, 169 SourcePort: tc.SourcePort, 170 TargetPort: tc.TargetPort, 171 TargetDomain: tc.TargetDomain, 172 OutboundGroupTags: tc.OutboundGroupTags, 173 OutboundTag: tc.OutboundTag, 174 } 175 if r := cmp.Diff(msg, stat, cmpopts.IgnoreUnexported(RoutingContext{})); r != "" { 176 t.Error(r) 177 } 178 } 179 180 return nil 181 } 182 183 if err := testRetrievingAllFields(); err != nil { 184 errCh <- err 185 } 186 if err := testRetrievingSubsetOfFields(); err != nil { 187 errCh <- err 188 } 189 errCh <- nil // Client passed all tests successfully 190 }() 191 192 // Wait for goroutines to complete 193 select { 194 case <-time.After(2 * time.Second): 195 t.Fatal("Test timeout after 2s") 196 case err := <-errCh: 197 if err != nil { 198 t.Fatal(err) 199 } 200 } 201 } 202 203 func TestSerivceTestRoute(t *testing.T) { 204 c := stats.NewChannel(&stats.ChannelConfig{ 205 SubscriberLimit: 1, 206 BufferSize: 16, 207 Blocking: true, 208 }) 209 common.Must(c.Start()) 210 defer c.Close() 211 212 r := new(router.Router) 213 mockCtl := gomock.NewController(t) 214 defer mockCtl.Finish() 215 common.Must(r.Init(&router.Config{ 216 Rule: []*router.RoutingRule{ 217 { 218 InboundTag: []string{"in"}, 219 TargetTag: &router.RoutingRule_Tag{Tag: "out"}, 220 }, 221 { 222 Protocol: []string{"bittorrent"}, 223 TargetTag: &router.RoutingRule_Tag{Tag: "blocked"}, 224 }, 225 { 226 PortList: &net.PortList{Range: []*net.PortRange{{From: 8080, To: 8080}}}, 227 TargetTag: &router.RoutingRule_Tag{Tag: "out"}, 228 }, 229 { 230 SourcePortList: &net.PortList{Range: []*net.PortRange{{From: 9999, To: 9999}}}, 231 TargetTag: &router.RoutingRule_Tag{Tag: "out"}, 232 }, 233 { 234 Domain: []*router.Domain{{Type: router.Domain_Domain, Value: "com"}}, 235 TargetTag: &router.RoutingRule_Tag{Tag: "out"}, 236 }, 237 { 238 SourceGeoip: []*router.GeoIP{{CountryCode: "private", Cidr: []*router.CIDR{{Ip: []byte{127, 0, 0, 0}, Prefix: 8}}}}, 239 TargetTag: &router.RoutingRule_Tag{Tag: "out"}, 240 }, 241 { 242 UserEmail: []string{"example@example.com"}, 243 TargetTag: &router.RoutingRule_Tag{Tag: "out"}, 244 }, 245 { 246 Networks: []net.Network{net.Network_UDP, net.Network_TCP}, 247 TargetTag: &router.RoutingRule_Tag{Tag: "out"}, 248 }, 249 }, 250 }, mocks.NewDNSClient(mockCtl), mocks.NewOutboundManager(mockCtl))) 251 252 lis := bufconn.Listen(1024 * 1024) 253 bufDialer := func(context.Context, string) (net.Conn, error) { 254 return lis.Dial() 255 } 256 257 errCh := make(chan error) 258 259 // Server goroutine 260 go func() { 261 server := grpc.NewServer() 262 RegisterRoutingServiceServer(server, NewRoutingServer(r, c)) 263 errCh <- server.Serve(lis) 264 }() 265 266 // Client goroutine 267 go func() { 268 defer lis.Close() 269 conn, err := grpc.DialContext(context.Background(), "bufnet", grpc.WithContextDialer(bufDialer), grpc.WithInsecure()) 270 if err != nil { 271 errCh <- err 272 } 273 defer conn.Close() 274 client := NewRoutingServiceClient(conn) 275 276 testCases := []*RoutingContext{ 277 {InboundTag: "in", OutboundTag: "out"}, 278 {TargetIPs: [][]byte{{1, 2, 3, 4}}, TargetPort: 8080, OutboundTag: "out"}, 279 {TargetDomain: "example.com", TargetPort: 443, OutboundTag: "out"}, 280 {SourcePort: 9999, TargetPort: 9999, OutboundTag: "out"}, 281 {Network: net.Network_UDP, Protocol: "bittorrent", OutboundTag: "blocked"}, 282 {User: "example@example.com", OutboundTag: "out"}, 283 {SourceIPs: [][]byte{{127, 0, 0, 1}}, Attributes: map[string]string{"attr": "value"}, OutboundTag: "out"}, 284 } 285 286 // Test simple TestRoute 287 testSimple := func() error { 288 for _, tc := range testCases { 289 route, err := client.TestRoute(context.Background(), &TestRouteRequest{RoutingContext: tc}) 290 if err != nil { 291 return err 292 } 293 if r := cmp.Diff(route, tc, cmpopts.IgnoreUnexported(RoutingContext{})); r != "" { 294 t.Error(r) 295 } 296 } 297 return nil 298 } 299 300 // Test TestRoute with special options 301 testOptions := func() error { 302 sub, err := c.Subscribe() 303 if err != nil { 304 return err 305 } 306 for _, tc := range testCases { 307 route, err := client.TestRoute(context.Background(), &TestRouteRequest{ 308 RoutingContext: tc, 309 FieldSelectors: []string{"ip", "port", "domain", "outbound"}, 310 PublishResult: true, 311 }) 312 if err != nil { 313 return err 314 } 315 stat := &RoutingContext{ // Only a subset of stats is retrieved 316 SourceIPs: tc.SourceIPs, 317 TargetIPs: tc.TargetIPs, 318 SourcePort: tc.SourcePort, 319 TargetPort: tc.TargetPort, 320 TargetDomain: tc.TargetDomain, 321 OutboundGroupTags: tc.OutboundGroupTags, 322 OutboundTag: tc.OutboundTag, 323 } 324 if r := cmp.Diff(route, stat, cmpopts.IgnoreUnexported(RoutingContext{})); r != "" { 325 t.Error(r) 326 } 327 select { // Check that routing result has been published to statistics channel 328 case msg, received := <-sub: 329 if route, ok := msg.(routing.Route); received && ok { 330 if r := cmp.Diff(AsProtobufMessage(nil)(route), tc, cmpopts.IgnoreUnexported(RoutingContext{})); r != "" { 331 t.Error(r) 332 } 333 } else { 334 t.Error("unexpected failure in receiving published routing result for testcase", tc) 335 } 336 case <-time.After(100 * time.Millisecond): 337 t.Error("unexpected failure in receiving published routing result", tc) 338 } 339 } 340 return nil 341 } 342 343 if err := testSimple(); err != nil { 344 errCh <- err 345 } 346 if err := testOptions(); err != nil { 347 errCh <- err 348 } 349 errCh <- nil // Client passed all tests successfully 350 }() 351 352 // Wait for goroutines to complete 353 select { 354 case <-time.After(2 * time.Second): 355 t.Fatal("Test timeout after 2s") 356 case err := <-errCh: 357 if err != nil { 358 t.Fatal(err) 359 } 360 } 361 }