github.com/bazelbuild/rules_webtesting@v0.2.0/go/wtl/proxy/driverhub/driver_session.go (about)

     1  // Copyright 2016 Google Inc.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //      http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package driverhub
    16  
    17  import (
    18  	"bytes"
    19  	"context"
    20  	"encoding/json"
    21  	"fmt"
    22  	"io/ioutil"
    23  	"log"
    24  	"net/http"
    25  	"strconv"
    26  	"strings"
    27  	"sync"
    28  
    29  	"github.com/bazelbuild/rules_webtesting/go/errors"
    30  	"github.com/bazelbuild/rules_webtesting/go/httphelper"
    31  	"github.com/bazelbuild/rules_webtesting/go/metadata"
    32  	"github.com/bazelbuild/rules_webtesting/go/metadata/capabilities"
    33  	"github.com/bazelbuild/rules_webtesting/go/webdriver"
    34  	"github.com/bazelbuild/rules_webtesting/go/wtl/diagnostics"
    35  	"github.com/gorilla/mux"
    36  )
    37  
    38  // WebDriverSession is an http.Handler for forwarding requests to a WebDriver session.
    39  type WebDriverSession struct {
    40  	*mux.Router
    41  	diagnostics.Diagnostics
    42  	WebDriverHub *WebDriverHub
    43  	webdriver.WebDriver
    44  	ID            int
    45  	handler       HandlerFunc
    46  	sessionPath   string
    47  	RequestedCaps *capabilities.Capabilities
    48  	Metadata      *metadata.Metadata
    49  
    50  	mu      sync.RWMutex
    51  	stopped bool
    52  }
    53  
    54  // A handlerProvider wraps another HandlerFunc to create a new HandlerFunc.
    55  // If the second return value is false, then the provider did not construct a new HandlerFunc.
    56  type handlerProvider func(session *WebDriverSession, caps *capabilities.Capabilities, base HandlerFunc) (HandlerFunc, bool)
    57  
    58  // HandlerFunc is a func for handling a request to a WebDriver session.
    59  type HandlerFunc func(context.Context, Request) (Response, error)
    60  
    61  // Request wraps a request to a WebDriver session.
    62  type Request struct {
    63  	// HTTP Method for this request (e.g. http.MethodGet, ...).
    64  	Method string
    65  	// Path of the request after the session id.
    66  	Path []string
    67  	// Any HTTP headers sent with the request.
    68  	Header http.Header
    69  	// The body of the request.
    70  	Body []byte
    71  }
    72  
    73  // Response describes what response should be returned for a request to WebDriver session.
    74  type Response struct {
    75  	// HTTP status code to return (e.g. http.StatusOK, ...).
    76  	Status int
    77  	// Any HTTP Headers that should be included in the response.
    78  	Header http.Header
    79  	// The body of the response.
    80  	Body []byte
    81  }
    82  
    83  var providers = []handlerProvider{}
    84  
    85  // HandlerProviderFunc adds additional handlers that will wrap any previously defined handlers.
    86  //
    87  // It is important to note that later added handlers will wrap earlier added handlers.
    88  // E.g. if you call as follows:
    89  //   HandlerProviderFunc(hp1)
    90  //   HandlerProviderFunc(hp2)
    91  //   HandlerProviderFunc(hp3)
    92  //
    93  // The generated handler will be constructed as follows:
    94  //   hp3(session, caps, hp2(session, caps, hp1(session, caps, base)))
    95  // where base is the a default function that forwards commands to WebDriver unchanged.
    96  func HandlerProviderFunc(provider handlerProvider) {
    97  	providers = append(providers, provider)
    98  }
    99  
   100  func createHandler(session *WebDriverSession, caps *capabilities.Capabilities) HandlerFunc {
   101  	handler := createBaseHandler(session.WebDriver)
   102  
   103  	for _, provider := range providers {
   104  		if h, ok := provider(session, caps, handler); ok {
   105  			handler = h
   106  		}
   107  	}
   108  	return handler
   109  }
   110  
   111  // CreateSession creates a WebDriverSession object.
   112  func CreateSession(id int, hub *WebDriverHub, driver webdriver.WebDriver, caps *capabilities.Capabilities) (*WebDriverSession, error) {
   113  	sessionPath := fmt.Sprintf("/wd/hub/session/%s", driver.SessionID())
   114  	session := &WebDriverSession{
   115  		ID:            id,
   116  		Diagnostics:   hub.Diagnostics,
   117  		WebDriverHub:  hub,
   118  		WebDriver:     driver,
   119  		sessionPath:   sessionPath,
   120  		Router:        mux.NewRouter(),
   121  		RequestedCaps: caps,
   122  		Metadata:      hub.Metadata,
   123  	}
   124  
   125  	session.handler = createHandler(session, caps)
   126  	// Route for commands for this session.
   127  	session.PathPrefix(sessionPath).HandlerFunc(session.defaultHandler)
   128  	// Route for commands for some other session. If this happens, the hub has
   129  	// done something wrong.
   130  	session.PathPrefix("/wd/hub/session/{sessionID}").HandlerFunc(session.wrongSession)
   131  	// Route for all other paths that aren't WebDriver commands. This also implies
   132  	// that the hub has done something wrong.
   133  	session.PathPrefix("/").HandlerFunc(session.unknownCommand)
   134  
   135  	return session, nil
   136  }
   137  
   138  // Name is the name of the component used in error messages.
   139  func (s *WebDriverSession) Name() string {
   140  	return "WebDriver Session Handler"
   141  }
   142  
   143  func (s *WebDriverSession) wrongSession(w http.ResponseWriter, r *http.Request) {
   144  	vars := mux.Vars(r)
   145  	s.Severe(errors.New(s.Name(), "request routed to wrong session handler"))
   146  	unknownError(w, fmt.Errorf("request for session %q was routed to handler for %q", vars["sessionID"], s.SessionID()))
   147  }
   148  
   149  func (s *WebDriverSession) unknownCommand(w http.ResponseWriter, r *http.Request) {
   150  	s.Severe(errors.New(s.Name(), "unknown command routed to session handler"))
   151  	unknownCommand(w, r)
   152  }
   153  
   154  // Quit can be called by handlers to quit this session.
   155  func (s *WebDriverSession) Quit(ctx context.Context, _ Request) (Response, error) {
   156  	if err := s.quit(ctx, capabilities.CanReuseSession(s.RequestedCaps)); err != nil {
   157  		return ResponseFromError(err)
   158  	}
   159  
   160  	return Response{
   161  		Status: http.StatusOK,
   162  		Body:   []byte(`{"status": 0}`),
   163  	}, nil
   164  }
   165  
   166  // Quit can be called by handlers to quit this session.
   167  func (s *WebDriverSession) quit(ctx context.Context, reusable bool) error {
   168  	s.mu.Lock()
   169  	defer s.mu.Unlock()
   170  
   171  	s.stopped = true
   172  
   173  	var wdErr error
   174  
   175  	if !reusable {
   176  		wdErr = s.WebDriver.Quit(ctx)
   177  		if wdErr != nil {
   178  			s.Warning(wdErr)
   179  		}
   180  	}
   181  
   182  	envErr := s.WebDriverHub.Env.StopSession(ctx, s.ID)
   183  	if envErr != nil {
   184  		s.Warning(envErr)
   185  	}
   186  
   187  	s.WebDriverHub.RemoveSession(s.SessionID())
   188  
   189  	if wdErr != nil {
   190  		return wdErr
   191  	}
   192  	if envErr != nil {
   193  		return envErr
   194  	}
   195  
   196  	if reusable {
   197  		s.WebDriverHub.AddReusableSession(s)
   198  	}
   199  
   200  	return nil
   201  }
   202  
   203  func (s *WebDriverSession) commandPathTokens(path string) []string {
   204  	commandPath := strings.Trim(strings.TrimPrefix(path, s.sessionPath), "/")
   205  	if commandPath == "" {
   206  		return []string{}
   207  	}
   208  	return strings.Split(commandPath, "/")
   209  }
   210  
   211  // Unpause makes the session usable again and associates it with the given session id.
   212  func (s *WebDriverSession) Unpause(id int) {
   213  	s.mu.Lock()
   214  	s.stopped = false
   215  	s.ID = id
   216  	s.mu.Unlock()
   217  }
   218  
   219  func (s *WebDriverSession) defaultHandler(w http.ResponseWriter, r *http.Request) {
   220  	ctx := r.Context()
   221  	vars := mux.Vars(r)
   222  	pathTokens := s.commandPathTokens(r.URL.Path)
   223  
   224  	s.mu.Lock()
   225  	stopped := s.stopped
   226  	s.mu.Unlock()
   227  
   228  	if stopped {
   229  		invalidSessionID(w, vars["sessionID"])
   230  		return
   231  	}
   232  
   233  	body, err := ioutil.ReadAll(r.Body)
   234  	if err != nil {
   235  		unknownError(w, err)
   236  		return
   237  	}
   238  
   239  	req := Request{
   240  		Method: r.Method,
   241  		Path:   pathTokens,
   242  		Header: r.Header,
   243  		Body:   body,
   244  	}
   245  	resp, err := s.handler(ctx, req)
   246  	if err != nil {
   247  		if ctx.Err() == context.Canceled {
   248  			log.Printf("[%s] request %+v was canceled.", s.Name(), req)
   249  			return
   250  		}
   251  		if ctx.Err() == context.DeadlineExceeded {
   252  			s.Warning(errors.New(s.Name(), fmt.Errorf("request %+v exceeded deadline", req)))
   253  			timeout(w, r.URL.Path)
   254  			return
   255  		}
   256  		s.Severe(errors.New(s.Name(), err))
   257  		unknownError(w, err)
   258  		return
   259  	}
   260  
   261  	if len(resp.Body) != 0 {
   262  		w.Header().Set("Content-Type", contentType)
   263  	}
   264  	if resp.Header != nil {
   265  		// Copy response headers from resp to w
   266  		for k, vs := range resp.Header {
   267  			w.Header().Del(k)
   268  			for _, v := range vs {
   269  				w.Header().Add(k, v)
   270  			}
   271  		}
   272  	}
   273  
   274  	// TODO(fisherii): needed to play nice with Dart Sync WebDriver. Delete when Dart Sync WebDriver is deleted.
   275  	w.Header().Set("Transfer-Encoding", "identity")
   276  	w.Header().Set("Content-Length", strconv.Itoa(len(resp.Body)))
   277  
   278  	httphelper.SetDefaultResponseHeaders(w.Header())
   279  
   280  	// Copy status code from resp to w
   281  	w.WriteHeader(resp.Status)
   282  
   283  	// Write body from resp to w
   284  	w.Write(resp.Body)
   285  }
   286  
   287  func createBaseHandler(driver webdriver.WebDriver) HandlerFunc {
   288  	client := &http.Client{}
   289  
   290  	return func(ctx context.Context, rq Request) (Response, error) {
   291  		url, err := driver.CommandURL(rq.Path...)
   292  		if err != nil {
   293  			return Response{}, err
   294  		}
   295  
   296  		req, err := http.NewRequest(rq.Method, url.String(), bytes.NewReader(rq.Body))
   297  		if err != nil {
   298  			return Response{}, err
   299  		}
   300  		req = req.WithContext(ctx)
   301  		for k, v := range rq.Header {
   302  			if !strings.HasPrefix(k, "x-google-") {
   303  				req.Header[k] = v
   304  			}
   305  		}
   306  
   307  		resp, err := client.Do(req)
   308  		if err != nil {
   309  			return Response{}, err
   310  		}
   311  		defer resp.Body.Close()
   312  		body, err := ioutil.ReadAll(resp.Body)
   313  		if err != nil {
   314  			return Response{}, err
   315  		}
   316  		return Response{resp.StatusCode, resp.Header, body}, nil
   317  	}
   318  }
   319  
   320  // ResponseFromError generates a Response object for err.
   321  func ResponseFromError(err error) (Response, error) {
   322  	body, e := webdriver.MarshalError(err)
   323  	return Response{
   324  		Status: webdriver.ErrorHTTPStatus(err),
   325  		Body:   body,
   326  	}, e
   327  }
   328  
   329  // SuccessfulResponse generate a response object indicating success.
   330  func SuccessfulResponse(value interface{}) (Response, error) {
   331  	body := map[string]interface{}{
   332  		"status": 0,
   333  	}
   334  
   335  	if value != nil {
   336  		body["value"] = value
   337  	}
   338  
   339  	bytes, err := json.Marshal(body)
   340  	return Response{
   341  		Status: http.StatusOK,
   342  		Body:   bytes,
   343  	}, err
   344  }