github.com/v2fly/v2ray-core/v4@v4.45.2/testing/scenarios/common.go (about)

     1  package scenarios
     2  
     3  import (
     4  	"bytes"
     5  	"crypto/rand"
     6  	"fmt"
     7  	"io"
     8  	"os"
     9  	"os/exec"
    10  	"path/filepath"
    11  	"runtime"
    12  	"sync"
    13  	"syscall"
    14  	"time"
    15  
    16  	"github.com/golang/protobuf/proto"
    17  
    18  	core "github.com/v2fly/v2ray-core/v4"
    19  	"github.com/v2fly/v2ray-core/v4/app/dispatcher"
    20  	"github.com/v2fly/v2ray-core/v4/app/proxyman"
    21  	"github.com/v2fly/v2ray-core/v4/common"
    22  	"github.com/v2fly/v2ray-core/v4/common/errors"
    23  	"github.com/v2fly/v2ray-core/v4/common/log"
    24  	"github.com/v2fly/v2ray-core/v4/common/net"
    25  	"github.com/v2fly/v2ray-core/v4/common/retry"
    26  	"github.com/v2fly/v2ray-core/v4/common/serial"
    27  )
    28  
    29  func xor(b []byte) []byte {
    30  	r := make([]byte, len(b))
    31  	for i, v := range b {
    32  		r[i] = v ^ 'c'
    33  	}
    34  	return r
    35  }
    36  
    37  func readFrom(conn net.Conn, timeout time.Duration, length int) []byte {
    38  	b := make([]byte, length)
    39  	deadline := time.Now().Add(timeout)
    40  	conn.SetReadDeadline(deadline)
    41  	n, err := io.ReadFull(conn, b[:length])
    42  	if err != nil {
    43  		fmt.Println("Unexpected error from readFrom:", err)
    44  	}
    45  	return b[:n]
    46  }
    47  
    48  func readFrom2(conn net.Conn, timeout time.Duration, length int) ([]byte, error) {
    49  	b := make([]byte, length)
    50  	deadline := time.Now().Add(timeout)
    51  	conn.SetReadDeadline(deadline)
    52  	n, err := io.ReadFull(conn, b[:length])
    53  	if err != nil {
    54  		return nil, err
    55  	}
    56  	return b[:n], nil
    57  }
    58  
    59  func InitializeServerConfigs(configs ...*core.Config) ([]*exec.Cmd, error) {
    60  	servers := make([]*exec.Cmd, 0, 10)
    61  
    62  	for _, config := range configs {
    63  		server, err := InitializeServerConfig(config)
    64  		if err != nil {
    65  			CloseAllServers(servers)
    66  			return nil, err
    67  		}
    68  		servers = append(servers, server)
    69  	}
    70  
    71  	time.Sleep(time.Second * 2)
    72  
    73  	return servers, nil
    74  }
    75  
    76  func InitializeServerConfig(config *core.Config) (*exec.Cmd, error) {
    77  	err := BuildV2Ray()
    78  	if err != nil {
    79  		return nil, err
    80  	}
    81  
    82  	config = withDefaultApps(config)
    83  	configBytes, err := proto.Marshal(config)
    84  	if err != nil {
    85  		return nil, err
    86  	}
    87  	proc := RunV2RayProtobuf(configBytes)
    88  
    89  	if err := proc.Start(); err != nil {
    90  		return nil, err
    91  	}
    92  
    93  	return proc, nil
    94  }
    95  
    96  var (
    97  	testBinaryPath    string
    98  	testBinaryPathGen sync.Once
    99  )
   100  
   101  func genTestBinaryPath() {
   102  	testBinaryPathGen.Do(func() {
   103  		var tempDir string
   104  		common.Must(retry.Timed(5, 100).On(func() error {
   105  			dir, err := os.MkdirTemp("", "v2ray")
   106  			if err != nil {
   107  				return err
   108  			}
   109  			tempDir = dir
   110  			return nil
   111  		}))
   112  		file := filepath.Join(tempDir, "v2ray.test")
   113  		if runtime.GOOS == "windows" {
   114  			file += ".exe"
   115  		}
   116  		testBinaryPath = file
   117  		fmt.Printf("Generated binary path: %s\n", file)
   118  	})
   119  }
   120  
   121  func GetSourcePath() string {
   122  	return filepath.Join("github.com", "v2fly", "v2ray-core", "v4", "main")
   123  }
   124  
   125  func CloseAllServers(servers []*exec.Cmd) {
   126  	log.Record(&log.GeneralMessage{
   127  		Severity: log.Severity_Info,
   128  		Content:  "Closing all servers.",
   129  	})
   130  	for _, server := range servers {
   131  		if runtime.GOOS == "windows" {
   132  			server.Process.Kill()
   133  		} else {
   134  			server.Process.Signal(syscall.SIGTERM)
   135  		}
   136  	}
   137  	for _, server := range servers {
   138  		server.Process.Wait()
   139  	}
   140  	log.Record(&log.GeneralMessage{
   141  		Severity: log.Severity_Info,
   142  		Content:  "All server closed.",
   143  	})
   144  }
   145  
   146  func CloseServer(server *exec.Cmd) {
   147  	log.Record(&log.GeneralMessage{
   148  		Severity: log.Severity_Info,
   149  		Content:  "Closing server.",
   150  	})
   151  	if runtime.GOOS == "windows" {
   152  		server.Process.Kill()
   153  	} else {
   154  		server.Process.Signal(syscall.SIGTERM)
   155  	}
   156  	server.Process.Wait()
   157  	log.Record(&log.GeneralMessage{
   158  		Severity: log.Severity_Info,
   159  		Content:  "Server closed.",
   160  	})
   161  }
   162  
   163  func withDefaultApps(config *core.Config) *core.Config {
   164  	config.App = append(config.App, serial.ToTypedMessage(&dispatcher.Config{}))
   165  	config.App = append(config.App, serial.ToTypedMessage(&proxyman.InboundConfig{}))
   166  	config.App = append(config.App, serial.ToTypedMessage(&proxyman.OutboundConfig{}))
   167  	return config
   168  }
   169  
   170  func testTCPConn(port net.Port, payloadSize int, timeout time.Duration) func() error {
   171  	return func() error {
   172  		conn, err := net.DialTCP("tcp", nil, &net.TCPAddr{
   173  			IP:   []byte{127, 0, 0, 1},
   174  			Port: int(port),
   175  		})
   176  		if err != nil {
   177  			return err
   178  		}
   179  		defer conn.Close()
   180  
   181  		return testTCPConn2(conn, payloadSize, timeout)()
   182  	}
   183  }
   184  
   185  func testUDPConn(port net.Port, payloadSize int, timeout time.Duration) func() error { // nolint: unparam
   186  	return func() error {
   187  		conn, err := net.DialUDP("udp", nil, &net.UDPAddr{
   188  			IP:   []byte{127, 0, 0, 1},
   189  			Port: int(port),
   190  		})
   191  		if err != nil {
   192  			return err
   193  		}
   194  		defer conn.Close()
   195  
   196  		return testTCPConn2(conn, payloadSize, timeout)()
   197  	}
   198  }
   199  
   200  func testTCPConn2(conn net.Conn, payloadSize int, timeout time.Duration) func() error {
   201  	return func() error {
   202  		payload := make([]byte, payloadSize)
   203  		common.Must2(rand.Read(payload))
   204  
   205  		nBytes, err := conn.Write(payload)
   206  		if err != nil {
   207  			return err
   208  		}
   209  		if nBytes != len(payload) {
   210  			return errors.New("expect ", len(payload), " written, but actually ", nBytes)
   211  		}
   212  
   213  		response, err := readFrom2(conn, timeout, payloadSize)
   214  		if err != nil {
   215  			return err
   216  		}
   217  		_ = response
   218  
   219  		if r := bytes.Compare(response, xor(payload)); r != 0 {
   220  			return errors.New(r)
   221  		}
   222  
   223  		return nil
   224  	}
   225  }