github.com/EagleQL/Xray-core@v1.4.3/transport/internet/udp/dispatcher_test.go (about)

     1  package udp_test
     2  
     3  import (
     4  	"context"
     5  	"sync/atomic"
     6  	"testing"
     7  	"time"
     8  
     9  	"github.com/xtls/xray-core/common"
    10  	"github.com/xtls/xray-core/common/buf"
    11  	"github.com/xtls/xray-core/common/net"
    12  	"github.com/xtls/xray-core/common/protocol/udp"
    13  	"github.com/xtls/xray-core/features/routing"
    14  	"github.com/xtls/xray-core/transport"
    15  	. "github.com/xtls/xray-core/transport/internet/udp"
    16  	"github.com/xtls/xray-core/transport/pipe"
    17  )
    18  
    19  type TestDispatcher struct {
    20  	OnDispatch func(ctx context.Context, dest net.Destination) (*transport.Link, error)
    21  }
    22  
    23  func (d *TestDispatcher) Dispatch(ctx context.Context, dest net.Destination) (*transport.Link, error) {
    24  	return d.OnDispatch(ctx, dest)
    25  }
    26  
    27  func (d *TestDispatcher) Start() error {
    28  	return nil
    29  }
    30  
    31  func (d *TestDispatcher) Close() error {
    32  	return nil
    33  }
    34  
    35  func (*TestDispatcher) Type() interface{} {
    36  	return routing.DispatcherType()
    37  }
    38  
    39  func TestSameDestinationDispatching(t *testing.T) {
    40  	ctx, cancel := context.WithCancel(context.Background())
    41  	uplinkReader, uplinkWriter := pipe.New(pipe.WithSizeLimit(1024))
    42  	downlinkReader, downlinkWriter := pipe.New(pipe.WithSizeLimit(1024))
    43  
    44  	go func() {
    45  		for {
    46  			data, err := uplinkReader.ReadMultiBuffer()
    47  			if err != nil {
    48  				break
    49  			}
    50  			err = downlinkWriter.WriteMultiBuffer(data)
    51  			common.Must(err)
    52  		}
    53  	}()
    54  
    55  	var count uint32
    56  	td := &TestDispatcher{
    57  		OnDispatch: func(ctx context.Context, dest net.Destination) (*transport.Link, error) {
    58  			atomic.AddUint32(&count, 1)
    59  			return &transport.Link{Reader: downlinkReader, Writer: uplinkWriter}, nil
    60  		},
    61  	}
    62  	dest := net.UDPDestination(net.LocalHostIP, 53)
    63  
    64  	b := buf.New()
    65  	b.WriteString("abcd")
    66  
    67  	var msgCount uint32
    68  	dispatcher := NewDispatcher(td, func(ctx context.Context, packet *udp.Packet) {
    69  		atomic.AddUint32(&msgCount, 1)
    70  	})
    71  
    72  	dispatcher.Dispatch(ctx, dest, b)
    73  	for i := 0; i < 5; i++ {
    74  		dispatcher.Dispatch(ctx, dest, b)
    75  	}
    76  
    77  	time.Sleep(time.Second)
    78  	cancel()
    79  
    80  	if count != 1 {
    81  		t.Error("count: ", count)
    82  	}
    83  	if v := atomic.LoadUint32(&msgCount); v != 6 {
    84  		t.Error("msgCount: ", v)
    85  	}
    86  }