github.com/ipfans/trojan-go@v0.11.0/tunnel/websocket/websocket_test.go (about)

     1  package websocket
     2  
     3  import (
     4  	"context"
     5  	"fmt"
     6  	"net"
     7  	"strings"
     8  	"sync"
     9  	"testing"
    10  	"time"
    11  
    12  	"golang.org/x/net/websocket"
    13  
    14  	"github.com/ipfans/trojan-go/common"
    15  	"github.com/ipfans/trojan-go/config"
    16  	"github.com/ipfans/trojan-go/test/util"
    17  	"github.com/ipfans/trojan-go/tunnel"
    18  	"github.com/ipfans/trojan-go/tunnel/freedom"
    19  	"github.com/ipfans/trojan-go/tunnel/transport"
    20  )
    21  
    22  func TestWebsocket(t *testing.T) {
    23  	cfg := &Config{
    24  		Websocket: WebsocketConfig{
    25  			Enabled: true,
    26  			Host:    "localhost",
    27  			Path:    "/ws",
    28  		},
    29  	}
    30  
    31  	ctx := config.WithConfig(context.Background(), Name, cfg)
    32  
    33  	port := common.PickPort("tcp", "127.0.0.1")
    34  	transportConfig := &transport.Config{
    35  		LocalHost:  "127.0.0.1",
    36  		LocalPort:  port,
    37  		RemoteHost: "127.0.0.1",
    38  		RemotePort: port,
    39  	}
    40  	freedomCfg := &freedom.Config{}
    41  	ctx = config.WithConfig(ctx, transport.Name, transportConfig)
    42  	ctx = config.WithConfig(ctx, freedom.Name, freedomCfg)
    43  	tcpClient, err := transport.NewClient(ctx, nil)
    44  	common.Must(err)
    45  	tcpServer, err := transport.NewServer(ctx, nil)
    46  	common.Must(err)
    47  
    48  	c, err := NewClient(ctx, tcpClient)
    49  	common.Must(err)
    50  	s, err := NewServer(ctx, tcpServer)
    51  	var conn2 tunnel.Conn
    52  	wg := sync.WaitGroup{}
    53  	wg.Add(1)
    54  	go func() {
    55  		conn2, err = s.AcceptConn(nil)
    56  		common.Must(err)
    57  		wg.Done()
    58  	}()
    59  	time.Sleep(time.Second)
    60  	conn1, err := c.DialConn(nil, nil)
    61  	common.Must(err)
    62  	wg.Wait()
    63  	if !util.CheckConn(conn1, conn2) {
    64  		t.Fail()
    65  	}
    66  
    67  	if strings.HasPrefix(conn1.RemoteAddr().String(), "ws") {
    68  		t.Fail()
    69  	}
    70  	if strings.HasPrefix(conn2.RemoteAddr().String(), "ws") {
    71  		t.Fail()
    72  	}
    73  
    74  	conn1.Close()
    75  	conn2.Close()
    76  	s.Close()
    77  	c.Close()
    78  }
    79  
    80  func TestRedirect(t *testing.T) {
    81  	cfg := &Config{
    82  		RemoteHost: "127.0.0.1",
    83  		Websocket: WebsocketConfig{
    84  			Enabled: true,
    85  			Host:    "localhost",
    86  			Path:    "/ws",
    87  		},
    88  	}
    89  	fmt.Sscanf(util.HTTPPort, "%d", &cfg.RemotePort)
    90  	ctx := config.WithConfig(context.Background(), Name, cfg)
    91  
    92  	port := common.PickPort("tcp", "127.0.0.1")
    93  	transportConfig := &transport.Config{
    94  		LocalHost: "127.0.0.1",
    95  		LocalPort: port,
    96  	}
    97  	ctx = config.WithConfig(ctx, transport.Name, transportConfig)
    98  	tcpServer, err := transport.NewServer(ctx, nil)
    99  	common.Must(err)
   100  
   101  	s, err := NewServer(ctx, tcpServer)
   102  	common.Must(err)
   103  
   104  	go func() {
   105  		_, err := s.AcceptConn(nil)
   106  		if err == nil {
   107  			t.Fail()
   108  		}
   109  	}()
   110  	time.Sleep(time.Second)
   111  	conn, err := net.Dial("tcp", fmt.Sprintf("127.0.0.1:%d", port))
   112  	common.Must(err)
   113  	url := "wss://localhost/wrong-path"
   114  	origin := "https://localhost"
   115  	wsConfig, err := websocket.NewConfig(url, origin)
   116  	common.Must(err)
   117  	_, err = websocket.NewClient(wsConfig, conn)
   118  	if err == nil {
   119  		t.Fail()
   120  	}
   121  	conn.Close()
   122  
   123  	s.Close()
   124  }