github.com/EagleQL/Xray-core@v1.4.3/transport/internet/websocket/ws_test.go (about)

     1  package websocket_test
     2  
     3  import (
     4  	"context"
     5  	"runtime"
     6  	"testing"
     7  	"time"
     8  
     9  	"github.com/xtls/xray-core/common"
    10  	"github.com/xtls/xray-core/common/net"
    11  	"github.com/xtls/xray-core/common/protocol/tls/cert"
    12  	"github.com/xtls/xray-core/transport/internet"
    13  	"github.com/xtls/xray-core/transport/internet/tls"
    14  	. "github.com/xtls/xray-core/transport/internet/websocket"
    15  )
    16  
    17  func Test_listenWSAndDial(t *testing.T) {
    18  	listen, err := ListenWS(context.Background(), net.LocalHostIP, 13146, &internet.MemoryStreamConfig{
    19  		ProtocolName: "websocket",
    20  		ProtocolSettings: &Config{
    21  			Path: "ws",
    22  		},
    23  	}, func(conn internet.Connection) {
    24  		go func(c internet.Connection) {
    25  			defer c.Close()
    26  
    27  			var b [1024]byte
    28  			_, err := c.Read(b[:])
    29  			if err != nil {
    30  				return
    31  			}
    32  
    33  			common.Must2(c.Write([]byte("Response")))
    34  		}(conn)
    35  	})
    36  	common.Must(err)
    37  
    38  	ctx := context.Background()
    39  	streamSettings := &internet.MemoryStreamConfig{
    40  		ProtocolName:     "websocket",
    41  		ProtocolSettings: &Config{Path: "ws"},
    42  	}
    43  	conn, err := Dial(ctx, net.TCPDestination(net.DomainAddress("localhost"), 13146), streamSettings)
    44  
    45  	common.Must(err)
    46  	_, err = conn.Write([]byte("Test connection 1"))
    47  	common.Must(err)
    48  
    49  	var b [1024]byte
    50  	n, err := conn.Read(b[:])
    51  	common.Must(err)
    52  	if string(b[:n]) != "Response" {
    53  		t.Error("response: ", string(b[:n]))
    54  	}
    55  
    56  	common.Must(conn.Close())
    57  	<-time.After(time.Second * 5)
    58  	conn, err = Dial(ctx, net.TCPDestination(net.DomainAddress("localhost"), 13146), streamSettings)
    59  	common.Must(err)
    60  	_, err = conn.Write([]byte("Test connection 2"))
    61  	common.Must(err)
    62  	n, err = conn.Read(b[:])
    63  	common.Must(err)
    64  	if string(b[:n]) != "Response" {
    65  		t.Error("response: ", string(b[:n]))
    66  	}
    67  	common.Must(conn.Close())
    68  
    69  	common.Must(listen.Close())
    70  }
    71  
    72  func TestDialWithRemoteAddr(t *testing.T) {
    73  	listen, err := ListenWS(context.Background(), net.LocalHostIP, 13148, &internet.MemoryStreamConfig{
    74  		ProtocolName: "websocket",
    75  		ProtocolSettings: &Config{
    76  			Path: "ws",
    77  		},
    78  	}, func(conn internet.Connection) {
    79  		go func(c internet.Connection) {
    80  			defer c.Close()
    81  
    82  			var b [1024]byte
    83  			_, err := c.Read(b[:])
    84  			// common.Must(err)
    85  			if err != nil {
    86  				return
    87  			}
    88  
    89  			_, err = c.Write([]byte("Response"))
    90  			common.Must(err)
    91  		}(conn)
    92  	})
    93  	common.Must(err)
    94  
    95  	conn, err := Dial(context.Background(), net.TCPDestination(net.DomainAddress("localhost"), 13148), &internet.MemoryStreamConfig{
    96  		ProtocolName:     "websocket",
    97  		ProtocolSettings: &Config{Path: "ws", Header: []*Header{{Key: "X-Forwarded-For", Value: "1.1.1.1"}}},
    98  	})
    99  
   100  	common.Must(err)
   101  	_, err = conn.Write([]byte("Test connection 1"))
   102  	common.Must(err)
   103  
   104  	var b [1024]byte
   105  	n, err := conn.Read(b[:])
   106  	common.Must(err)
   107  	if string(b[:n]) != "Response" {
   108  		t.Error("response: ", string(b[:n]))
   109  	}
   110  
   111  	common.Must(listen.Close())
   112  }
   113  
   114  func Test_listenWSAndDial_TLS(t *testing.T) {
   115  	if runtime.GOARCH == "arm64" {
   116  		return
   117  	}
   118  
   119  	start := time.Now()
   120  
   121  	streamSettings := &internet.MemoryStreamConfig{
   122  		ProtocolName: "websocket",
   123  		ProtocolSettings: &Config{
   124  			Path: "wss",
   125  		},
   126  		SecurityType: "tls",
   127  		SecuritySettings: &tls.Config{
   128  			AllowInsecure: true,
   129  			Certificate:   []*tls.Certificate{tls.ParseCertificate(cert.MustGenerate(nil, cert.CommonName("localhost")))},
   130  		},
   131  	}
   132  	listen, err := ListenWS(context.Background(), net.LocalHostIP, 13143, streamSettings, func(conn internet.Connection) {
   133  		go func() {
   134  			_ = conn.Close()
   135  		}()
   136  	})
   137  	common.Must(err)
   138  	defer listen.Close()
   139  
   140  	conn, err := Dial(context.Background(), net.TCPDestination(net.DomainAddress("localhost"), 13143), streamSettings)
   141  	common.Must(err)
   142  	_ = conn.Close()
   143  
   144  	end := time.Now()
   145  	if !end.Before(start.Add(time.Second * 5)) {
   146  		t.Error("end: ", end, " start: ", start)
   147  	}
   148  }