github.com/EagleQL/Xray-core@v1.4.3/testing/servers/tcp/tcp.go (about)

     1  package tcp
     2  
     3  import (
     4  	"context"
     5  	"fmt"
     6  	"io"
     7  
     8  	"github.com/xtls/xray-core/common/buf"
     9  	"github.com/xtls/xray-core/common/net"
    10  	"github.com/xtls/xray-core/common/task"
    11  	"github.com/xtls/xray-core/transport/internet"
    12  	"github.com/xtls/xray-core/transport/pipe"
    13  )
    14  
    15  type Server struct {
    16  	Port         net.Port
    17  	MsgProcessor func(msg []byte) []byte
    18  	ShouldClose  bool
    19  	SendFirst    []byte
    20  	Listen       net.Address
    21  	listener     net.Listener
    22  }
    23  
    24  func (server *Server) Start() (net.Destination, error) {
    25  	return server.StartContext(context.Background(), nil)
    26  }
    27  
    28  func (server *Server) StartContext(ctx context.Context, sockopt *internet.SocketConfig) (net.Destination, error) {
    29  	listenerAddr := server.Listen
    30  	if listenerAddr == nil {
    31  		listenerAddr = net.LocalHostIP
    32  	}
    33  	listener, err := internet.ListenSystem(ctx, &net.TCPAddr{
    34  		IP:   listenerAddr.IP(),
    35  		Port: int(server.Port),
    36  	}, sockopt)
    37  	if err != nil {
    38  		return net.Destination{}, err
    39  	}
    40  
    41  	localAddr := listener.Addr().(*net.TCPAddr)
    42  	server.Port = net.Port(localAddr.Port)
    43  	server.listener = listener
    44  	go server.acceptConnections(listener.(*net.TCPListener))
    45  
    46  	return net.TCPDestination(net.IPAddress(localAddr.IP), net.Port(localAddr.Port)), nil
    47  }
    48  
    49  func (server *Server) acceptConnections(listener *net.TCPListener) {
    50  	for {
    51  		conn, err := listener.Accept()
    52  		if err != nil {
    53  			fmt.Printf("Failed accept TCP connection: %v\n", err)
    54  			return
    55  		}
    56  
    57  		go server.handleConnection(conn)
    58  	}
    59  }
    60  
    61  func (server *Server) handleConnection(conn net.Conn) {
    62  	if len(server.SendFirst) > 0 {
    63  		conn.Write(server.SendFirst)
    64  	}
    65  
    66  	pReader, pWriter := pipe.New(pipe.WithoutSizeLimit())
    67  	err := task.Run(context.Background(), func() error {
    68  		defer pWriter.Close()
    69  
    70  		for {
    71  			b := buf.New()
    72  			if _, err := b.ReadFrom(conn); err != nil {
    73  				if err == io.EOF {
    74  					return nil
    75  				}
    76  				return err
    77  			}
    78  			copy(b.Bytes(), server.MsgProcessor(b.Bytes()))
    79  			if err := pWriter.WriteMultiBuffer(buf.MultiBuffer{b}); err != nil {
    80  				return err
    81  			}
    82  		}
    83  	}, func() error {
    84  		defer pReader.Interrupt()
    85  
    86  		w := buf.NewWriter(conn)
    87  		for {
    88  			mb, err := pReader.ReadMultiBuffer()
    89  			if err != nil {
    90  				if err == io.EOF {
    91  					return nil
    92  				}
    93  				return err
    94  			}
    95  			if err := w.WriteMultiBuffer(mb); err != nil {
    96  				return err
    97  			}
    98  		}
    99  	})
   100  
   101  	if err != nil {
   102  		fmt.Println("failed to transfer data: ", err.Error())
   103  	}
   104  
   105  	conn.Close()
   106  }
   107  
   108  func (server *Server) Close() error {
   109  	return server.listener.Close()
   110  }