github.com/xmplusdev/xmcore@v1.8.11-0.20240412132628-5518b55526af/testing/scenarios/common.go (about)

     1  package scenarios
     2  
     3  import (
     4  	"bytes"
     5  	"crypto/rand"
     6  	"fmt"
     7  	"io"
     8  	"os"
     9  	"os/exec"
    10  	"path/filepath"
    11  	"runtime"
    12  	"sync"
    13  	"syscall"
    14  	"testing"
    15  	"time"
    16  
    17  	"github.com/xmplusdev/xmcore/app/dispatcher"
    18  	"github.com/xmplusdev/xmcore/app/proxyman"
    19  	"github.com/xmplusdev/xmcore/common"
    20  	"github.com/xmplusdev/xmcore/common/errors"
    21  	"github.com/xmplusdev/xmcore/common/log"
    22  	"github.com/xmplusdev/xmcore/common/net"
    23  	"github.com/xmplusdev/xmcore/common/retry"
    24  	"github.com/xmplusdev/xmcore/common/serial"
    25  	"github.com/xmplusdev/xmcore/common/units"
    26  	core "github.com/xmplusdev/xmcore/core"
    27  	"google.golang.org/protobuf/proto"
    28  )
    29  
    30  func xor(b []byte) []byte {
    31  	r := make([]byte, len(b))
    32  	for i, v := range b {
    33  		r[i] = v ^ 'c'
    34  	}
    35  	return r
    36  }
    37  
    38  func readFrom(conn net.Conn, timeout time.Duration, length int) []byte {
    39  	b := make([]byte, length)
    40  	deadline := time.Now().Add(timeout)
    41  	conn.SetReadDeadline(deadline)
    42  	n, err := io.ReadFull(conn, b[:length])
    43  	if err != nil {
    44  		fmt.Println("Unexpected error from readFrom:", err)
    45  	}
    46  	return b[:n]
    47  }
    48  
    49  func readFrom2(conn net.Conn, timeout time.Duration, length int) ([]byte, error) {
    50  	b := make([]byte, length)
    51  	deadline := time.Now().Add(timeout)
    52  	conn.SetReadDeadline(deadline)
    53  	n, err := io.ReadFull(conn, b[:length])
    54  	if err != nil {
    55  		return nil, err
    56  	}
    57  	return b[:n], nil
    58  }
    59  
    60  func InitializeServerConfigs(configs ...*core.Config) ([]*exec.Cmd, error) {
    61  	servers := make([]*exec.Cmd, 0, 10)
    62  
    63  	for _, config := range configs {
    64  		server, err := InitializeServerConfig(config)
    65  		if err != nil {
    66  			CloseAllServers(servers)
    67  			return nil, err
    68  		}
    69  		servers = append(servers, server)
    70  	}
    71  
    72  	time.Sleep(time.Second * 2)
    73  
    74  	return servers, nil
    75  }
    76  
    77  func InitializeServerConfig(config *core.Config) (*exec.Cmd, error) {
    78  	err := BuildXray()
    79  	if err != nil {
    80  		return nil, err
    81  	}
    82  
    83  	config = withDefaultApps(config)
    84  	configBytes, err := proto.Marshal(config)
    85  	if err != nil {
    86  		return nil, err
    87  	}
    88  	proc := RunXrayProtobuf(configBytes)
    89  
    90  	if err := proc.Start(); err != nil {
    91  		return nil, err
    92  	}
    93  
    94  	return proc, nil
    95  }
    96  
    97  var (
    98  	testBinaryPath    string
    99  	testBinaryPathGen sync.Once
   100  )
   101  
   102  func genTestBinaryPath() {
   103  	testBinaryPathGen.Do(func() {
   104  		var tempDir string
   105  		common.Must(retry.Timed(5, 100).On(func() error {
   106  			dir, err := os.MkdirTemp("", "xray")
   107  			if err != nil {
   108  				return err
   109  			}
   110  			tempDir = dir
   111  			return nil
   112  		}))
   113  		file := filepath.Join(tempDir, "xray.test")
   114  		if runtime.GOOS == "windows" {
   115  			file += ".exe"
   116  		}
   117  		testBinaryPath = file
   118  		fmt.Printf("Generated binary path: %s\n", file)
   119  	})
   120  }
   121  
   122  func GetSourcePath() string {
   123  	return filepath.Join("github.com", "xtls", "xray-core", "main")
   124  }
   125  
   126  func CloseAllServers(servers []*exec.Cmd) {
   127  	log.Record(&log.GeneralMessage{
   128  		Severity: log.Severity_Info,
   129  		Content:  "Closing all servers.",
   130  	})
   131  	for _, server := range servers {
   132  		if runtime.GOOS == "windows" {
   133  			server.Process.Kill()
   134  		} else {
   135  			server.Process.Signal(syscall.SIGTERM)
   136  		}
   137  	}
   138  	for _, server := range servers {
   139  		server.Process.Wait()
   140  	}
   141  	log.Record(&log.GeneralMessage{
   142  		Severity: log.Severity_Info,
   143  		Content:  "All server closed.",
   144  	})
   145  }
   146  
   147  func CloseServer(server *exec.Cmd) {
   148  	log.Record(&log.GeneralMessage{
   149  		Severity: log.Severity_Info,
   150  		Content:  "Closing server.",
   151  	})
   152  	if runtime.GOOS == "windows" {
   153  		server.Process.Kill()
   154  	} else {
   155  		server.Process.Signal(syscall.SIGTERM)
   156  	}
   157  	server.Process.Wait()
   158  	log.Record(&log.GeneralMessage{
   159  		Severity: log.Severity_Info,
   160  		Content:  "Server closed.",
   161  	})
   162  }
   163  
   164  func withDefaultApps(config *core.Config) *core.Config {
   165  	config.App = append(config.App, serial.ToTypedMessage(&dispatcher.Config{}))
   166  	config.App = append(config.App, serial.ToTypedMessage(&proxyman.InboundConfig{}))
   167  	config.App = append(config.App, serial.ToTypedMessage(&proxyman.OutboundConfig{}))
   168  	return config
   169  }
   170  
   171  func testTCPConn(port net.Port, payloadSize int, timeout time.Duration) func() error {
   172  	return func() error {
   173  		conn, err := net.DialTCP("tcp", nil, &net.TCPAddr{
   174  			IP:   []byte{127, 0, 0, 1},
   175  			Port: int(port),
   176  		})
   177  		if err != nil {
   178  			return err
   179  		}
   180  		defer conn.Close()
   181  
   182  		return testTCPConn2(conn, payloadSize, timeout)()
   183  	}
   184  }
   185  
   186  func testUDPConn(port net.Port, payloadSize int, timeout time.Duration) func() error {
   187  	return func() error {
   188  		conn, err := net.DialUDP("udp", nil, &net.UDPAddr{
   189  			IP:   []byte{127, 0, 0, 1},
   190  			Port: int(port),
   191  		})
   192  		if err != nil {
   193  			return err
   194  		}
   195  		defer conn.Close()
   196  
   197  		return testTCPConn2(conn, payloadSize, timeout)()
   198  	}
   199  }
   200  
   201  func testTCPConn2(conn net.Conn, payloadSize int, timeout time.Duration) func() error {
   202  	return func() (err1 error) {
   203  		start := time.Now()
   204  		defer func() {
   205  			var m runtime.MemStats
   206  			runtime.ReadMemStats(&m)
   207  			// For info on each, see: https://golang.org/pkg/runtime/#MemStats
   208  			fmt.Println("testConn finishes:", time.Since(start).Milliseconds(), "ms\t",
   209  				err1, "\tAlloc =", units.ByteSize(m.Alloc).String(),
   210  				"\tTotalAlloc =", units.ByteSize(m.TotalAlloc).String(),
   211  				"\tSys =", units.ByteSize(m.Sys).String(),
   212  				"\tNumGC =", m.NumGC)
   213  		}()
   214  		payload := make([]byte, payloadSize)
   215  		common.Must2(rand.Read(payload))
   216  
   217  		nBytes, err := conn.Write(payload)
   218  		if err != nil {
   219  			return err
   220  		}
   221  		if nBytes != len(payload) {
   222  			return errors.New("expect ", len(payload), " written, but actually ", nBytes)
   223  		}
   224  
   225  		response, err := readFrom2(conn, timeout, payloadSize)
   226  		if err != nil {
   227  			return err
   228  		}
   229  		_ = response
   230  
   231  		if r := bytes.Compare(response, xor(payload)); r != 0 {
   232  			return errors.New(r)
   233  		}
   234  
   235  		return nil
   236  	}
   237  }
   238  
   239  func WaitConnAvailableWithTest(t *testing.T, testFunc func() error) bool {
   240  	for i := 1; ; i++ {
   241  		if i > 10 {
   242  			t.Log("All attempts failed to test tcp conn")
   243  			return false
   244  		}
   245  		time.Sleep(time.Millisecond * 10)
   246  		if err := testFunc(); err != nil {
   247  			t.Log("err ", err)
   248  		} else {
   249  			t.Log("success with", i, "attempts")
   250  			break
   251  		}
   252  	}
   253  	return true
   254  }