github.com/ipfans/trojan-go@v0.11.0/tunnel/websocket/websocket_test.go (about) 1 package websocket 2 3 import ( 4 "context" 5 "fmt" 6 "net" 7 "strings" 8 "sync" 9 "testing" 10 "time" 11 12 "golang.org/x/net/websocket" 13 14 "github.com/ipfans/trojan-go/common" 15 "github.com/ipfans/trojan-go/config" 16 "github.com/ipfans/trojan-go/test/util" 17 "github.com/ipfans/trojan-go/tunnel" 18 "github.com/ipfans/trojan-go/tunnel/freedom" 19 "github.com/ipfans/trojan-go/tunnel/transport" 20 ) 21 22 func TestWebsocket(t *testing.T) { 23 cfg := &Config{ 24 Websocket: WebsocketConfig{ 25 Enabled: true, 26 Host: "localhost", 27 Path: "/ws", 28 }, 29 } 30 31 ctx := config.WithConfig(context.Background(), Name, cfg) 32 33 port := common.PickPort("tcp", "127.0.0.1") 34 transportConfig := &transport.Config{ 35 LocalHost: "127.0.0.1", 36 LocalPort: port, 37 RemoteHost: "127.0.0.1", 38 RemotePort: port, 39 } 40 freedomCfg := &freedom.Config{} 41 ctx = config.WithConfig(ctx, transport.Name, transportConfig) 42 ctx = config.WithConfig(ctx, freedom.Name, freedomCfg) 43 tcpClient, err := transport.NewClient(ctx, nil) 44 common.Must(err) 45 tcpServer, err := transport.NewServer(ctx, nil) 46 common.Must(err) 47 48 c, err := NewClient(ctx, tcpClient) 49 common.Must(err) 50 s, err := NewServer(ctx, tcpServer) 51 var conn2 tunnel.Conn 52 wg := sync.WaitGroup{} 53 wg.Add(1) 54 go func() { 55 conn2, err = s.AcceptConn(nil) 56 common.Must(err) 57 wg.Done() 58 }() 59 time.Sleep(time.Second) 60 conn1, err := c.DialConn(nil, nil) 61 common.Must(err) 62 wg.Wait() 63 if !util.CheckConn(conn1, conn2) { 64 t.Fail() 65 } 66 67 if strings.HasPrefix(conn1.RemoteAddr().String(), "ws") { 68 t.Fail() 69 } 70 if strings.HasPrefix(conn2.RemoteAddr().String(), "ws") { 71 t.Fail() 72 } 73 74 conn1.Close() 75 conn2.Close() 76 s.Close() 77 c.Close() 78 } 79 80 func TestRedirect(t *testing.T) { 81 cfg := &Config{ 82 RemoteHost: "127.0.0.1", 83 Websocket: WebsocketConfig{ 84 Enabled: true, 85 Host: "localhost", 86 Path: "/ws", 87 }, 88 } 89 fmt.Sscanf(util.HTTPPort, "%d", &cfg.RemotePort) 90 ctx := config.WithConfig(context.Background(), Name, cfg) 91 92 port := common.PickPort("tcp", "127.0.0.1") 93 transportConfig := &transport.Config{ 94 LocalHost: "127.0.0.1", 95 LocalPort: port, 96 } 97 ctx = config.WithConfig(ctx, transport.Name, transportConfig) 98 tcpServer, err := transport.NewServer(ctx, nil) 99 common.Must(err) 100 101 s, err := NewServer(ctx, tcpServer) 102 common.Must(err) 103 104 go func() { 105 _, err := s.AcceptConn(nil) 106 if err == nil { 107 t.Fail() 108 } 109 }() 110 time.Sleep(time.Second) 111 conn, err := net.Dial("tcp", fmt.Sprintf("127.0.0.1:%d", port)) 112 common.Must(err) 113 url := "wss://localhost/wrong-path" 114 origin := "https://localhost" 115 wsConfig, err := websocket.NewConfig(url, origin) 116 common.Must(err) 117 _, err = websocket.NewClient(wsConfig, conn) 118 if err == nil { 119 t.Fail() 120 } 121 conn.Close() 122 123 s.Close() 124 }