github.com/bazelbuild/rules_webtesting@v0.2.0/go/wtl/proxy/driverhub/driver_hub.go (about)

     1  // Copyright 2016 Google Inc.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //      http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  // Package driverhub provides a handler for proxying connections to a Selenium server.
    16  package driverhub
    17  
    18  import (
    19  	"context"
    20  	"encoding/json"
    21  	"fmt"
    22  	"io/ioutil"
    23  	"log"
    24  	"net/http"
    25  	"reflect"
    26  	"sync"
    27  	"time"
    28  
    29  	"github.com/bazelbuild/rules_webtesting/go/errors"
    30  	"github.com/bazelbuild/rules_webtesting/go/healthreporter"
    31  	"github.com/bazelbuild/rules_webtesting/go/httphelper"
    32  	"github.com/bazelbuild/rules_webtesting/go/metadata"
    33  	"github.com/bazelbuild/rules_webtesting/go/metadata/capabilities"
    34  	"github.com/bazelbuild/rules_webtesting/go/webdriver"
    35  	"github.com/bazelbuild/rules_webtesting/go/wtl/diagnostics"
    36  	"github.com/bazelbuild/rules_webtesting/go/wtl/environment"
    37  	"github.com/bazelbuild/rules_webtesting/go/wtl/proxy"
    38  	"github.com/bazelbuild/rules_webtesting/go/wtl/proxy/driverhub/debugger"
    39  	"github.com/gorilla/mux"
    40  )
    41  
    42  const envTimeout = 5 * time.Minute // some environments such as Android take a long time to start up.
    43  
    44  // WebDriverHub routes message to the various WebDriver sessions.
    45  type WebDriverHub struct {
    46  	*mux.Router
    47  	environment.Env
    48  	*metadata.Metadata
    49  	*http.Client
    50  	diagnostics.Diagnostics
    51  	Proxy    *proxy.Proxy
    52  	Debugger *debugger.Debugger
    53  
    54  	healthyOnce sync.Once
    55  
    56  	mu               sync.RWMutex
    57  	sessions         map[string]*WebDriverSession
    58  	reusableSessions []*WebDriverSession
    59  	nextID           int
    60  }
    61  
    62  // NewHandler creates a handler for /wd/hub paths that delegates to a WebDriver server instance provided by env.
    63  func HTTPHandlerProvider(p *proxy.Proxy) (proxy.HTTPHandler, error) {
    64  	var d *debugger.Debugger
    65  	if p.Metadata.DebuggerPort != 0 {
    66  		d = debugger.New(p.Metadata.DebuggerPort)
    67  	}
    68  	h := &WebDriverHub{
    69  		Router:      mux.NewRouter(),
    70  		Env:         p.Env,
    71  		sessions:    map[string]*WebDriverSession{},
    72  		Client:      &http.Client{},
    73  		Diagnostics: p.Diagnostics,
    74  		Metadata:    p.Metadata,
    75  		Proxy:       p,
    76  		Debugger:    d,
    77  	}
    78  
    79  	h.Path("/wd/hub/session").Methods("POST").HandlerFunc(h.createSession)
    80  	h.Path("/wd/hub/session").HandlerFunc(unknownMethod)
    81  	h.PathPrefix("/wd/hub/session/{sessionID}").HandlerFunc(h.routeToSession)
    82  	h.PathPrefix("/wd/hub/{command}").HandlerFunc(h.defaultForward)
    83  	h.PathPrefix("/").HandlerFunc(unknownCommand)
    84  
    85  	return h, nil
    86  }
    87  
    88  func (h *WebDriverHub) ServeHTTP(w http.ResponseWriter, r *http.Request) {
    89  	if h.Debugger != nil {
    90  		// allow debugger to pause for breakpoint, log to debugger front-end.
    91  		h.Debugger.Request(r)
    92  	}
    93  	// TODO(DrMarcII) add support for breakpointing on responses.
    94  	h.Router.ServeHTTP(w, r)
    95  }
    96  
    97  // Name is the name of the component used in error messages.
    98  func (h *WebDriverHub) Name() string {
    99  	return "WebDriver Hub"
   100  }
   101  
   102  // Healthy returns nil if the WebDriverHub is ready for use, and an error otherwise.
   103  func (h *WebDriverHub) Healthy(ctx context.Context) error {
   104  	if h.Debugger != nil {
   105  		return h.Debugger.Healthy(ctx)
   106  	}
   107  	return nil
   108  }
   109  
   110  // AddSession adds a session to WebDriverHub.
   111  func (h *WebDriverHub) AddSession(id string, session *WebDriverSession) {
   112  	h.mu.Lock()
   113  	defer h.mu.Unlock()
   114  	if h.sessions == nil {
   115  		h.sessions = map[string]*WebDriverSession{}
   116  	}
   117  	h.sessions[id] = session
   118  }
   119  
   120  // RemoveSession removes a session from WebDriverHub.
   121  func (h *WebDriverHub) RemoveSession(id string) {
   122  	h.mu.Lock()
   123  	defer h.mu.Unlock()
   124  	if h.sessions == nil {
   125  		return
   126  	}
   127  	delete(h.sessions, id)
   128  }
   129  
   130  // GetSession gets the session for a given WebDriver session id..
   131  func (h *WebDriverHub) GetSession(id string) *WebDriverSession {
   132  	h.mu.RLock()
   133  	defer h.mu.RUnlock()
   134  	return h.sessions[id]
   135  }
   136  
   137  // NextID gets the next available internal id for a session.
   138  func (h *WebDriverHub) NextID() int {
   139  	h.mu.Lock()
   140  	defer h.mu.Unlock()
   141  	id := h.nextID
   142  	h.nextID++
   143  	return id
   144  }
   145  
   146  // GetActiveSessions returns the ids for all currently active sessions.
   147  func (h *WebDriverHub) GetActiveSessions() []string {
   148  	result := []string{}
   149  	h.mu.RLock()
   150  	defer h.mu.RUnlock()
   151  	for id := range h.sessions {
   152  		result = append(result, id)
   153  	}
   154  	return result
   155  }
   156  
   157  // Shutdown  shuts down any running sessions.
   158  func (h *WebDriverHub) Shutdown(ctx context.Context) error {
   159  	for _, id := range h.GetActiveSessions() {
   160  		session := h.GetSession(id)
   161  		session.quit(ctx, false)
   162  	}
   163  	for _, session := range h.reusableSessions {
   164  		session.quit(ctx, false)
   165  	}
   166  	return nil
   167  }
   168  
   169  // GetReusableSession grabs a reusable session if one is available that matches caps.
   170  func (h *WebDriverHub) GetReusableSession(ctx context.Context, caps *capabilities.Capabilities) (*WebDriverSession, bool) {
   171  	if !capabilities.CanReuseSession(caps) {
   172  		return nil, false
   173  	}
   174  
   175  	h.mu.Lock()
   176  	defer h.mu.Unlock()
   177  	for i, session := range h.reusableSessions {
   178  		if reflect.DeepEqual(caps, session.RequestedCaps) {
   179  			h.reusableSessions = append(h.reusableSessions[:i], h.reusableSessions[i+1:]...)
   180  			if err := session.WebDriver.Healthy(ctx); err == nil {
   181  				return session, true
   182  			}
   183  			return session, true
   184  		}
   185  	}
   186  	return nil, false
   187  }
   188  
   189  // AddReusableSession adds a session that can be reused.
   190  func (h *WebDriverHub) AddReusableSession(session *WebDriverSession) error {
   191  	if !capabilities.CanReuseSession(session.RequestedCaps) {
   192  		return errors.New(h.Name(), "session is not reusable.")
   193  	}
   194  	h.reusableSessions = append(h.reusableSessions, session)
   195  	return nil
   196  }
   197  
   198  func (h *WebDriverHub) routeToSession(w http.ResponseWriter, r *http.Request) {
   199  	sid := mux.Vars(r)["sessionID"]
   200  	session := h.GetSession(sid)
   201  
   202  	if session == nil {
   203  		invalidSessionID(w, sid)
   204  		return
   205  	}
   206  	session.ServeHTTP(w, r)
   207  }
   208  
   209  func (h *WebDriverHub) createSession(w http.ResponseWriter, r *http.Request) {
   210  	ctx := r.Context()
   211  	log.Print("Creating session\n\n")
   212  
   213  	if err := h.waitForHealthyEnv(ctx); err != nil {
   214  		sessionNotCreated(w, err)
   215  		return
   216  	}
   217  
   218  	body, err := ioutil.ReadAll(r.Body)
   219  	if err != nil {
   220  		sessionNotCreated(w, err)
   221  		return
   222  	}
   223  
   224  	j := map[string]interface{}{}
   225  
   226  	if err := json.Unmarshal(body, &j); err != nil {
   227  		sessionNotCreated(w, err)
   228  		return
   229  	}
   230  
   231  	requestedCaps, err := capabilities.FromNewSessionArgs(j)
   232  	if err != nil {
   233  		sessionNotCreated(w, err)
   234  		return
   235  	}
   236  
   237  	id := h.NextID()
   238  
   239  	caps, err := h.Env.StartSession(ctx, id, requestedCaps)
   240  	if err != nil {
   241  		sessionNotCreated(w, err)
   242  		return
   243  	}
   244  
   245  	log.Printf("Caps: %+v", caps)
   246  
   247  	var session *WebDriverSession
   248  
   249  	if reusable, ok := h.GetReusableSession(ctx, caps); ok {
   250  		reusable.Unpause(id)
   251  		session = reusable
   252  	} else {
   253  		// TODO(DrMarcII) parameterize attempts based on browser metadata
   254  		driver, err := webdriver.CreateSession(ctx, h.Env.WDAddress(ctx), 3, caps.Strip("google:canReuseSession"))
   255  		if err != nil {
   256  			if err2 := h.Env.StopSession(ctx, id); err2 != nil {
   257  				log.Printf("error stopping session after failing to launch webdriver: %v", err2)
   258  			}
   259  			sessionNotCreated(w, err)
   260  			return
   261  		}
   262  
   263  		s, err := CreateSession(id, h, driver, caps)
   264  		if err != nil {
   265  			sessionNotCreated(w, err)
   266  			return
   267  		}
   268  		session = s
   269  	}
   270  
   271  	h.AddSession(session.WebDriver.SessionID(), session)
   272  
   273  	var respJSON map[string]interface{}
   274  
   275  	if session.WebDriver.W3C() {
   276  		respJSON = map[string]interface{}{
   277  			"value": map[string]interface{}{
   278  				"capabilities": session.WebDriver.Capabilities(),
   279  				"sessionId":    session.WebDriver.SessionID(),
   280  			},
   281  		}
   282  	} else {
   283  		respJSON = map[string]interface{}{
   284  			"value":     session.WebDriver.Capabilities(),
   285  			"sessionId": session.WebDriver.SessionID(),
   286  			"status":    0,
   287  		}
   288  	}
   289  
   290  	bytes, err := json.Marshal(respJSON)
   291  	if err != nil {
   292  		unknownError(w, err)
   293  		return
   294  	}
   295  
   296  	w.Header().Set("Content-Type", contentType)
   297  	httphelper.SetDefaultResponseHeaders(w.Header())
   298  	w.WriteHeader(http.StatusOK)
   299  	w.Write(bytes)
   300  }
   301  
   302  func (h *WebDriverHub) defaultForward(w http.ResponseWriter, r *http.Request) {
   303  	ctx := r.Context()
   304  	if err := h.waitForHealthyEnv(ctx); err != nil {
   305  		unknownError(w, err)
   306  		return
   307  	}
   308  
   309  	if err := httphelper.Forward(r.Context(), h.Env.WDAddress(ctx), "/wd/hub/", w, r); err != nil {
   310  		unknownError(w, err)
   311  	}
   312  }
   313  
   314  func (h *WebDriverHub) waitForHealthyEnv(ctx context.Context) error {
   315  	h.healthyOnce.Do(func() {
   316  		healthyCtx, cancel := context.WithTimeout(ctx, envTimeout)
   317  		defer cancel()
   318  		// ignore error here as we will call and return Healthy below.
   319  		healthreporter.WaitForHealthy(healthyCtx, h.Env)
   320  	})
   321  	err := h.Env.Healthy(ctx)
   322  	if err != nil {
   323  		err = errors.New(h.Name(), fmt.Sprintf("environment is unhealthy: %v", err))
   324  	}
   325  	return err
   326  }