github.com/xmplusdev/xmcore@v1.8.11-0.20240412132628-5518b55526af/transport/internet/websocket/ws_test.go (about)

     1  package websocket_test
     2  
     3  import (
     4  	"context"
     5  	"runtime"
     6  	"testing"
     7  	"time"
     8  
     9  	"github.com/xmplusdev/xmcore/common"
    10  	"github.com/xmplusdev/xmcore/common/net"
    11  	"github.com/xmplusdev/xmcore/common/protocol/tls/cert"
    12  	"github.com/xmplusdev/xmcore/transport/internet"
    13  	"github.com/xmplusdev/xmcore/transport/internet/stat"
    14  	"github.com/xmplusdev/xmcore/transport/internet/tls"
    15  	. "github.com/xmplusdev/xmcore/transport/internet/websocket"
    16  )
    17  
    18  func Test_listenWSAndDial(t *testing.T) {
    19  	listen, err := ListenWS(context.Background(), net.LocalHostIP, 13146, &internet.MemoryStreamConfig{
    20  		ProtocolName: "websocket",
    21  		ProtocolSettings: &Config{
    22  			Path: "ws",
    23  		},
    24  	}, func(conn stat.Connection) {
    25  		go func(c stat.Connection) {
    26  			defer c.Close()
    27  
    28  			var b [1024]byte
    29  			_, err := c.Read(b[:])
    30  			if err != nil {
    31  				return
    32  			}
    33  
    34  			common.Must2(c.Write([]byte("Response")))
    35  		}(conn)
    36  	})
    37  	common.Must(err)
    38  
    39  	ctx := context.Background()
    40  	streamSettings := &internet.MemoryStreamConfig{
    41  		ProtocolName:     "websocket",
    42  		ProtocolSettings: &Config{Path: "ws"},
    43  	}
    44  	conn, err := Dial(ctx, net.TCPDestination(net.DomainAddress("localhost"), 13146), streamSettings)
    45  
    46  	common.Must(err)
    47  	_, err = conn.Write([]byte("Test connection 1"))
    48  	common.Must(err)
    49  
    50  	var b [1024]byte
    51  	n, err := conn.Read(b[:])
    52  	common.Must(err)
    53  	if string(b[:n]) != "Response" {
    54  		t.Error("response: ", string(b[:n]))
    55  	}
    56  
    57  	common.Must(conn.Close())
    58  	<-time.After(time.Second * 5)
    59  	conn, err = Dial(ctx, net.TCPDestination(net.DomainAddress("localhost"), 13146), streamSettings)
    60  	common.Must(err)
    61  	_, err = conn.Write([]byte("Test connection 2"))
    62  	common.Must(err)
    63  	n, err = conn.Read(b[:])
    64  	common.Must(err)
    65  	if string(b[:n]) != "Response" {
    66  		t.Error("response: ", string(b[:n]))
    67  	}
    68  	common.Must(conn.Close())
    69  
    70  	common.Must(listen.Close())
    71  }
    72  
    73  func TestDialWithRemoteAddr(t *testing.T) {
    74  	listen, err := ListenWS(context.Background(), net.LocalHostIP, 13148, &internet.MemoryStreamConfig{
    75  		ProtocolName: "websocket",
    76  		ProtocolSettings: &Config{
    77  			Path: "ws",
    78  		},
    79  	}, func(conn stat.Connection) {
    80  		go func(c stat.Connection) {
    81  			defer c.Close()
    82  
    83  			var b [1024]byte
    84  			_, err := c.Read(b[:])
    85  			// common.Must(err)
    86  			if err != nil {
    87  				return
    88  			}
    89  
    90  			_, err = c.Write([]byte("Response"))
    91  			common.Must(err)
    92  		}(conn)
    93  	})
    94  	common.Must(err)
    95  
    96  	conn, err := Dial(context.Background(), net.TCPDestination(net.DomainAddress("localhost"), 13148), &internet.MemoryStreamConfig{
    97  		ProtocolName:     "websocket",
    98  		ProtocolSettings: &Config{Path: "ws", Header: map[string]string{"X-Forwarded-For": "1.1.1.1"}},
    99  	})
   100  
   101  	common.Must(err)
   102  	_, err = conn.Write([]byte("Test connection 1"))
   103  	common.Must(err)
   104  
   105  	var b [1024]byte
   106  	n, err := conn.Read(b[:])
   107  	common.Must(err)
   108  	if string(b[:n]) != "Response" {
   109  		t.Error("response: ", string(b[:n]))
   110  	}
   111  
   112  	common.Must(listen.Close())
   113  }
   114  
   115  func Test_listenWSAndDial_TLS(t *testing.T) {
   116  	if runtime.GOARCH == "arm64" {
   117  		return
   118  	}
   119  
   120  	start := time.Now()
   121  
   122  	streamSettings := &internet.MemoryStreamConfig{
   123  		ProtocolName: "websocket",
   124  		ProtocolSettings: &Config{
   125  			Path: "wss",
   126  		},
   127  		SecurityType: "tls",
   128  		SecuritySettings: &tls.Config{
   129  			AllowInsecure: true,
   130  			Certificate:   []*tls.Certificate{tls.ParseCertificate(cert.MustGenerate(nil, cert.CommonName("localhost")))},
   131  		},
   132  	}
   133  	listen, err := ListenWS(context.Background(), net.LocalHostIP, 13143, streamSettings, func(conn stat.Connection) {
   134  		go func() {
   135  			_ = conn.Close()
   136  		}()
   137  	})
   138  	common.Must(err)
   139  	defer listen.Close()
   140  
   141  	conn, err := Dial(context.Background(), net.TCPDestination(net.DomainAddress("localhost"), 13143), streamSettings)
   142  	common.Must(err)
   143  	_ = conn.Close()
   144  
   145  	end := time.Now()
   146  	if !end.Before(start.Add(time.Second * 5)) {
   147  		t.Error("end: ", end, " start: ", start)
   148  	}
   149  }