k8s.io/kubernetes@v1.31.0-alpha.0.0.20240520171757-56147500dadc/test/images/agnhost/connect/connect.go (about)

     1  /*
     2  Copyright 2019 The Kubernetes Authors.
     3  
     4  Licensed under the Apache License, Version 2.0 (the "License");
     5  you may not use this file except in compliance with the License.
     6  You may obtain a copy of the License at
     7  
     8      http://www.apache.org/licenses/LICENSE-2.0
     9  
    10  Unless required by applicable law or agreed to in writing, software
    11  distributed under the License is distributed on an "AS IS" BASIS,
    12  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    13  See the License for the specific language governing permissions and
    14  limitations under the License.
    15  */
    16  
    17  package connect
    18  
    19  import (
    20  	"fmt"
    21  	"net"
    22  	"os"
    23  	"strings"
    24  	"syscall"
    25  	"time"
    26  
    27  	"github.com/ishidawataru/sctp"
    28  	"github.com/spf13/cobra"
    29  )
    30  
    31  // CmdConnect is used by agnhost Cobra.
    32  var CmdConnect = &cobra.Command{
    33  	Use:   "connect [host:port]",
    34  	Short: "Attempts a TCP, UDP or SCTP connection and returns useful errors",
    35  	Long: `Tries to open a TCP, UDP or SCTP connection to the given host and port. On error it prints an error message prefixed with a specific fixed string that test cases can check for:
    36  
    37  * UNKNOWN - Generic/unknown (non-network) error (eg, bad arguments)
    38  * TIMEOUT - The connection attempt timed out
    39  * DNS - An error in DNS resolution
    40  * REFUSED - Connection refused
    41  * OTHER - Other networking error (eg, "no route to host")`,
    42  	Args: cobra.ExactArgs(1),
    43  	Run:  main,
    44  }
    45  
    46  var (
    47  	timeout  time.Duration
    48  	protocol string
    49  	udpData  string
    50  	sctpData string
    51  )
    52  
    53  func init() {
    54  	CmdConnect.Flags().DurationVar(&timeout, "timeout", time.Duration(0), "Maximum time before returning an error")
    55  	CmdConnect.Flags().StringVar(&protocol, "protocol", "tcp", "The protocol to use to perform the connection, can be tcp, udp or sctp")
    56  	CmdConnect.Flags().StringVar(&udpData, "udp-data", "hostname", "The UDP payload send to the server")
    57  	CmdConnect.Flags().StringVar(&sctpData, "sctp-data", "hostname", "The SCTP payload send to the server")
    58  }
    59  
    60  func main(cmd *cobra.Command, args []string) {
    61  	dest := args[0]
    62  	switch protocol {
    63  	case "", "tcp":
    64  		connectTCP(dest, timeout)
    65  	case "udp":
    66  		connectUDP(dest, timeout, udpData)
    67  	case "sctp":
    68  		connectSCTP(dest, timeout, sctpData)
    69  	default:
    70  		fmt.Fprint(os.Stderr, "Unsupported protocol\n", protocol)
    71  		os.Exit(1)
    72  	}
    73  }
    74  
    75  func connectTCP(dest string, timeout time.Duration) {
    76  	// Redundantly parse and resolve the destination so we can return the correct
    77  	// errors if there's a problem.
    78  	if _, _, err := net.SplitHostPort(dest); err != nil {
    79  		fmt.Fprintf(os.Stderr, "UNKNOWN: %v\n", err)
    80  		os.Exit(1)
    81  	}
    82  	if _, err := net.ResolveTCPAddr("tcp", dest); err != nil {
    83  		fmt.Fprintf(os.Stderr, "DNS: %v\n", err)
    84  		os.Exit(1)
    85  	}
    86  
    87  	conn, err := net.DialTimeout("tcp", dest, timeout)
    88  	if err == nil {
    89  		conn.Close()
    90  		os.Exit(0)
    91  	}
    92  	if opErr, ok := err.(*net.OpError); ok {
    93  		if opErr.Timeout() {
    94  			fmt.Fprintf(os.Stderr, "TIMEOUT\n")
    95  			os.Exit(1)
    96  		} else if syscallErr, ok := opErr.Err.(*os.SyscallError); ok {
    97  			if syscallErr.Err == syscall.ECONNREFUSED {
    98  				fmt.Fprintf(os.Stderr, "REFUSED\n")
    99  				os.Exit(1)
   100  			}
   101  		}
   102  	}
   103  
   104  	fmt.Fprintf(os.Stderr, "OTHER: %v\n", err)
   105  	os.Exit(1)
   106  }
   107  
   108  func connectSCTP(dest string, timeout time.Duration, data string) {
   109  	var (
   110  		buf  = make([]byte, 1024)
   111  		conn *sctp.SCTPConn
   112  	)
   113  	addr, err := sctp.ResolveSCTPAddr("sctp", dest)
   114  	if err != nil {
   115  		fmt.Fprintf(os.Stderr, "DNS: %v\n", err)
   116  		os.Exit(1)
   117  	}
   118  
   119  	timeoutCh := time.After(timeout)
   120  	errCh := make(chan error)
   121  
   122  	go func() {
   123  		conn, err = sctp.DialSCTP("sctp", nil, addr)
   124  		if err != nil {
   125  			errCh <- err
   126  			return
   127  		}
   128  		defer func() {
   129  			errCh <- conn.Close()
   130  		}()
   131  
   132  		if _, err = conn.Write([]byte(fmt.Sprintf("%s\n", data))); err != nil {
   133  			errCh <- err
   134  			return
   135  		}
   136  
   137  		if _, err = conn.Read(buf); err != nil {
   138  			errCh <- err
   139  			return
   140  		}
   141  	}()
   142  
   143  	select {
   144  	case err := <-errCh:
   145  		if err != nil {
   146  			fmt.Fprintf(os.Stderr, "OTHER: %v\n", err)
   147  			os.Exit(1)
   148  		}
   149  	case <-timeoutCh:
   150  		fmt.Fprint(os.Stderr, "TIMEOUT\n")
   151  		os.Exit(1)
   152  	}
   153  }
   154  
   155  func connectUDP(dest string, timeout time.Duration, data string) {
   156  	var (
   157  		readBytes int
   158  		buf       = make([]byte, 1024)
   159  	)
   160  
   161  	if _, err := net.ResolveUDPAddr("udp", dest); err != nil {
   162  		fmt.Fprintf(os.Stderr, "DNS: %v\n", err)
   163  		os.Exit(1)
   164  	}
   165  
   166  	conn, err := net.Dial("udp", dest)
   167  	if err != nil {
   168  		fmt.Fprintf(os.Stderr, "OTHER: %v\n", err)
   169  		os.Exit(1)
   170  	}
   171  
   172  	if timeout > 0 {
   173  		if err = conn.SetDeadline(time.Now().Add(timeout)); err != nil {
   174  			fmt.Fprintf(os.Stderr, "OTHER: %v\n", err)
   175  			os.Exit(1)
   176  		}
   177  	}
   178  
   179  	if _, err = conn.Write([]byte(fmt.Sprintf("%s\n", data))); err != nil {
   180  		parseUDPErrorAndExit(err)
   181  	}
   182  
   183  	if readBytes, err = conn.Read(buf); err != nil {
   184  		parseUDPErrorAndExit(err)
   185  	}
   186  
   187  	// ensure the response from UDP server
   188  	if readBytes == 0 {
   189  		fmt.Fprintf(os.Stderr, "OTHER: No data received from the server. Cannot guarantee the server received the request.\n")
   190  		os.Exit(1)
   191  	}
   192  }
   193  
   194  func parseUDPErrorAndExit(err error) {
   195  	neterr, ok := err.(net.Error)
   196  	if ok && neterr.Timeout() {
   197  		fmt.Fprintf(os.Stderr, "TIMEOUT: %v\n", err)
   198  	} else if strings.Contains(err.Error(), "connection refused") {
   199  		fmt.Fprintf(os.Stderr, "REFUSED: %v\n", err)
   200  	} else {
   201  		fmt.Fprintf(os.Stderr, "UNKNOWN: %v\n", err)
   202  	}
   203  	os.Exit(1)
   204  }