golang.org/x/build@v0.0.0-20240506185731-218518f32b70/buildlet/buildletclient_test.go (about)

     1  // Copyright 2020 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package buildlet
     6  
     7  import (
     8  	"context"
     9  	"crypto/tls"
    10  	"encoding/json"
    11  	"errors"
    12  	"net"
    13  	"net/http"
    14  	"net/http/httptest"
    15  	"net/url"
    16  	"strings"
    17  	"testing"
    18  )
    19  
    20  func TestConnectSSHTLS(t *testing.T) {
    21  	testCases := []struct {
    22  		desc         string
    23  		authUser     string
    24  		dialer       func(context.Context) (net.Conn, error)
    25  		key          string
    26  		keyPair      KeyPair
    27  		password     string
    28  		user         string
    29  		wantAuthUser string
    30  	}{
    31  		{
    32  			desc:         "tls-without-authuser",
    33  			authUser:     "",
    34  			key:          "key-foo",
    35  			keyPair:      createKeyPair(t),
    36  			password:     "foo",
    37  			user:         "kate",
    38  			wantAuthUser: "gomote",
    39  		},
    40  		{
    41  			desc:         "tls-with-authuser",
    42  			authUser:     "george",
    43  			key:          "key-foo",
    44  			keyPair:      createKeyPair(t),
    45  			password:     "foo",
    46  			user:         "kate",
    47  			wantAuthUser: "george",
    48  		},
    49  		{
    50  			desc:         "tls-with-configured-dialer",
    51  			authUser:     "",
    52  			dialer:       func(_ context.Context) (net.Conn, error) { return nil, errors.New("test error") },
    53  			key:          "key-foo",
    54  			keyPair:      createKeyPair(t),
    55  			password:     "foo",
    56  			user:         "kate",
    57  			wantAuthUser: "gomote",
    58  		},
    59  	}
    60  	for _, tc := range testCases {
    61  		t.Run(tc.desc, func(t *testing.T) {
    62  			ts := httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
    63  				if gotUser := r.Header.Get("X-Go-Ssh-User"); gotUser != tc.user {
    64  					t.Errorf("r.Header.Get(X-Go-Ssh-User) = %q; want %q", gotUser, tc.user)
    65  				}
    66  				if gotKey := r.Header.Get("X-Go-Authorized-Key"); gotKey != tc.key {
    67  					t.Errorf("r.Header.Get(X-Go-Authorized-Key) = %q; want %q", gotKey, tc.key)
    68  				}
    69  				if gotAuthUser, gotAuthPass, gotOk := r.BasicAuth(); !gotOk || gotAuthUser != tc.wantAuthUser || gotAuthPass != tc.password {
    70  					t.Errorf("Request.BasicAuth() = %q, %q, %t; want %q, %q, true", gotAuthUser, gotAuthPass, gotOk, tc.wantAuthUser, tc.password)
    71  				}
    72  				w.WriteHeader(http.StatusSwitchingProtocols)
    73  			}))
    74  			cert, err := tls.X509KeyPair([]byte(tc.keyPair.CertPEM), []byte(tc.keyPair.KeyPEM))
    75  			if err != nil {
    76  				t.Fatalf("tls.X509KeyPair([]byte(%q), []byte(%q)) = %v, %q; want no error", tc.keyPair.CertPEM, tc.keyPair.KeyPEM, cert, err)
    77  			}
    78  			ts.TLS = &tls.Config{
    79  				Certificates: []tls.Certificate{cert},
    80  			}
    81  			ts.StartTLS()
    82  			defer ts.Close()
    83  			c := client{
    84  				ipPort:   strings.TrimPrefix(ts.URL, "https://"),
    85  				tls:      tc.keyPair,
    86  				password: tc.password,
    87  				authUser: tc.authUser,
    88  				dialer:   tc.dialer,
    89  			}
    90  			gotConn, gotErr := c.ConnectSSH(tc.user, tc.key)
    91  			if gotErr != nil {
    92  				t.Fatalf("Client.ConnectSSH(%s, %s) = %v, %v; want no error", tc.user, tc.key, gotConn, gotErr)
    93  			}
    94  		})
    95  	}
    96  }
    97  
    98  func TestConnectSSHNonTLS(t *testing.T) {
    99  	testCases := []struct {
   100  		desc      string
   101  		authUser  string
   102  		basicAuth bool
   103  		dialer    func(context.Context) (net.Conn, error)
   104  		key       string
   105  		password  string
   106  		user      string
   107  		wantErr   bool
   108  	}{
   109  		{
   110  			desc:      "non-tls-without-authuser",
   111  			authUser:  "gomote",
   112  			basicAuth: false,
   113  			key:       "key-foo",
   114  			password:  "foo",
   115  			user:      "kate",
   116  			wantErr:   false,
   117  		},
   118  		{
   119  			desc:      "non-tls--with-authuser",
   120  			authUser:  "gomote",
   121  			basicAuth: true,
   122  			key:       "key-foo",
   123  			password:  "foo",
   124  			user:      "kate",
   125  			wantErr:   false,
   126  		},
   127  		{
   128  			desc:      "non-tls-with-configured-dialer",
   129  			authUser:  "gomote",
   130  			basicAuth: true,
   131  			dialer: func(context.Context) (net.Conn, error) {
   132  				return nil, errors.New("test error")
   133  			},
   134  			key:      "key-foo",
   135  			password: "foo",
   136  			user:     "kate",
   137  			wantErr:  true,
   138  		},
   139  	}
   140  	for _, tc := range testCases {
   141  		t.Run(tc.desc, func(t *testing.T) {
   142  			ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   143  				if gotUser := r.Header.Get("X-Go-Ssh-User"); gotUser != tc.user {
   144  					t.Errorf("r.Header.Get(X-Go-Ssh-User) = %q; want %q", gotUser, tc.user)
   145  				}
   146  				if gotKey := r.Header.Get("X-Go-Authorized-Key"); gotKey != tc.key {
   147  					t.Errorf("r.Header.Get(X-Go-Authorized-Key) = %q; want %q", gotKey, tc.key)
   148  				}
   149  				if gotAuthUser, gotAuthPass, gotOk := r.BasicAuth(); gotOk || gotAuthUser != "" || gotAuthPass != "" {
   150  					t.Errorf("Request.BasicAuth() = %q, %q, %t; want %q, %q, %t", gotAuthUser, gotAuthPass, gotOk, tc.user, tc.password, tc.basicAuth)
   151  				}
   152  				w.WriteHeader(http.StatusSwitchingProtocols)
   153  			}))
   154  			defer ts.Close()
   155  			c := client{
   156  				ipPort:   strings.TrimPrefix(ts.URL, "http://"),
   157  				password: tc.password,
   158  				authUser: tc.authUser,
   159  				dialer:   tc.dialer,
   160  			}
   161  			gotConn, gotErr := c.ConnectSSH(tc.user, tc.key)
   162  			if (gotErr != nil) != tc.wantErr {
   163  				t.Fatalf("Client.ConnectSSH(%q, %q) = %v, %v; want net.Conn, error=%t", tc.user, tc.key, gotConn, gotErr, tc.wantErr)
   164  			}
   165  		})
   166  	}
   167  }
   168  
   169  func createKeyPair(t *testing.T) KeyPair {
   170  	kp, err := NewKeyPair()
   171  	if err != nil {
   172  		t.Fatalf("NewKeyPair() = %v, %s; want no error", kp, err)
   173  	}
   174  	return kp
   175  }
   176  
   177  // Test that Exec returns ErrTimeout upon reaching the context timeout
   178  // during command execution, as its documentation promises.
   179  func TestExecTimeoutError(t *testing.T) {
   180  	mux := http.NewServeMux()
   181  	mux.HandleFunc("/status", func(w http.ResponseWriter, req *http.Request) {
   182  		json.NewEncoder(w).Encode(Status{})
   183  	})
   184  	mux.HandleFunc("/exec", func(w http.ResponseWriter, req *http.Request) {
   185  		w.Write([]byte("."))
   186  		w.(http.Flusher).Flush() // /exec needs to flush headers right away.
   187  		<-req.Context().Done()   // Simulate that execution hangs, so no more output.
   188  	})
   189  	ts := httptest.NewServer(mux)
   190  	defer ts.Close()
   191  	u, err := url.Parse(ts.URL)
   192  	if err != nil {
   193  		t.Fatalf("unable to parse http server url %s", err)
   194  	}
   195  	cl := NewClient(u.Host, NoKeyPair)
   196  	defer cl.Close()
   197  
   198  	// Use a custom context that reports context.DeadlineExceeded
   199  	// after Exec starts command execution. (context.WithTimeout
   200  	// requires us to select an arbitrary duration, which might
   201  	// not be long enough or will make the test take too long.)
   202  	ctx := deadlineOnDemandContext{
   203  		Context: context.Background(),
   204  		done:    make(chan struct{}),
   205  	}
   206  	_, execErr := cl.Exec(ctx, "./bin/test", ExecOpts{
   207  		OnStartExec: func() { close(ctx.done) },
   208  	})
   209  	if execErr != ErrTimeout {
   210  		t.Errorf("cl.Exec error = %v; want %v", execErr, ErrTimeout)
   211  	}
   212  }
   213  
   214  type deadlineOnDemandContext struct {
   215  	context.Context
   216  	done chan struct{}
   217  }
   218  
   219  func (c deadlineOnDemandContext) Done() <-chan struct{} { return c.done }
   220  func (c deadlineOnDemandContext) Err() error {
   221  	select {
   222  	default:
   223  		return nil
   224  	case <-c.done:
   225  		return context.DeadlineExceeded
   226  	}
   227  }