github.com/eagleql/xray-core@v1.4.4/testing/scenarios/common.go (about)

     1  package scenarios
     2  
     3  import (
     4  	"bytes"
     5  	"crypto/rand"
     6  	"fmt"
     7  	"io"
     8  	"io/ioutil"
     9  	"os/exec"
    10  	"path/filepath"
    11  	"runtime"
    12  	"sync"
    13  	"syscall"
    14  	"time"
    15  
    16  	"github.com/eagleql/xray-core/app/dispatcher"
    17  	"github.com/eagleql/xray-core/app/proxyman"
    18  	"github.com/eagleql/xray-core/common"
    19  	"github.com/eagleql/xray-core/common/errors"
    20  	"github.com/eagleql/xray-core/common/log"
    21  	"github.com/eagleql/xray-core/common/net"
    22  	"github.com/eagleql/xray-core/common/retry"
    23  	"github.com/eagleql/xray-core/common/serial"
    24  	core "github.com/eagleql/xray-core/core"
    25  	"github.com/golang/protobuf/proto"
    26  )
    27  
    28  func xor(b []byte) []byte {
    29  	r := make([]byte, len(b))
    30  	for i, v := range b {
    31  		r[i] = v ^ 'c'
    32  	}
    33  	return r
    34  }
    35  
    36  func readFrom(conn net.Conn, timeout time.Duration, length int) []byte {
    37  	b := make([]byte, length)
    38  	deadline := time.Now().Add(timeout)
    39  	conn.SetReadDeadline(deadline)
    40  	n, err := io.ReadFull(conn, b[:length])
    41  	if err != nil {
    42  		fmt.Println("Unexpected error from readFrom:", err)
    43  	}
    44  	return b[:n]
    45  }
    46  
    47  func readFrom2(conn net.Conn, timeout time.Duration, length int) ([]byte, error) {
    48  	b := make([]byte, length)
    49  	deadline := time.Now().Add(timeout)
    50  	conn.SetReadDeadline(deadline)
    51  	n, err := io.ReadFull(conn, b[:length])
    52  	if err != nil {
    53  		return nil, err
    54  	}
    55  	return b[:n], nil
    56  }
    57  
    58  func InitializeServerConfigs(configs ...*core.Config) ([]*exec.Cmd, error) {
    59  	servers := make([]*exec.Cmd, 0, 10)
    60  
    61  	for _, config := range configs {
    62  		server, err := InitializeServerConfig(config)
    63  		if err != nil {
    64  			CloseAllServers(servers)
    65  			return nil, err
    66  		}
    67  		servers = append(servers, server)
    68  	}
    69  
    70  	time.Sleep(time.Second * 2)
    71  
    72  	return servers, nil
    73  }
    74  
    75  func InitializeServerConfig(config *core.Config) (*exec.Cmd, error) {
    76  	err := BuildXray()
    77  	if err != nil {
    78  		return nil, err
    79  	}
    80  
    81  	config = withDefaultApps(config)
    82  	configBytes, err := proto.Marshal(config)
    83  	if err != nil {
    84  		return nil, err
    85  	}
    86  	proc := RunXrayProtobuf(configBytes)
    87  
    88  	if err := proc.Start(); err != nil {
    89  		return nil, err
    90  	}
    91  
    92  	return proc, nil
    93  }
    94  
    95  var (
    96  	testBinaryPath    string
    97  	testBinaryPathGen sync.Once
    98  )
    99  
   100  func genTestBinaryPath() {
   101  	testBinaryPathGen.Do(func() {
   102  		var tempDir string
   103  		common.Must(retry.Timed(5, 100).On(func() error {
   104  			dir, err := ioutil.TempDir("", "xray")
   105  			if err != nil {
   106  				return err
   107  			}
   108  			tempDir = dir
   109  			return nil
   110  		}))
   111  		file := filepath.Join(tempDir, "xray.test")
   112  		if runtime.GOOS == "windows" {
   113  			file += ".exe"
   114  		}
   115  		testBinaryPath = file
   116  		fmt.Printf("Generated binary path: %s\n", file)
   117  	})
   118  }
   119  
   120  func GetSourcePath() string {
   121  	return filepath.Join("github.com", "xtls", "xray-core", "main")
   122  }
   123  
   124  func CloseAllServers(servers []*exec.Cmd) {
   125  	log.Record(&log.GeneralMessage{
   126  		Severity: log.Severity_Info,
   127  		Content:  "Closing all servers.",
   128  	})
   129  	for _, server := range servers {
   130  		if runtime.GOOS == "windows" {
   131  			server.Process.Kill()
   132  		} else {
   133  			server.Process.Signal(syscall.SIGTERM)
   134  		}
   135  	}
   136  	for _, server := range servers {
   137  		server.Process.Wait()
   138  	}
   139  	log.Record(&log.GeneralMessage{
   140  		Severity: log.Severity_Info,
   141  		Content:  "All server closed.",
   142  	})
   143  }
   144  
   145  func withDefaultApps(config *core.Config) *core.Config {
   146  	config.App = append(config.App, serial.ToTypedMessage(&dispatcher.Config{}))
   147  	config.App = append(config.App, serial.ToTypedMessage(&proxyman.InboundConfig{}))
   148  	config.App = append(config.App, serial.ToTypedMessage(&proxyman.OutboundConfig{}))
   149  	return config
   150  }
   151  
   152  func testTCPConn(port net.Port, payloadSize int, timeout time.Duration) func() error {
   153  	return func() error {
   154  		conn, err := net.DialTCP("tcp", nil, &net.TCPAddr{
   155  			IP:   []byte{127, 0, 0, 1},
   156  			Port: int(port),
   157  		})
   158  		if err != nil {
   159  			return err
   160  		}
   161  		defer conn.Close()
   162  
   163  		return testTCPConn2(conn, payloadSize, timeout)()
   164  	}
   165  }
   166  
   167  func testUDPConn(port net.Port, payloadSize int, timeout time.Duration) func() error {
   168  	return func() error {
   169  		conn, err := net.DialUDP("udp", nil, &net.UDPAddr{
   170  			IP:   []byte{127, 0, 0, 1},
   171  			Port: int(port),
   172  		})
   173  		if err != nil {
   174  			return err
   175  		}
   176  		defer conn.Close()
   177  
   178  		return testTCPConn2(conn, payloadSize, timeout)()
   179  	}
   180  }
   181  
   182  func testTCPConn2(conn net.Conn, payloadSize int, timeout time.Duration) func() error {
   183  	return func() error {
   184  		payload := make([]byte, payloadSize)
   185  		common.Must2(rand.Read(payload))
   186  
   187  		nBytes, err := conn.Write(payload)
   188  		if err != nil {
   189  			return err
   190  		}
   191  		if nBytes != len(payload) {
   192  			return errors.New("expect ", len(payload), " written, but actually ", nBytes)
   193  		}
   194  
   195  		response, err := readFrom2(conn, timeout, payloadSize)
   196  		if err != nil {
   197  			return err
   198  		}
   199  		_ = response
   200  
   201  		if r := bytes.Compare(response, xor(payload)); r != 0 {
   202  			return errors.New(r)
   203  		}
   204  
   205  		return nil
   206  	}
   207  }