github.com/bazelbuild/rules_webtesting@v0.2.0/go/wsl/hub/hub.go (about)

     1  // Copyright 2018 Google Inc.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //      http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  // Package hub launches WebDriver servers and correctly dispatches requests to the correct server
    16  // based on session id.
    17  package hub
    18  
    19  import (
    20  	"context"
    21  	"encoding/json"
    22  	"errors"
    23  	"fmt"
    24  	"log"
    25  	"net/http"
    26  	"strconv"
    27  	"strings"
    28  	"sync"
    29  	"time"
    30  
    31  	"github.com/bazelbuild/rules_webtesting/go/httphelper"
    32  	"github.com/bazelbuild/rules_webtesting/go/metadata/capabilities"
    33  	"github.com/bazelbuild/rules_webtesting/go/wsl/driver"
    34  	"github.com/bazelbuild/rules_webtesting/go/wsl/resolver"
    35  )
    36  
    37  // A Hub is an HTTP handler that manages incoming WebDriver requests.
    38  type Hub struct {
    39  	// Mutex to protext access to sessions.
    40  	mu       sync.RWMutex
    41  	sessions map[string]*driver.Driver
    42  
    43  	localHost string
    44  	uploader  http.Handler
    45  }
    46  
    47  // New creates a new Hub.
    48  func New(localHost string, uploader http.Handler) *Hub {
    49  	return &Hub{
    50  		sessions:  map[string]*driver.Driver{},
    51  		localHost: localHost,
    52  		uploader:  uploader,
    53  	}
    54  }
    55  
    56  func (h *Hub) ServeHTTP(w http.ResponseWriter, r *http.Request) {
    57  	path := strings.Split(r.URL.Path, "/")[1:]
    58  
    59  	if len(path) < 1 || path[0] != "session" {
    60  		errorResponse(w, http.StatusNotFound, 9, "unknown command", fmt.Sprintf("%q is not a known command", r.URL.Path))
    61  		return
    62  	}
    63  
    64  	if r.Method == http.MethodPost && len(path) == 1 {
    65  		h.newSession(w, r)
    66  		return
    67  	}
    68  
    69  	if len(path) < 2 {
    70  		errorResponse(w, http.StatusMethodNotAllowed, 9, "unknown method", fmt.Sprintf("%s is not a supported method for /session", r.Method))
    71  		return
    72  	}
    73  
    74  	driver := h.driver(path[1])
    75  	if driver == nil {
    76  		errorResponse(w, http.StatusNotFound, 6, "invalid session id", fmt.Sprintf("%q is not an active session", path[1]))
    77  		return
    78  	}
    79  
    80  	if r.Method == http.MethodDelete && len(path) == 2 {
    81  		h.quitSession(path[1], driver, w, r)
    82  		return
    83  	}
    84  
    85  	if len(path) == 3 && path[2] == "file" {
    86  		h.uploader.ServeHTTP(w, r)
    87  		return
    88  	}
    89  
    90  	driver.Forward(w, r)
    91  }
    92  
    93  func (h *Hub) driver(session string) *driver.Driver {
    94  	h.mu.RLock()
    95  	defer h.mu.RUnlock()
    96  	return h.sessions[session]
    97  }
    98  
    99  func (h *Hub) newSession(w http.ResponseWriter, r *http.Request) {
   100  	reqJSON := map[string]interface{}{}
   101  
   102  	if err := json.NewDecoder(r.Body).Decode(&reqJSON); err != nil {
   103  		errorResponse(w, http.StatusBadRequest, 13, "invalid argument", err.Error())
   104  		return
   105  	}
   106  
   107  	caps, err := capabilities.FromNewSessionArgs(reqJSON)
   108  	if err != nil {
   109  		errorResponse(w, http.StatusBadRequest, 13, "invalid argument", err.Error())
   110  		return
   111  	}
   112  
   113  	session, driver, err := h.newSessionFromCaps(r.Context(), caps, w)
   114  	if err != nil {
   115  		errorResponse(w, http.StatusInternalServerError, 33, "session not created", fmt.Sprintf("unable to create session: %v", err))
   116  		log.Printf("Error creating webdriver session: %v", err)
   117  		return
   118  	}
   119  
   120  	h.mu.Lock()
   121  	defer h.mu.Unlock()
   122  	h.sessions[session] = driver
   123  }
   124  
   125  func (h *Hub) newSessionFromCaps(ctx context.Context, caps *capabilities.Capabilities, w http.ResponseWriter) (string, *driver.Driver, error) {
   126  	sessionID := "last"
   127  	if i, ok := caps.AlwaysMatch["google:sessionId"]; ok {
   128  		switch ii := i.(type) {
   129  		case string:
   130  			sessionID = ii
   131  		case float64:
   132  			sessionID = strconv.Itoa(int(ii))
   133  		default:
   134  			return "", nil, fmt.Errorf("google:sessionId %#v is not a string or number", i)
   135  		}
   136  	}
   137  
   138  	caps, err := caps.Resolve(resolver.New(sessionID))
   139  	if err != nil {
   140  		return "", nil, err
   141  	}
   142  
   143  	if wslConfig, ok := caps.AlwaysMatch["google:wslConfig"].(map[string]interface{}); ok {
   144  		d, err := driver.New(ctx, h.localHost, sessionID, wslConfig)
   145  		if err != nil {
   146  			return "", nil, err
   147  		}
   148  
   149  		s, err := d.NewSession(ctx, caps, w)
   150  		if err != nil {
   151  			ctx, cancel := context.WithTimeout(ctx, 3*time.Second)
   152  			defer cancel()
   153  			d.Shutdown(ctx)
   154  			return "", nil, err
   155  		}
   156  
   157  		return s, d, nil
   158  	}
   159  
   160  	for _, fm := range caps.FirstMatch {
   161  		wslConfig, ok := fm["google:wslConfig"].(map[string]interface{})
   162  
   163  		if ok {
   164  			sessionID := "last"
   165  			if i, ok := caps.AlwaysMatch["google:sessionId"]; ok {
   166  				switch ii := i.(type) {
   167  				case string:
   168  					sessionID = ii
   169  				case float64:
   170  					sessionID = strconv.Itoa(int(ii))
   171  				default:
   172  					return "", nil, fmt.Errorf("google:sessionId %#v is not a string or number", i)
   173  				}
   174  			}
   175  
   176  			d, err := driver.New(ctx, h.localHost, sessionID, wslConfig)
   177  			if err != nil {
   178  				continue
   179  			}
   180  
   181  			s, err := d.NewSession(ctx, &capabilities.Capabilities{
   182  				AlwaysMatch: caps.AlwaysMatch,
   183  				FirstMatch:  []map[string]interface{}{fm},
   184  			}, w)
   185  			if err != nil {
   186  				ctx, cancel := context.WithTimeout(ctx, 3*time.Second)
   187  				defer cancel()
   188  				d.Shutdown(ctx)
   189  				continue
   190  			}
   191  
   192  			return s, d, nil
   193  		}
   194  	}
   195  
   196  	return "", nil, errors.New("No first match caps worked")
   197  }
   198  
   199  func (h *Hub) quitSession(session string, driver *driver.Driver, w http.ResponseWriter, r *http.Request) {
   200  	h.mu.Lock()
   201  	defer h.mu.Unlock()
   202  
   203  	driver.Forward(w, r)
   204  
   205  	ctx, cancel := context.WithTimeout(r.Context(), 3*time.Second)
   206  	defer cancel()
   207  	if err := driver.Shutdown(ctx); err != nil {
   208  		log.Printf("Error shutting down driver: %v", err)
   209  	}
   210  
   211  	delete(h.sessions, session)
   212  }
   213  
   214  func errorResponse(w http.ResponseWriter, httpStatus, status int, err, message string) {
   215  	w.Header().Set("Content-Type", "application/json; charset=utf-8")
   216  	httphelper.SetDefaultResponseHeaders(w.Header())
   217  	w.WriteHeader(httpStatus)
   218  
   219  	respJSON := map[string]interface{}{
   220  		"status": status,
   221  		"value": map[string]interface{}{
   222  			"error":   err,
   223  			"message": message,
   224  		},
   225  	}
   226  
   227  	json.NewEncoder(w).Encode(respJSON)
   228  }