github.com/xraypb/Xray-core@v1.8.1/testing/scenarios/common.go (about) 1 package scenarios 2 3 import ( 4 "bytes" 5 "crypto/rand" 6 "fmt" 7 "io" 8 "os" 9 "os/exec" 10 "path/filepath" 11 "runtime" 12 "sync" 13 "syscall" 14 "testing" 15 "time" 16 17 "github.com/golang/protobuf/proto" 18 "github.com/xraypb/Xray-core/app/dispatcher" 19 "github.com/xraypb/Xray-core/app/proxyman" 20 "github.com/xraypb/Xray-core/common" 21 "github.com/xraypb/Xray-core/common/errors" 22 "github.com/xraypb/Xray-core/common/log" 23 "github.com/xraypb/Xray-core/common/net" 24 "github.com/xraypb/Xray-core/common/retry" 25 "github.com/xraypb/Xray-core/common/serial" 26 "github.com/xraypb/Xray-core/common/units" 27 core "github.com/xraypb/Xray-core/core" 28 ) 29 30 func xor(b []byte) []byte { 31 r := make([]byte, len(b)) 32 for i, v := range b { 33 r[i] = v ^ 'c' 34 } 35 return r 36 } 37 38 func readFrom(conn net.Conn, timeout time.Duration, length int) []byte { 39 b := make([]byte, length) 40 deadline := time.Now().Add(timeout) 41 conn.SetReadDeadline(deadline) 42 n, err := io.ReadFull(conn, b[:length]) 43 if err != nil { 44 fmt.Println("Unexpected error from readFrom:", err) 45 } 46 return b[:n] 47 } 48 49 func readFrom2(conn net.Conn, timeout time.Duration, length int) ([]byte, error) { 50 b := make([]byte, length) 51 deadline := time.Now().Add(timeout) 52 conn.SetReadDeadline(deadline) 53 n, err := io.ReadFull(conn, b[:length]) 54 if err != nil { 55 return nil, err 56 } 57 return b[:n], nil 58 } 59 60 func InitializeServerConfigs(configs ...*core.Config) ([]*exec.Cmd, error) { 61 servers := make([]*exec.Cmd, 0, 10) 62 63 for _, config := range configs { 64 server, err := InitializeServerConfig(config) 65 if err != nil { 66 CloseAllServers(servers) 67 return nil, err 68 } 69 servers = append(servers, server) 70 } 71 72 time.Sleep(time.Second * 2) 73 74 return servers, nil 75 } 76 77 func InitializeServerConfig(config *core.Config) (*exec.Cmd, error) { 78 err := BuildXray() 79 if err != nil { 80 return nil, err 81 } 82 83 config = withDefaultApps(config) 84 configBytes, err := proto.Marshal(config) 85 if err != nil { 86 return nil, err 87 } 88 proc := RunXrayProtobuf(configBytes) 89 90 if err := proc.Start(); err != nil { 91 return nil, err 92 } 93 94 return proc, nil 95 } 96 97 var ( 98 testBinaryPath string 99 testBinaryPathGen sync.Once 100 ) 101 102 func genTestBinaryPath() { 103 testBinaryPathGen.Do(func() { 104 var tempDir string 105 common.Must(retry.Timed(5, 100).On(func() error { 106 dir, err := os.MkdirTemp("", "xray") 107 if err != nil { 108 return err 109 } 110 tempDir = dir 111 return nil 112 })) 113 file := filepath.Join(tempDir, "xray.test") 114 if runtime.GOOS == "windows" { 115 file += ".exe" 116 } 117 testBinaryPath = file 118 fmt.Printf("Generated binary path: %s\n", file) 119 }) 120 } 121 122 func GetSourcePath() string { 123 return filepath.Join("github.com", "xtls", "xray-core", "main") 124 } 125 126 func CloseAllServers(servers []*exec.Cmd) { 127 log.Record(&log.GeneralMessage{ 128 Severity: log.Severity_Info, 129 Content: "Closing all servers.", 130 }) 131 for _, server := range servers { 132 if runtime.GOOS == "windows" { 133 server.Process.Kill() 134 } else { 135 server.Process.Signal(syscall.SIGTERM) 136 } 137 } 138 for _, server := range servers { 139 server.Process.Wait() 140 } 141 log.Record(&log.GeneralMessage{ 142 Severity: log.Severity_Info, 143 Content: "All server closed.", 144 }) 145 } 146 147 func CloseServer(server *exec.Cmd) { 148 log.Record(&log.GeneralMessage{ 149 Severity: log.Severity_Info, 150 Content: "Closing server.", 151 }) 152 if runtime.GOOS == "windows" { 153 server.Process.Kill() 154 } else { 155 server.Process.Signal(syscall.SIGTERM) 156 } 157 server.Process.Wait() 158 log.Record(&log.GeneralMessage{ 159 Severity: log.Severity_Info, 160 Content: "Server closed.", 161 }) 162 } 163 164 func withDefaultApps(config *core.Config) *core.Config { 165 config.App = append(config.App, serial.ToTypedMessage(&dispatcher.Config{})) 166 config.App = append(config.App, serial.ToTypedMessage(&proxyman.InboundConfig{})) 167 config.App = append(config.App, serial.ToTypedMessage(&proxyman.OutboundConfig{})) 168 return config 169 } 170 171 func testTCPConn(port net.Port, payloadSize int, timeout time.Duration) func() error { 172 return func() error { 173 conn, err := net.DialTCP("tcp", nil, &net.TCPAddr{ 174 IP: []byte{127, 0, 0, 1}, 175 Port: int(port), 176 }) 177 if err != nil { 178 return err 179 } 180 defer conn.Close() 181 182 return testTCPConn2(conn, payloadSize, timeout)() 183 } 184 } 185 186 func testUDPConn(port net.Port, payloadSize int, timeout time.Duration) func() error { 187 return func() error { 188 conn, err := net.DialUDP("udp", nil, &net.UDPAddr{ 189 IP: []byte{127, 0, 0, 1}, 190 Port: int(port), 191 }) 192 if err != nil { 193 return err 194 } 195 defer conn.Close() 196 197 return testTCPConn2(conn, payloadSize, timeout)() 198 } 199 } 200 201 func testTCPConn2(conn net.Conn, payloadSize int, timeout time.Duration) func() error { 202 return func() (err1 error) { 203 start := time.Now() 204 defer func() { 205 var m runtime.MemStats 206 runtime.ReadMemStats(&m) 207 // For info on each, see: https://golang.org/pkg/runtime/#MemStats 208 fmt.Println("testConn finishes:", time.Since(start).Milliseconds(), "ms\t", 209 err1, "\tAlloc =", units.ByteSize(m.Alloc).String(), 210 "\tTotalAlloc =", units.ByteSize(m.TotalAlloc).String(), 211 "\tSys =", units.ByteSize(m.Sys).String(), 212 "\tNumGC =", m.NumGC) 213 }() 214 payload := make([]byte, payloadSize) 215 common.Must2(rand.Read(payload)) 216 217 nBytes, err := conn.Write(payload) 218 if err != nil { 219 return err 220 } 221 if nBytes != len(payload) { 222 return errors.New("expect ", len(payload), " written, but actually ", nBytes) 223 } 224 225 response, err := readFrom2(conn, timeout, payloadSize) 226 if err != nil { 227 return err 228 } 229 _ = response 230 231 if r := bytes.Compare(response, xor(payload)); r != 0 { 232 return errors.New(r) 233 } 234 235 return nil 236 } 237 } 238 239 func WaitConnAvailableWithTest(t *testing.T, testFunc func() error) bool { 240 for i := 1; ; i++ { 241 if i > 10 { 242 t.Log("All attempts failed to test tcp conn") 243 return false 244 } 245 time.Sleep(time.Millisecond * 10) 246 if err := testFunc(); err != nil { 247 t.Log("err ", err) 248 } else { 249 t.Log("success with", i, "attempts") 250 break 251 } 252 } 253 return true 254 }