github.com/nya3jp/tast@v0.0.0-20230601000426-85c8e4d83a9b/src/go.chromium.org/tast/core/internal/faketlw/faketlw.go (about)

     1  // Copyright 2020 The ChromiumOS Authors
     2  // Use of this source code is governed by a BSD-style license that can be
     3  // found in the LICENSE file.
     4  
     5  // Package faketlw provides a fake implementation of the TLW service.
     6  package faketlw
     7  
     8  import (
     9  	"context"
    10  	"fmt"
    11  	"net"
    12  	"net/http"
    13  	"net/http/httptest"
    14  	"net/url"
    15  	"sync"
    16  	"testing"
    17  
    18  	"github.com/golang/protobuf/ptypes"
    19  	"github.com/golang/protobuf/ptypes/empty"
    20  	"go.chromium.org/chromiumos/config/go/api/test/tls"
    21  	"go.chromium.org/chromiumos/config/go/api/test/tls/dependencies/longrunning"
    22  	"google.golang.org/grpc"
    23  	"google.golang.org/grpc/codes"
    24  	"google.golang.org/grpc/status"
    25  )
    26  
    27  // NamePort represents a simple name/port pair.
    28  type NamePort struct {
    29  	Name string
    30  	Port int32
    31  }
    32  
    33  type wiringServerConfig struct {
    34  	cacheFileMap map[string][]byte
    35  	dutName      string
    36  }
    37  
    38  // WiringServerOption is an option passed to NewWiringServer to customize WiringServer.
    39  type WiringServerOption func(cfg *wiringServerConfig)
    40  
    41  // WithCacheFileMap returns an option that sets the files to be fetched by
    42  // CacheForDUT requests.
    43  func WithCacheFileMap(m map[string][]byte) WiringServerOption {
    44  	return func(cfg *wiringServerConfig) {
    45  		cfg.cacheFileMap = m
    46  	}
    47  }
    48  
    49  // WithDUTName returns an option that sets the expected DUT name to be requested by
    50  // CacheForDut requests.
    51  func WithDUTName(n string) WiringServerOption {
    52  	return func(cfg *wiringServerConfig) {
    53  		cfg.dutName = n
    54  	}
    55  }
    56  
    57  type operation struct {
    58  	srcURL string
    59  }
    60  
    61  type operationsMap map[string]operation
    62  
    63  // WiringServer is a fake implementation of tls.WiringServer and
    64  // longrunning.UnimplementedOperationsServer for CacheForDUT.
    65  type WiringServer struct {
    66  	tls.UnimplementedWiringServer
    67  	longrunning.UnimplementedOperationsServer
    68  	cfg         wiringServerConfig
    69  	cacheServer *cacheServer
    70  
    71  	operationsMu sync.RWMutex
    72  	operations   operationsMap
    73  }
    74  
    75  var _ tls.WiringServer = &WiringServer{}
    76  
    77  // NewWiringServer constructs a new WiringServer from given options.
    78  // The caller is responsible for calling Shutdown() of the returned object.
    79  func NewWiringServer(opts ...WiringServerOption) *WiringServer {
    80  	var cfg wiringServerConfig
    81  	for _, opt := range opts {
    82  		opt(&cfg)
    83  	}
    84  
    85  	return &WiringServer{
    86  		cfg:         cfg,
    87  		operations:  map[string]operation{},
    88  		cacheServer: newCacheServer(),
    89  	}
    90  }
    91  
    92  // Shutdown shuts down the serer.
    93  func (s *WiringServer) Shutdown() {
    94  	s.cacheServer.hs.Close()
    95  }
    96  
    97  // OpenDutPort implements tls.WiringServer.OpenDutPort.
    98  func (s *WiringServer) OpenDutPort(ctx context.Context, req *tls.OpenDutPortRequest) (*tls.OpenDutPortResponse, error) {
    99  	return nil, status.Error(codes.Unimplemented, "not implemented")
   100  }
   101  
   102  // CacheForDut implements tls WiringServer.CacheForDUT
   103  func (s *WiringServer) CacheForDut(ctx context.Context, req *tls.CacheForDutRequest) (*longrunning.Operation, error) {
   104  	if req.DutName != s.cfg.dutName {
   105  		return nil, status.Errorf(codes.InvalidArgument, "wrong DUT name: got %q, want %q", req.DutName, s.cfg.dutName)
   106  	}
   107  	_, ok := s.cfg.cacheFileMap[req.Url]
   108  	if !ok {
   109  		return nil, status.Errorf(codes.NotFound, "not found in cache file map: %s", req.Url)
   110  	}
   111  
   112  	operationName := fmt.Sprintf("CacheForDUTOperation_%s", req.Url)
   113  	s.beginOperation(operationName, req.Url)
   114  	op := longrunning.Operation{
   115  		Name: operationName,
   116  		// Pretend operation is not finished yet in order to test the code path for
   117  		// waiting the operation to finish.
   118  		Done: false,
   119  	}
   120  
   121  	return &op, nil
   122  }
   123  
   124  func (s *WiringServer) beginOperation(name, srcURL string) {
   125  	s.operationsMu.Lock()
   126  	defer s.operationsMu.Unlock()
   127  	s.operations[name] = operation{
   128  		srcURL: srcURL,
   129  	}
   130  }
   131  
   132  func (s *WiringServer) operation(name string) (oper *operation, exists bool) {
   133  	s.operationsMu.RLock()
   134  	defer s.operationsMu.RUnlock()
   135  	o, ok := s.operations[name]
   136  	return &o, ok
   137  }
   138  
   139  func (s *WiringServer) fillCache(srcURL string) (string, error) {
   140  	cacheURL := fmt.Sprintf("%s/?s=%s",
   141  		s.cacheServer.hs.URL, url.QueryEscape(srcURL))
   142  	k, err := cacheKey(cacheURL)
   143  	if err != nil {
   144  		return "", status.Errorf(codes.InvalidArgument, "failed to generate cache key for %s: %s", cacheURL, err)
   145  	}
   146  	content, ok := s.cfg.cacheFileMap[srcURL]
   147  	if !ok {
   148  		// CacheForDUT examines existence of the resource first.
   149  		// If the file seen missing here, it indicates the server resource was lost after it.
   150  		// (This should not happen with current implementation of the this fake TLW.)
   151  		return "", status.Errorf(codes.DataLoss, "requrested URL does not exist: %s", srcURL)
   152  	}
   153  	s.cacheServer.fillCache(k, content)
   154  	return cacheURL, nil
   155  }
   156  
   157  // GetOperation implements longrunning.GetOperation.
   158  func (s *WiringServer) GetOperation(ctx context.Context, req *longrunning.GetOperationRequest) (*longrunning.Operation, error) {
   159  	return s.finishOperation(req.Name)
   160  }
   161  
   162  // WaitOperation implements longrunning.WaitOperation.
   163  func (s *WiringServer) WaitOperation(ctx context.Context, req *longrunning.WaitOperationRequest) (*longrunning.Operation, error) {
   164  	return s.finishOperation(req.Name)
   165  }
   166  
   167  func (s *WiringServer) finishOperation(name string) (*longrunning.Operation, error) {
   168  	o, ok := s.operation(name)
   169  	if !ok {
   170  		return nil, status.Errorf(codes.InvalidArgument, "invalid argument: %s", name)
   171  	}
   172  	cacheURL, err := s.fillCache(o.srcURL)
   173  	if err != nil {
   174  		return nil, err
   175  	}
   176  	m, err := ptypes.MarshalAny(&tls.CacheForDutResponse{Url: cacheURL})
   177  	if err != nil {
   178  		return nil, status.Errorf(codes.Internal, "failed to marshal data: %s", err)
   179  	}
   180  	return &longrunning.Operation{
   181  		Done: true,
   182  		Name: name,
   183  		Result: &longrunning.Operation_Response{
   184  			Response: m,
   185  		},
   186  	}, nil
   187  }
   188  
   189  // CancelOperation implements longrunning.CancelOperation.
   190  func (s *WiringServer) CancelOperation(ctx context.Context, req *longrunning.CancelOperationRequest) (*empty.Empty, error) {
   191  	return nil, status.Error(codes.Unimplemented, "not implemented")
   192  }
   193  
   194  // DeleteOperation implements longrunning.CancelOperation.
   195  func (s *WiringServer) DeleteOperation(ctx context.Context, req *longrunning.DeleteOperationRequest) (*empty.Empty, error) {
   196  	return nil, status.Error(codes.Unimplemented, "not implemented")
   197  }
   198  
   199  // ListOperations implements longrunning.ListOperations.
   200  func (s *WiringServer) ListOperations(ctx context.Context, req *longrunning.ListOperationsRequest) (*longrunning.ListOperationsResponse, error) {
   201  	return nil, status.Error(codes.Unimplemented, "not implemented")
   202  }
   203  
   204  // cacheKey generates the internal key used for matching a URL generated by CacheForDUT,
   205  // and a one that is passed to the HTTP handler.
   206  func cacheKey(cacheURL string) (string, error) {
   207  	u, err := url.Parse(cacheURL)
   208  	if err != nil {
   209  		return "", err
   210  	}
   211  
   212  	// The query parameter name should be kept consistent with WiringServer.fillCache().
   213  	q := u.Query()["s"]
   214  	if len(q) != 1 {
   215  		return "", fmt.Errorf("failed to find query in cache URL %s", cacheURL)
   216  	}
   217  	return q[0], nil
   218  }
   219  
   220  // StartWiringServer is a convenient method for unit tests which starts a gRPC
   221  // server serving WiringServer in the background. It also starts an HTTP server
   222  // for serving cached files by CacheForDUT.
   223  // Callers are responsible for stopping the server by stopFunc().
   224  func StartWiringServer(t *testing.T, opts ...WiringServerOption) (stopFunc func(), addr string) {
   225  	ws := NewWiringServer(opts...)
   226  
   227  	srv := grpc.NewServer()
   228  	tls.RegisterWiringServer(srv, ws)
   229  	longrunning.RegisterOperationsServer(srv, ws)
   230  
   231  	lis, err := net.Listen("tcp", "localhost:0")
   232  	if err != nil {
   233  		t.Fatal("Failed to listen: ", err)
   234  	}
   235  
   236  	go srv.Serve(lis)
   237  
   238  	return func() {
   239  		ws.Shutdown()
   240  		srv.Stop()
   241  	}, lis.Addr().String()
   242  }
   243  
   244  type cacheServer struct {
   245  	hs *httptest.Server
   246  
   247  	cachedFilesMu sync.RWMutex
   248  	cachedFiles   map[string][]byte
   249  }
   250  
   251  func newCacheServer() *cacheServer {
   252  	c := cacheServer{
   253  		cachedFiles: map[string][]byte{},
   254  	}
   255  	c.hs = httptest.NewServer(&c)
   256  	return &c
   257  }
   258  
   259  func (c *cacheServer) fillCache(key string, content []byte) {
   260  	c.cachedFilesMu.Lock()
   261  	defer c.cachedFilesMu.Unlock()
   262  	c.cachedFiles[key] = content
   263  }
   264  
   265  func (c *cacheServer) cachedFile(key string) (content []byte, exists bool) {
   266  	c.cachedFilesMu.RLock()
   267  	defer c.cachedFilesMu.RUnlock()
   268  	content, ok := c.cachedFiles[key]
   269  	return content, ok
   270  }
   271  
   272  func (c *cacheServer) ServeHTTP(w http.ResponseWriter, r *http.Request) {
   273  	k, err := cacheKey(r.URL.String())
   274  	if err != nil {
   275  		http.Error(w, err.Error(), http.StatusBadRequest)
   276  		return
   277  	}
   278  	content, ok := c.cachedFile(k)
   279  	if !ok {
   280  		http.NotFound(w, r)
   281  		return
   282  	}
   283  	w.Write(content)
   284  }