istio.io/istio@v0.0.0-20240520182934-d79c90f27776/pkg/test/echo/server/forwarder/tcp.go (about)

     1  // Copyright Istio Authors
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package forwarder
    16  
    17  import (
    18  	"bufio"
    19  	"bytes"
    20  	"context"
    21  	"fmt"
    22  	"io"
    23  	"net"
    24  	"net/http"
    25  	"strings"
    26  
    27  	proxyproto "github.com/pires/go-proxyproto"
    28  
    29  	"istio.io/istio/pkg/hbone"
    30  	"istio.io/istio/pkg/test/echo"
    31  	"istio.io/istio/pkg/test/echo/common"
    32  	"istio.io/istio/pkg/test/echo/proto"
    33  )
    34  
    35  var _ protocol = &tcpProtocol{}
    36  
    37  type tcpProtocol struct {
    38  	e *executor
    39  }
    40  
    41  func newTCPProtocol(e *executor) protocol {
    42  	return &tcpProtocol{e: e}
    43  }
    44  
    45  func (c *tcpProtocol) ForwardEcho(ctx context.Context, cfg *Config) (*proto.ForwardEchoResponse, error) {
    46  	return doForward(ctx, cfg, c.e, c.makeRequest)
    47  }
    48  
    49  func (c *tcpProtocol) makeRequest(ctx context.Context, cfg *Config, requestID int) (string, error) {
    50  	conn, err := newTCPConnection(cfg)
    51  	if err != nil {
    52  		return "", err
    53  	}
    54  	defer func() { _ = conn.Close() }()
    55  
    56  	msgBuilder := strings.Builder{}
    57  	// If we have been asked to do TCP comms with a PROXY protocol header,
    58  	// determine which version, and send the header.
    59  	// https://www.haproxy.org/download/1.8/doc/proxy-protocol.txt
    60  	//
    61  	// PROXY protocol is only for L4 TCP traffic, and the magic string/bytes MUST
    62  	// be written at the BEGINNING of the TCP connection if communicating with a PROXY-protocol enabled server.
    63  	if cfg.proxyProtocolVersion != 0 {
    64  		fwLog.Infof("TCP forwarder using PROXY protocol version %d", cfg.proxyProtocolVersion)
    65  		header := proxyproto.HeaderProxyFromAddrs(byte(cfg.proxyProtocolVersion), conn.LocalAddr(), conn.RemoteAddr())
    66  		// After the connection is created, write the proxy headers first
    67  		if _, err := header.WriteTo(conn); err != nil {
    68  			fwLog.Warnf("TCP Proxy protocol header write failed: %v", err)
    69  			return msgBuilder.String(), err
    70  		}
    71  	}
    72  
    73  	echo.ForwarderURLField.WriteForRequest(&msgBuilder, requestID, cfg.Request.Url)
    74  
    75  	if cfg.Request.Message != "" {
    76  		echo.ForwarderMessageField.WriteForRequest(&msgBuilder, requestID, cfg.Request.Message)
    77  	}
    78  
    79  	// Apply per-request timeout to calculate deadline for reads/writes.
    80  	ctx, cancel := context.WithTimeout(ctx, cfg.timeout)
    81  	defer cancel()
    82  
    83  	// Apply the deadline to the connection.
    84  	deadline, _ := ctx.Deadline()
    85  	if err := conn.SetWriteDeadline(deadline); err != nil {
    86  		return msgBuilder.String(), err
    87  	}
    88  	if err := conn.SetReadDeadline(deadline); err != nil {
    89  		return msgBuilder.String(), err
    90  	}
    91  
    92  	// For server first protocol, we expect the server to send us the magic string first
    93  	if cfg.Request.ServerFirst {
    94  		readBytes, err := bufio.NewReader(conn).ReadBytes('\n')
    95  		if err != nil {
    96  			fwLog.Warnf("server first TCP read failed: %v", err)
    97  			return "", err
    98  		}
    99  		if string(readBytes) != common.ServerFirstMagicString {
   100  			return "", fmt.Errorf("did not receive magic string. Want %q, got %q", common.ServerFirstMagicString, string(readBytes))
   101  		}
   102  	}
   103  
   104  	// Make sure the client writes something to the buffer
   105  	message := "HelloWorld"
   106  	if cfg.Request.Message != "" {
   107  		message = cfg.Request.Message
   108  	}
   109  
   110  	if _, err := conn.Write([]byte(message + "\n")); err != nil {
   111  		fwLog.Warnf("TCP write failed: %v", err)
   112  		return msgBuilder.String(), err
   113  	}
   114  	var resBuffer bytes.Buffer
   115  	buf := make([]byte, 1024+len(message))
   116  	for {
   117  		n, err := conn.Read(buf)
   118  		if err != nil && err != io.EOF {
   119  			fwLog.Warnf("TCP read failed (already read %d bytes): %v", len(resBuffer.String()), err)
   120  			return msgBuilder.String(), err
   121  		}
   122  		resBuffer.Write(buf[:n])
   123  		// the message is sent last - when we get the whole message we can stop reading
   124  		if err == io.EOF || strings.Contains(resBuffer.String(), message) {
   125  			break
   126  		}
   127  	}
   128  
   129  	// format the output for forwarder response
   130  	for _, line := range strings.Split(resBuffer.String(), "\n") {
   131  		if line != "" {
   132  			echo.WriteBodyLine(&msgBuilder, requestID, line)
   133  		}
   134  	}
   135  
   136  	msg := msgBuilder.String()
   137  	expected := fmt.Sprintf("%s=%d", string(echo.StatusCodeField), http.StatusOK)
   138  	if cfg.Request.ExpectedResponse != nil {
   139  		expected = cfg.Request.ExpectedResponse.GetValue()
   140  	}
   141  	if !strings.Contains(msg, expected) {
   142  		return msg, fmt.Errorf("expect to recv message with %s, got %s. Return EOF", expected, msg)
   143  	}
   144  	return msg, nil
   145  }
   146  
   147  func (c *tcpProtocol) Close() error {
   148  	return nil
   149  }
   150  
   151  func newTCPConnection(cfg *Config) (net.Conn, error) {
   152  	address := cfg.Request.Url[len(cfg.scheme+"://"):]
   153  
   154  	if cfg.secure {
   155  		return hbone.TLSDialWithDialer(newDialer(cfg), "tcp", address, cfg.tlsConfig)
   156  	}
   157  
   158  	ctx, cancel := context.WithTimeout(context.Background(), common.ConnectionTimeout)
   159  	defer cancel()
   160  
   161  	return newDialer(cfg).DialContext(ctx, "tcp", address)
   162  }