github.com/bazelbuild/rules_webtesting@v0.2.0/go/wsl/driver/driver.go (about) 1 // Copyright 2018 Google Inc. 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 // Package driver launches a WebDriver driver endpoint binary (e.g. ChromeDriver, SafariDriver, etc) 16 // based on a google:wslConfig capability. 17 package driver 18 19 import ( 20 "context" 21 "encoding/json" 22 "errors" 23 "fmt" 24 "log" 25 "net" 26 "net/http" 27 "os" 28 "os/exec" 29 "strconv" 30 "sync" 31 "syscall" 32 "time" 33 34 "github.com/bazelbuild/rules_webtesting/go/cmdhelper" 35 "github.com/bazelbuild/rules_webtesting/go/httphelper" 36 "github.com/bazelbuild/rules_webtesting/go/metadata/capabilities" 37 "github.com/bazelbuild/rules_webtesting/go/webdriver" 38 ) 39 40 const compName = "WSL Driver" 41 42 // Driver is wrapper around a running WebDriver endpoint binary. 43 type Driver struct { 44 Address string 45 caps *wslCaps 46 stopped chan error 47 cmd *exec.Cmd 48 49 // Mutex to prevent overlapping commands to remote end. 50 mu sync.Mutex 51 } 52 53 type wslCaps struct { 54 binary string 55 args []string 56 port int 57 timeout time.Duration 58 env map[string]string 59 shutdown bool 60 status bool 61 stdout string 62 stderr string 63 } 64 65 // New creates starts a WebDriver endpoint binary based on caps. Argument caps should just be 66 // the google:wslConfig capability extracted from the capabilities passed into a new session request. 67 func New(ctx context.Context, localHost, sessionID string, caps map[string]interface{}) (*Driver, error) { 68 wslCaps, err := extractWSLCaps(sessionID, caps) 69 if err != nil { 70 return nil, err 71 } 72 hostPort := net.JoinHostPort(localHost, strconv.Itoa(wslCaps.port)) 73 74 d := &Driver{ 75 Address: fmt.Sprintf("http://%s", hostPort), 76 caps: wslCaps, 77 stopped: make(chan error), 78 } 79 80 errChan, err := d.startDriver() 81 if err != nil { 82 return nil, err 83 } 84 85 deadline, cancel := context.WithTimeout(ctx, d.caps.timeout) 86 defer cancel() 87 88 statusURL := fmt.Sprintf("http://%s/status", hostPort) 89 90 for { 91 select { 92 case err := <-errChan: 93 return nil, err 94 default: 95 } 96 97 if response, err := httphelper.Get(deadline, statusURL); err == nil { 98 if !d.caps.status { 99 // just wait for successful connection because status endpoint doesn't work. 100 break 101 } 102 if response.StatusCode == http.StatusOK { 103 respJSON := map[string]interface{}{} 104 if err := json.NewDecoder(response.Body).Decode(&respJSON); err == nil { 105 log.Printf("Response: %+v", respJSON) 106 if status, ok := respJSON["status"].(float64); ok { 107 if int(status) == 0 { 108 break 109 } 110 } else if value, ok := respJSON["value"].(map[string]interface{}); ok { 111 if ready, _ := value["ready"].(bool); ready { 112 break 113 } 114 } 115 } 116 } 117 } 118 119 if deadline.Err() != nil { 120 if d.cmd != nil { 121 go d.cmd.Process.Kill() 122 } 123 return nil, deadline.Err() 124 } 125 126 time.Sleep(10 * time.Millisecond) 127 } 128 129 return d, nil 130 } 131 132 func extractWSLCaps(sessionID string, caps map[string]interface{}) (*wslCaps, error) { 133 binary := "" 134 if b, ok := caps["binary"]; ok { 135 bs, ok := b.(string) 136 if !ok { 137 return nil, fmt.Errorf("binary %#v is not a string", b) 138 } 139 binary = bs 140 } 141 142 port := 0 143 if p, ok := caps["port"]; ok { 144 switch pt := p.(type) { 145 case float64: 146 port = int(pt) 147 case string: 148 pi, err := strconv.Atoi(pt) 149 if err != nil { 150 return nil, err 151 } 152 port = pi 153 default: 154 return nil, fmt.Errorf("port %#v is not a number or string", p) 155 } 156 } 157 158 if port == 0 { 159 return nil, errors.New(`port must be set (use "%WSLPORT:WSL%" if you don't care what port is used)`) 160 } 161 162 var args []string 163 if a, ok := caps["args"]; ok { 164 if binary == "" { 165 return nil, fmt.Errorf("args set to %#v when binary is not set", a) 166 } 167 168 argsInterface, ok := a.([]interface{}) 169 if !ok { 170 return nil, fmt.Errorf("args %#v is not a list", a) 171 } 172 173 for _, argInterface := range argsInterface { 174 arg, ok := argInterface.(string) 175 if !ok { 176 return nil, fmt.Errorf("element %#v in args is not a string", argInterface) 177 } 178 args = append(args, arg) 179 } 180 } 181 182 timeout := 1 * time.Second 183 if t, ok := caps["timeout"]; ok { 184 switch tt := t.(type) { 185 case float64: 186 // Incoming value is in seconds. 187 to, err := time.ParseDuration(fmt.Sprintf("%fs", tt)) 188 if err != nil { 189 return nil, err 190 } 191 timeout = to 192 case string: 193 to, err := time.ParseDuration(tt) 194 if err != nil { 195 return nil, err 196 } 197 timeout = to 198 default: 199 return nil, fmt.Errorf("timeout %#v is not a number or string", t) 200 } 201 } 202 203 env := map[string]string{} 204 if e, ok := caps["env"]; ok { 205 if binary == "" { 206 return nil, fmt.Errorf("env set to %#v when binary is not set", e) 207 } 208 em, ok := e.(map[string]interface{}) 209 if !ok { 210 return nil, fmt.Errorf("env %#v is not a map", e) 211 } 212 for k, v := range em { 213 vs, ok := v.(string) 214 if !ok { 215 return nil, fmt.Errorf("value %#v for key %q in env is not a string", v, k) 216 } 217 env[k] = vs 218 } 219 } 220 221 shutdown := true 222 if s, ok := caps["shutdown"]; ok { 223 sb, ok := s.(bool) 224 if !ok { 225 return nil, fmt.Errorf("shutdown %#v is not a boolean", s) 226 } 227 shutdown = sb 228 } 229 230 status := true 231 if s, ok := caps["status"]; ok { 232 sb, ok := s.(bool) 233 if !ok { 234 return nil, fmt.Errorf("status %#v is not a boolean", s) 235 } 236 status = sb 237 } 238 239 stdout := "" 240 if s, ok := caps["stdout"]; ok { 241 if binary == "" { 242 return nil, fmt.Errorf("stdout set to %#v when binary is not set", s) 243 } 244 sb, ok := s.(string) 245 if !ok { 246 return nil, fmt.Errorf("stdout %#v is not a string", s) 247 } 248 stdout = sb 249 } 250 251 stderr := "" 252 if s, ok := caps["stderr"]; ok { 253 if binary == "" { 254 return nil, fmt.Errorf("stderr set to %#v when binary is not set", s) 255 } 256 sb, ok := s.(string) 257 if !ok { 258 return nil, fmt.Errorf("stderr %#v is not a string", s) 259 } 260 stderr = sb 261 } 262 263 return &wslCaps{ 264 binary: binary, 265 args: args, 266 port: port, 267 timeout: timeout, 268 env: env, 269 shutdown: shutdown, 270 status: status, 271 stdout: stdout, 272 stderr: stderr, 273 }, nil 274 } 275 276 func (d *Driver) startDriver() (chan error, error) { 277 errChan := make(chan error) 278 if d.caps.binary == "" { 279 return errChan, nil 280 } 281 282 cmd := exec.CommandContext(context.Background(), d.caps.binary, d.caps.args...) 283 284 cmd.Env = cmdhelper.BulkUpdateEnv(os.Environ(), d.caps.env) 285 286 stdout := os.Stdout 287 288 if d.caps.stdout != "" { 289 s, err := os.Create(d.caps.stdout) 290 if err != nil { 291 return nil, err 292 } 293 stdout = s 294 } 295 cmd.Stdout = stdout 296 297 stderr := os.Stderr 298 299 if d.caps.stderr != "" { 300 if d.caps.stderr == d.caps.stdout { 301 stderr = stdout 302 } else { 303 s, err := os.Create(d.caps.stderr) 304 if err != nil { 305 return nil, err 306 } 307 stderr = s 308 } 309 } 310 cmd.Stderr = stderr 311 312 if err := cmd.Start(); err != nil { 313 return nil, err 314 } 315 316 go func() { 317 err := cmd.Wait() 318 log.Printf("%s has exited: %v", d.caps.binary, err) 319 errChan <- err 320 d.stopped <- err 321 if stdout != os.Stdout { 322 stdout.Close() 323 } 324 if stderr != os.Stderr { 325 stdout.Close() 326 } 327 }() 328 329 d.cmd = cmd 330 331 return errChan, nil 332 } 333 334 // Forward forwards a request to the WebDriver endpoint managed by this server. 335 func (d *Driver) Forward(w http.ResponseWriter, r *http.Request) { 336 d.mu.Lock() 337 defer d.mu.Unlock() 338 339 if err := httphelper.Forward(r.Context(), d.Address, "", w, r); err != nil { 340 errorResponse(w, http.StatusInternalServerError, 13, "unknown error", err.Error()) 341 } 342 } 343 344 // NewSessionW3C creates a new session using the W3C protocol. 345 func (d *Driver) NewSession(ctx context.Context, caps *capabilities.Capabilities, w http.ResponseWriter) (string, error) { 346 wd, err := webdriver.CreateSession(ctx, d.Address, 1, caps.Strip("google:wslConfig", "google:sessionId")) 347 348 if err != nil { 349 return "", err 350 } 351 352 if wd.W3C() { 353 writeW3CNewSessionResponse(wd, w) 354 } else { 355 writeJWPNewSessionResponse(wd, w) 356 } 357 358 return wd.SessionID(), nil 359 } 360 361 func writeW3CNewSessionResponse(wd webdriver.WebDriver, w http.ResponseWriter) { 362 w.Header().Set("Content-Type", "application/json; charset=utf-8") 363 httphelper.SetDefaultResponseHeaders(w.Header()) 364 w.WriteHeader(http.StatusOK) 365 366 respJSON := map[string]interface{}{ 367 "value": map[string]interface{}{ 368 "capabilities": wd.Capabilities(), 369 "sessionId": wd.SessionID(), 370 }, 371 } 372 373 json.NewEncoder(w).Encode(respJSON) 374 } 375 376 func writeJWPNewSessionResponse(wd webdriver.WebDriver, w http.ResponseWriter) { 377 w.Header().Set("Content-Type", "application/json; charset=utf-8") 378 httphelper.SetDefaultResponseHeaders(w.Header()) 379 w.WriteHeader(http.StatusOK) 380 381 respJSON := map[string]interface{}{ 382 "value": wd.Capabilities(), 383 "sessionId": wd.SessionID(), 384 "status": 0, 385 } 386 387 json.NewEncoder(w).Encode(respJSON) 388 } 389 390 // Wait waits for the driver binary to exit, and returns an error if the binary exited with an error. 391 func (d *Driver) Wait(ctx context.Context) error { 392 select { 393 case err := <-d.stopped: 394 return err 395 case <-ctx.Done(): 396 return ctx.Err() 397 } 398 } 399 400 // Kill kills a running WebDriver server. 401 func (d *Driver) Shutdown(ctx context.Context) error { 402 if d.cmd == nil { 403 close(d.stopped) 404 return nil 405 } 406 if d.caps.shutdown { 407 httphelper.Get(ctx, d.Address+"/shutdown") 408 } else if err := d.cmd.Process.Signal(syscall.SIGTERM); err != nil { 409 if err := d.cmd.Process.Signal(os.Interrupt); err != nil { 410 d.cmd.Process.Kill() 411 } 412 } 413 414 if err := d.Wait(ctx); err != nil { 415 return d.cmd.Process.Kill() 416 } 417 return nil 418 } 419 420 func errorResponse(w http.ResponseWriter, httpStatus, status int, err, message string) { 421 w.Header().Set("Content-Type", "application/json; charset=utf-8") 422 httphelper.SetDefaultResponseHeaders(w.Header()) 423 w.WriteHeader(httpStatus) 424 425 respJSON := map[string]interface{}{ 426 "status": status, 427 "value": map[string]interface{}{ 428 "error": err, 429 "message": message, 430 }, 431 } 432 433 json.NewEncoder(w).Encode(respJSON) 434 }