github.com/bazelbuild/rules_webtesting@v0.2.0/go/wsl/driver/driver.go (about)

     1  // Copyright 2018 Google Inc.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //      http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  // Package driver launches a WebDriver driver endpoint binary (e.g. ChromeDriver, SafariDriver, etc)
    16  // based on a google:wslConfig capability.
    17  package driver
    18  
    19  import (
    20  	"context"
    21  	"encoding/json"
    22  	"errors"
    23  	"fmt"
    24  	"log"
    25  	"net"
    26  	"net/http"
    27  	"os"
    28  	"os/exec"
    29  	"strconv"
    30  	"sync"
    31  	"syscall"
    32  	"time"
    33  
    34  	"github.com/bazelbuild/rules_webtesting/go/cmdhelper"
    35  	"github.com/bazelbuild/rules_webtesting/go/httphelper"
    36  	"github.com/bazelbuild/rules_webtesting/go/metadata/capabilities"
    37  	"github.com/bazelbuild/rules_webtesting/go/webdriver"
    38  )
    39  
    40  const compName = "WSL Driver"
    41  
    42  // Driver is wrapper around a running WebDriver endpoint binary.
    43  type Driver struct {
    44  	Address string
    45  	caps    *wslCaps
    46  	stopped chan error
    47  	cmd     *exec.Cmd
    48  
    49  	// Mutex to prevent overlapping commands to remote end.
    50  	mu sync.Mutex
    51  }
    52  
    53  type wslCaps struct {
    54  	binary   string
    55  	args     []string
    56  	port     int
    57  	timeout  time.Duration
    58  	env      map[string]string
    59  	shutdown bool
    60  	status   bool
    61  	stdout   string
    62  	stderr   string
    63  }
    64  
    65  // New creates starts a WebDriver endpoint binary based on caps. Argument caps should just be
    66  // the google:wslConfig capability extracted from the capabilities passed into a new session request.
    67  func New(ctx context.Context, localHost, sessionID string, caps map[string]interface{}) (*Driver, error) {
    68  	wslCaps, err := extractWSLCaps(sessionID, caps)
    69  	if err != nil {
    70  		return nil, err
    71  	}
    72  	hostPort := net.JoinHostPort(localHost, strconv.Itoa(wslCaps.port))
    73  
    74  	d := &Driver{
    75  		Address: fmt.Sprintf("http://%s", hostPort),
    76  		caps:    wslCaps,
    77  		stopped: make(chan error),
    78  	}
    79  
    80  	errChan, err := d.startDriver()
    81  	if err != nil {
    82  		return nil, err
    83  	}
    84  
    85  	deadline, cancel := context.WithTimeout(ctx, d.caps.timeout)
    86  	defer cancel()
    87  
    88  	statusURL := fmt.Sprintf("http://%s/status", hostPort)
    89  
    90  	for {
    91  		select {
    92  		case err := <-errChan:
    93  			return nil, err
    94  		default:
    95  		}
    96  
    97  		if response, err := httphelper.Get(deadline, statusURL); err == nil {
    98  			if !d.caps.status {
    99  				// just wait for successful connection because status endpoint doesn't work.
   100  				break
   101  			}
   102  			if response.StatusCode == http.StatusOK {
   103  				respJSON := map[string]interface{}{}
   104  				if err := json.NewDecoder(response.Body).Decode(&respJSON); err == nil {
   105  					log.Printf("Response: %+v", respJSON)
   106  					if status, ok := respJSON["status"].(float64); ok {
   107  						if int(status) == 0 {
   108  							break
   109  						}
   110  					} else if value, ok := respJSON["value"].(map[string]interface{}); ok {
   111  						if ready, _ := value["ready"].(bool); ready {
   112  							break
   113  						}
   114  					}
   115  				}
   116  			}
   117  		}
   118  
   119  		if deadline.Err() != nil {
   120  			if d.cmd != nil {
   121  				go d.cmd.Process.Kill()
   122  			}
   123  			return nil, deadline.Err()
   124  		}
   125  
   126  		time.Sleep(10 * time.Millisecond)
   127  	}
   128  
   129  	return d, nil
   130  }
   131  
   132  func extractWSLCaps(sessionID string, caps map[string]interface{}) (*wslCaps, error) {
   133  	binary := ""
   134  	if b, ok := caps["binary"]; ok {
   135  		bs, ok := b.(string)
   136  		if !ok {
   137  			return nil, fmt.Errorf("binary %#v is not a string", b)
   138  		}
   139  		binary = bs
   140  	}
   141  
   142  	port := 0
   143  	if p, ok := caps["port"]; ok {
   144  		switch pt := p.(type) {
   145  		case float64:
   146  			port = int(pt)
   147  		case string:
   148  			pi, err := strconv.Atoi(pt)
   149  			if err != nil {
   150  				return nil, err
   151  			}
   152  			port = pi
   153  		default:
   154  			return nil, fmt.Errorf("port %#v is not a number or string", p)
   155  		}
   156  	}
   157  
   158  	if port == 0 {
   159  		return nil, errors.New(`port must be set (use "%WSLPORT:WSL%" if you don't care what port is used)`)
   160  	}
   161  
   162  	var args []string
   163  	if a, ok := caps["args"]; ok {
   164  		if binary == "" {
   165  			return nil, fmt.Errorf("args set to %#v when binary is not set", a)
   166  		}
   167  
   168  		argsInterface, ok := a.([]interface{})
   169  		if !ok {
   170  			return nil, fmt.Errorf("args %#v is not a list", a)
   171  		}
   172  
   173  		for _, argInterface := range argsInterface {
   174  			arg, ok := argInterface.(string)
   175  			if !ok {
   176  				return nil, fmt.Errorf("element %#v in args is not a string", argInterface)
   177  			}
   178  			args = append(args, arg)
   179  		}
   180  	}
   181  
   182  	timeout := 1 * time.Second
   183  	if t, ok := caps["timeout"]; ok {
   184  		switch tt := t.(type) {
   185  		case float64:
   186  			// Incoming value is in seconds.
   187  			to, err := time.ParseDuration(fmt.Sprintf("%fs", tt))
   188  			if err != nil {
   189  				return nil, err
   190  			}
   191  			timeout = to
   192  		case string:
   193  			to, err := time.ParseDuration(tt)
   194  			if err != nil {
   195  				return nil, err
   196  			}
   197  			timeout = to
   198  		default:
   199  			return nil, fmt.Errorf("timeout %#v is not a number or string", t)
   200  		}
   201  	}
   202  
   203  	env := map[string]string{}
   204  	if e, ok := caps["env"]; ok {
   205  		if binary == "" {
   206  			return nil, fmt.Errorf("env set to %#v when binary is not set", e)
   207  		}
   208  		em, ok := e.(map[string]interface{})
   209  		if !ok {
   210  			return nil, fmt.Errorf("env %#v is not a map", e)
   211  		}
   212  		for k, v := range em {
   213  			vs, ok := v.(string)
   214  			if !ok {
   215  				return nil, fmt.Errorf("value %#v for key %q in env is not a string", v, k)
   216  			}
   217  			env[k] = vs
   218  		}
   219  	}
   220  
   221  	shutdown := true
   222  	if s, ok := caps["shutdown"]; ok {
   223  		sb, ok := s.(bool)
   224  		if !ok {
   225  			return nil, fmt.Errorf("shutdown %#v is not a boolean", s)
   226  		}
   227  		shutdown = sb
   228  	}
   229  
   230  	status := true
   231  	if s, ok := caps["status"]; ok {
   232  		sb, ok := s.(bool)
   233  		if !ok {
   234  			return nil, fmt.Errorf("status %#v is not a boolean", s)
   235  		}
   236  		status = sb
   237  	}
   238  
   239  	stdout := ""
   240  	if s, ok := caps["stdout"]; ok {
   241  		if binary == "" {
   242  			return nil, fmt.Errorf("stdout set to %#v when binary is not set", s)
   243  		}
   244  		sb, ok := s.(string)
   245  		if !ok {
   246  			return nil, fmt.Errorf("stdout %#v is not a string", s)
   247  		}
   248  		stdout = sb
   249  	}
   250  
   251  	stderr := ""
   252  	if s, ok := caps["stderr"]; ok {
   253  		if binary == "" {
   254  			return nil, fmt.Errorf("stderr set to %#v when binary is not set", s)
   255  		}
   256  		sb, ok := s.(string)
   257  		if !ok {
   258  			return nil, fmt.Errorf("stderr %#v is not a string", s)
   259  		}
   260  		stderr = sb
   261  	}
   262  
   263  	return &wslCaps{
   264  		binary:   binary,
   265  		args:     args,
   266  		port:     port,
   267  		timeout:  timeout,
   268  		env:      env,
   269  		shutdown: shutdown,
   270  		status:   status,
   271  		stdout:   stdout,
   272  		stderr:   stderr,
   273  	}, nil
   274  }
   275  
   276  func (d *Driver) startDriver() (chan error, error) {
   277  	errChan := make(chan error)
   278  	if d.caps.binary == "" {
   279  		return errChan, nil
   280  	}
   281  
   282  	cmd := exec.CommandContext(context.Background(), d.caps.binary, d.caps.args...)
   283  
   284  	cmd.Env = cmdhelper.BulkUpdateEnv(os.Environ(), d.caps.env)
   285  
   286  	stdout := os.Stdout
   287  
   288  	if d.caps.stdout != "" {
   289  		s, err := os.Create(d.caps.stdout)
   290  		if err != nil {
   291  			return nil, err
   292  		}
   293  		stdout = s
   294  	}
   295  	cmd.Stdout = stdout
   296  
   297  	stderr := os.Stderr
   298  
   299  	if d.caps.stderr != "" {
   300  		if d.caps.stderr == d.caps.stdout {
   301  			stderr = stdout
   302  		} else {
   303  			s, err := os.Create(d.caps.stderr)
   304  			if err != nil {
   305  				return nil, err
   306  			}
   307  			stderr = s
   308  		}
   309  	}
   310  	cmd.Stderr = stderr
   311  
   312  	if err := cmd.Start(); err != nil {
   313  		return nil, err
   314  	}
   315  
   316  	go func() {
   317  		err := cmd.Wait()
   318  		log.Printf("%s has exited: %v", d.caps.binary, err)
   319  		errChan <- err
   320  		d.stopped <- err
   321  		if stdout != os.Stdout {
   322  			stdout.Close()
   323  		}
   324  		if stderr != os.Stderr {
   325  			stdout.Close()
   326  		}
   327  	}()
   328  
   329  	d.cmd = cmd
   330  
   331  	return errChan, nil
   332  }
   333  
   334  // Forward forwards a request to the WebDriver endpoint managed by this server.
   335  func (d *Driver) Forward(w http.ResponseWriter, r *http.Request) {
   336  	d.mu.Lock()
   337  	defer d.mu.Unlock()
   338  
   339  	if err := httphelper.Forward(r.Context(), d.Address, "", w, r); err != nil {
   340  		errorResponse(w, http.StatusInternalServerError, 13, "unknown error", err.Error())
   341  	}
   342  }
   343  
   344  // NewSessionW3C creates a new session using the W3C protocol.
   345  func (d *Driver) NewSession(ctx context.Context, caps *capabilities.Capabilities, w http.ResponseWriter) (string, error) {
   346  	wd, err := webdriver.CreateSession(ctx, d.Address, 1, caps.Strip("google:wslConfig", "google:sessionId"))
   347  
   348  	if err != nil {
   349  		return "", err
   350  	}
   351  
   352  	if wd.W3C() {
   353  		writeW3CNewSessionResponse(wd, w)
   354  	} else {
   355  		writeJWPNewSessionResponse(wd, w)
   356  	}
   357  
   358  	return wd.SessionID(), nil
   359  }
   360  
   361  func writeW3CNewSessionResponse(wd webdriver.WebDriver, w http.ResponseWriter) {
   362  	w.Header().Set("Content-Type", "application/json; charset=utf-8")
   363  	httphelper.SetDefaultResponseHeaders(w.Header())
   364  	w.WriteHeader(http.StatusOK)
   365  
   366  	respJSON := map[string]interface{}{
   367  		"value": map[string]interface{}{
   368  			"capabilities": wd.Capabilities(),
   369  			"sessionId":    wd.SessionID(),
   370  		},
   371  	}
   372  
   373  	json.NewEncoder(w).Encode(respJSON)
   374  }
   375  
   376  func writeJWPNewSessionResponse(wd webdriver.WebDriver, w http.ResponseWriter) {
   377  	w.Header().Set("Content-Type", "application/json; charset=utf-8")
   378  	httphelper.SetDefaultResponseHeaders(w.Header())
   379  	w.WriteHeader(http.StatusOK)
   380  
   381  	respJSON := map[string]interface{}{
   382  		"value":     wd.Capabilities(),
   383  		"sessionId": wd.SessionID(),
   384  		"status":    0,
   385  	}
   386  
   387  	json.NewEncoder(w).Encode(respJSON)
   388  }
   389  
   390  // Wait waits for the driver binary to exit, and returns an error if the binary exited with an error.
   391  func (d *Driver) Wait(ctx context.Context) error {
   392  	select {
   393  	case err := <-d.stopped:
   394  		return err
   395  	case <-ctx.Done():
   396  		return ctx.Err()
   397  	}
   398  }
   399  
   400  // Kill kills a running WebDriver server.
   401  func (d *Driver) Shutdown(ctx context.Context) error {
   402  	if d.cmd == nil {
   403  		close(d.stopped)
   404  		return nil
   405  	}
   406  	if d.caps.shutdown {
   407  		httphelper.Get(ctx, d.Address+"/shutdown")
   408  	} else if err := d.cmd.Process.Signal(syscall.SIGTERM); err != nil {
   409  		if err := d.cmd.Process.Signal(os.Interrupt); err != nil {
   410  			d.cmd.Process.Kill()
   411  		}
   412  	}
   413  
   414  	if err := d.Wait(ctx); err != nil {
   415  		return d.cmd.Process.Kill()
   416  	}
   417  	return nil
   418  }
   419  
   420  func errorResponse(w http.ResponseWriter, httpStatus, status int, err, message string) {
   421  	w.Header().Set("Content-Type", "application/json; charset=utf-8")
   422  	httphelper.SetDefaultResponseHeaders(w.Header())
   423  	w.WriteHeader(httpStatus)
   424  
   425  	respJSON := map[string]interface{}{
   426  		"status": status,
   427  		"value": map[string]interface{}{
   428  			"error":   err,
   429  			"message": message,
   430  		},
   431  	}
   432  
   433  	json.NewEncoder(w).Encode(respJSON)
   434  }