golang.org/x/build@v0.0.0-20240506185731-218518f32b70/buildlet/buildletclient_test.go (about) 1 // Copyright 2020 The Go Authors. All rights reserved. 2 // Use of this source code is governed by a BSD-style 3 // license that can be found in the LICENSE file. 4 5 package buildlet 6 7 import ( 8 "context" 9 "crypto/tls" 10 "encoding/json" 11 "errors" 12 "net" 13 "net/http" 14 "net/http/httptest" 15 "net/url" 16 "strings" 17 "testing" 18 ) 19 20 func TestConnectSSHTLS(t *testing.T) { 21 testCases := []struct { 22 desc string 23 authUser string 24 dialer func(context.Context) (net.Conn, error) 25 key string 26 keyPair KeyPair 27 password string 28 user string 29 wantAuthUser string 30 }{ 31 { 32 desc: "tls-without-authuser", 33 authUser: "", 34 key: "key-foo", 35 keyPair: createKeyPair(t), 36 password: "foo", 37 user: "kate", 38 wantAuthUser: "gomote", 39 }, 40 { 41 desc: "tls-with-authuser", 42 authUser: "george", 43 key: "key-foo", 44 keyPair: createKeyPair(t), 45 password: "foo", 46 user: "kate", 47 wantAuthUser: "george", 48 }, 49 { 50 desc: "tls-with-configured-dialer", 51 authUser: "", 52 dialer: func(_ context.Context) (net.Conn, error) { return nil, errors.New("test error") }, 53 key: "key-foo", 54 keyPair: createKeyPair(t), 55 password: "foo", 56 user: "kate", 57 wantAuthUser: "gomote", 58 }, 59 } 60 for _, tc := range testCases { 61 t.Run(tc.desc, func(t *testing.T) { 62 ts := httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 63 if gotUser := r.Header.Get("X-Go-Ssh-User"); gotUser != tc.user { 64 t.Errorf("r.Header.Get(X-Go-Ssh-User) = %q; want %q", gotUser, tc.user) 65 } 66 if gotKey := r.Header.Get("X-Go-Authorized-Key"); gotKey != tc.key { 67 t.Errorf("r.Header.Get(X-Go-Authorized-Key) = %q; want %q", gotKey, tc.key) 68 } 69 if gotAuthUser, gotAuthPass, gotOk := r.BasicAuth(); !gotOk || gotAuthUser != tc.wantAuthUser || gotAuthPass != tc.password { 70 t.Errorf("Request.BasicAuth() = %q, %q, %t; want %q, %q, true", gotAuthUser, gotAuthPass, gotOk, tc.wantAuthUser, tc.password) 71 } 72 w.WriteHeader(http.StatusSwitchingProtocols) 73 })) 74 cert, err := tls.X509KeyPair([]byte(tc.keyPair.CertPEM), []byte(tc.keyPair.KeyPEM)) 75 if err != nil { 76 t.Fatalf("tls.X509KeyPair([]byte(%q), []byte(%q)) = %v, %q; want no error", tc.keyPair.CertPEM, tc.keyPair.KeyPEM, cert, err) 77 } 78 ts.TLS = &tls.Config{ 79 Certificates: []tls.Certificate{cert}, 80 } 81 ts.StartTLS() 82 defer ts.Close() 83 c := client{ 84 ipPort: strings.TrimPrefix(ts.URL, "https://"), 85 tls: tc.keyPair, 86 password: tc.password, 87 authUser: tc.authUser, 88 dialer: tc.dialer, 89 } 90 gotConn, gotErr := c.ConnectSSH(tc.user, tc.key) 91 if gotErr != nil { 92 t.Fatalf("Client.ConnectSSH(%s, %s) = %v, %v; want no error", tc.user, tc.key, gotConn, gotErr) 93 } 94 }) 95 } 96 } 97 98 func TestConnectSSHNonTLS(t *testing.T) { 99 testCases := []struct { 100 desc string 101 authUser string 102 basicAuth bool 103 dialer func(context.Context) (net.Conn, error) 104 key string 105 password string 106 user string 107 wantErr bool 108 }{ 109 { 110 desc: "non-tls-without-authuser", 111 authUser: "gomote", 112 basicAuth: false, 113 key: "key-foo", 114 password: "foo", 115 user: "kate", 116 wantErr: false, 117 }, 118 { 119 desc: "non-tls--with-authuser", 120 authUser: "gomote", 121 basicAuth: true, 122 key: "key-foo", 123 password: "foo", 124 user: "kate", 125 wantErr: false, 126 }, 127 { 128 desc: "non-tls-with-configured-dialer", 129 authUser: "gomote", 130 basicAuth: true, 131 dialer: func(context.Context) (net.Conn, error) { 132 return nil, errors.New("test error") 133 }, 134 key: "key-foo", 135 password: "foo", 136 user: "kate", 137 wantErr: true, 138 }, 139 } 140 for _, tc := range testCases { 141 t.Run(tc.desc, func(t *testing.T) { 142 ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 143 if gotUser := r.Header.Get("X-Go-Ssh-User"); gotUser != tc.user { 144 t.Errorf("r.Header.Get(X-Go-Ssh-User) = %q; want %q", gotUser, tc.user) 145 } 146 if gotKey := r.Header.Get("X-Go-Authorized-Key"); gotKey != tc.key { 147 t.Errorf("r.Header.Get(X-Go-Authorized-Key) = %q; want %q", gotKey, tc.key) 148 } 149 if gotAuthUser, gotAuthPass, gotOk := r.BasicAuth(); gotOk || gotAuthUser != "" || gotAuthPass != "" { 150 t.Errorf("Request.BasicAuth() = %q, %q, %t; want %q, %q, %t", gotAuthUser, gotAuthPass, gotOk, tc.user, tc.password, tc.basicAuth) 151 } 152 w.WriteHeader(http.StatusSwitchingProtocols) 153 })) 154 defer ts.Close() 155 c := client{ 156 ipPort: strings.TrimPrefix(ts.URL, "http://"), 157 password: tc.password, 158 authUser: tc.authUser, 159 dialer: tc.dialer, 160 } 161 gotConn, gotErr := c.ConnectSSH(tc.user, tc.key) 162 if (gotErr != nil) != tc.wantErr { 163 t.Fatalf("Client.ConnectSSH(%q, %q) = %v, %v; want net.Conn, error=%t", tc.user, tc.key, gotConn, gotErr, tc.wantErr) 164 } 165 }) 166 } 167 } 168 169 func createKeyPair(t *testing.T) KeyPair { 170 kp, err := NewKeyPair() 171 if err != nil { 172 t.Fatalf("NewKeyPair() = %v, %s; want no error", kp, err) 173 } 174 return kp 175 } 176 177 // Test that Exec returns ErrTimeout upon reaching the context timeout 178 // during command execution, as its documentation promises. 179 func TestExecTimeoutError(t *testing.T) { 180 mux := http.NewServeMux() 181 mux.HandleFunc("/status", func(w http.ResponseWriter, req *http.Request) { 182 json.NewEncoder(w).Encode(Status{}) 183 }) 184 mux.HandleFunc("/exec", func(w http.ResponseWriter, req *http.Request) { 185 w.Write([]byte(".")) 186 w.(http.Flusher).Flush() // /exec needs to flush headers right away. 187 <-req.Context().Done() // Simulate that execution hangs, so no more output. 188 }) 189 ts := httptest.NewServer(mux) 190 defer ts.Close() 191 u, err := url.Parse(ts.URL) 192 if err != nil { 193 t.Fatalf("unable to parse http server url %s", err) 194 } 195 cl := NewClient(u.Host, NoKeyPair) 196 defer cl.Close() 197 198 // Use a custom context that reports context.DeadlineExceeded 199 // after Exec starts command execution. (context.WithTimeout 200 // requires us to select an arbitrary duration, which might 201 // not be long enough or will make the test take too long.) 202 ctx := deadlineOnDemandContext{ 203 Context: context.Background(), 204 done: make(chan struct{}), 205 } 206 _, execErr := cl.Exec(ctx, "./bin/test", ExecOpts{ 207 OnStartExec: func() { close(ctx.done) }, 208 }) 209 if execErr != ErrTimeout { 210 t.Errorf("cl.Exec error = %v; want %v", execErr, ErrTimeout) 211 } 212 } 213 214 type deadlineOnDemandContext struct { 215 context.Context 216 done chan struct{} 217 } 218 219 func (c deadlineOnDemandContext) Done() <-chan struct{} { return c.done } 220 func (c deadlineOnDemandContext) Err() error { 221 select { 222 default: 223 return nil 224 case <-c.done: 225 return context.DeadlineExceeded 226 } 227 }