github.com/xmplusdev/xray-core@v1.8.10/transport/internet/udp/dispatcher_test.go (about)

     1  package udp_test
     2  
     3  import (
     4  	"context"
     5  	"sync/atomic"
     6  	"testing"
     7  	"time"
     8  
     9  	"github.com/xmplusdev/xray-core/common"
    10  	"github.com/xmplusdev/xray-core/common/buf"
    11  	"github.com/xmplusdev/xray-core/common/net"
    12  	"github.com/xmplusdev/xray-core/common/protocol/udp"
    13  	"github.com/xmplusdev/xray-core/features/routing"
    14  	"github.com/xmplusdev/xray-core/transport"
    15  	. "github.com/xmplusdev/xray-core/transport/internet/udp"
    16  	"github.com/xmplusdev/xray-core/transport/pipe"
    17  )
    18  
    19  type TestDispatcher struct {
    20  	OnDispatch func(ctx context.Context, dest net.Destination) (*transport.Link, error)
    21  }
    22  
    23  func (d *TestDispatcher) Dispatch(ctx context.Context, dest net.Destination) (*transport.Link, error) {
    24  	return d.OnDispatch(ctx, dest)
    25  }
    26  
    27  func (d *TestDispatcher) DispatchLink(ctx context.Context, destination net.Destination, outbound *transport.Link) error {
    28  	return nil
    29  }
    30  
    31  func (d *TestDispatcher) Start() error {
    32  	return nil
    33  }
    34  
    35  func (d *TestDispatcher) Close() error {
    36  	return nil
    37  }
    38  
    39  func (*TestDispatcher) Type() interface{} {
    40  	return routing.DispatcherType()
    41  }
    42  
    43  func TestSameDestinationDispatching(t *testing.T) {
    44  	ctx, cancel := context.WithCancel(context.Background())
    45  	uplinkReader, uplinkWriter := pipe.New(pipe.WithSizeLimit(1024))
    46  	downlinkReader, downlinkWriter := pipe.New(pipe.WithSizeLimit(1024))
    47  
    48  	go func() {
    49  		for {
    50  			data, err := uplinkReader.ReadMultiBuffer()
    51  			if err != nil {
    52  				break
    53  			}
    54  			err = downlinkWriter.WriteMultiBuffer(data)
    55  			common.Must(err)
    56  		}
    57  	}()
    58  
    59  	var count uint32
    60  	td := &TestDispatcher{
    61  		OnDispatch: func(ctx context.Context, dest net.Destination) (*transport.Link, error) {
    62  			atomic.AddUint32(&count, 1)
    63  			return &transport.Link{Reader: downlinkReader, Writer: uplinkWriter}, nil
    64  		},
    65  	}
    66  	dest := net.UDPDestination(net.LocalHostIP, 53)
    67  
    68  	b := buf.New()
    69  	b.WriteString("abcd")
    70  
    71  	var msgCount uint32
    72  	dispatcher := NewDispatcher(td, func(ctx context.Context, packet *udp.Packet) {
    73  		atomic.AddUint32(&msgCount, 1)
    74  	})
    75  
    76  	dispatcher.Dispatch(ctx, dest, b)
    77  	for i := 0; i < 5; i++ {
    78  		dispatcher.Dispatch(ctx, dest, b)
    79  	}
    80  
    81  	time.Sleep(time.Second)
    82  	cancel()
    83  
    84  	if count != 1 {
    85  		t.Error("count: ", count)
    86  	}
    87  	if v := atomic.LoadUint32(&msgCount); v != 6 {
    88  		t.Error("msgCount: ", v)
    89  	}
    90  }