github.com/nya3jp/tast@v0.0.0-20230601000426-85c8e4d83a9b/src/go.chromium.org/tast/core/internal/bundle/service.go (about)

     1  // Copyright 2021 The ChromiumOS Authors
     2  // Use of this source code is governed by a BSD-style license that can be
     3  // found in the LICENSE file.
     4  
     5  package bundle
     6  
     7  import (
     8  	"context"
     9  	"fmt"
    10  	"os"
    11  	"os/exec"
    12  
    13  	"go.chromium.org/tast/core/errors"
    14  	"go.chromium.org/tast/core/internal/bundle/bundleclient"
    15  	"go.chromium.org/tast/core/internal/logging"
    16  	"go.chromium.org/tast/core/internal/protocol"
    17  	"go.chromium.org/tast/core/internal/testing"
    18  )
    19  
    20  type testServer struct {
    21  	protocol.UnimplementedTestServiceServer
    22  	scfg         *StaticConfig
    23  	bundleParams *protocol.BundleInitParams
    24  }
    25  
    26  func newTestServer(scfg *StaticConfig, bundleParams *protocol.BundleInitParams) *testServer {
    27  	exec.Command("logger", "New test server is setup in bundle to listen to requests").Run()
    28  	return &testServer{scfg: scfg, bundleParams: bundleParams}
    29  }
    30  
    31  func (s *testServer) ListEntities(ctx context.Context, req *protocol.ListEntitiesRequest) (*protocol.ListEntitiesResponse, error) {
    32  	var entities []*protocol.ResolvedEntity
    33  	// Logging added for b/213616631 to see ListEntities progress on the DUT.
    34  	execName, err := os.Executable()
    35  	if err != nil {
    36  		execName = "bundle"
    37  	}
    38  	logging.Debugf(ctx, "Serving ListEntities Request in %s (recursive flag: %v)", execName, req.GetRecursive())
    39  	exec.Command("logger", fmt.Sprintf("Serving ListEntities Request in %s", execName)).Run()
    40  	if req.GetRecursive() {
    41  		var cl *bundleclient.Client
    42  		if s.bundleParams.GetBundleConfig().GetPrimaryTarget() != nil {
    43  			var err error
    44  			cl, err = bundleclient.New(ctx, s.bundleParams.GetBundleConfig().GetPrimaryTarget(), s.scfg.registry.Name(), &protocol.HandshakeRequest{})
    45  			if err != nil {
    46  				return nil, err
    47  			}
    48  			defer cl.Close(ctx)
    49  		}
    50  
    51  		var err error
    52  		entities, err = listEntitiesRecursive(ctx, s.scfg.registry, req.Features, cl)
    53  		if err != nil {
    54  			return nil, err
    55  		}
    56  	} else {
    57  		entities = listEntities(s.scfg.registry, req.Features)
    58  	}
    59  	// Logging added for b/213616631 to see ListEntities progress on the DUT.
    60  	logging.Debugf(ctx, "Successfully serving ListEntities Request in %s ", execName)
    61  	exec.Command("logger", fmt.Sprintf("Successfully serving ListEntities Request in %s", execName)).Run()
    62  	return &protocol.ListEntitiesResponse{Entities: entities}, nil
    63  }
    64  
    65  func (s *testServer) GlobalRuntimeVars(ctx context.Context, req *protocol.GlobalRuntimeVarsRequest) (*protocol.GlobalRuntimeVarsResponse, error) {
    66  
    67  	vars := s.scfg.registry.AllVars()
    68  
    69  	var runTimeVars []*protocol.GlobalRuntimeVar
    70  	for _, v := range vars {
    71  		runTimeVars = append(runTimeVars, &protocol.GlobalRuntimeVar{Name: v.Name()})
    72  	}
    73  	return &protocol.GlobalRuntimeVarsResponse{Vars: runTimeVars}, nil
    74  }
    75  
    76  func (s *testServer) RunTests(srv protocol.TestService_RunTestsServer) error {
    77  	ctx := srv.Context()
    78  
    79  	initReq, err := srv.Recv()
    80  	if err != nil {
    81  		return errors.Wrap(err, "RunTests: failed to receive messages")
    82  	}
    83  	if _, ok := initReq.GetType().(*protocol.RunTestsRequest_RunTestsInit); !ok {
    84  		return errors.Errorf("RunTests: unexpected initial request message: got %T, want %T", initReq.GetType(), &protocol.RunTestsRequest_RunTestsInit{})
    85  	}
    86  
    87  	if initReq.GetRunTestsInit().GetRecursive() {
    88  		if err := runTestsRecursive(ctx, srv, initReq.GetRunTestsInit().GetRunConfig(), s.scfg, s.bundleParams); err != nil {
    89  			return errors.Wrap(err, "RunTests: failed in run tests recursively")
    90  		}
    91  		return nil
    92  	}
    93  	if err := runTests(ctx, srv, initReq.GetRunTestsInit().GetRunConfig(), s.scfg, s.bundleParams.GetBundleConfig()); err != nil {
    94  		return errors.Wrap(err, "RunTests: failed in run tests")
    95  	}
    96  	return nil
    97  }
    98  
    99  // listEntitiesRecursive lists all the entities this bundle has.
   100  // If cl is non-nil it also lists all the entities in the bundle cl points to.
   101  func listEntitiesRecursive(ctx context.Context, reg *testing.Registry, features *protocol.Features, cl *bundleclient.Client) ([]*protocol.ResolvedEntity, error) {
   102  	entities := listEntities(reg, features)
   103  	if cl == nil {
   104  		return entities, nil
   105  	}
   106  	es, err := cl.TestService().ListEntities(ctx, &protocol.ListEntitiesRequest{
   107  		Features:  features,
   108  		Recursive: true,
   109  	})
   110  	if err != nil {
   111  		return nil, err
   112  	}
   113  
   114  	for _, e := range es.Entities {
   115  		e.Hops++
   116  		entities = append(entities, e)
   117  	}
   118  	return entities, nil
   119  }
   120  
   121  func listEntities(reg *testing.Registry, features *protocol.Features) []*protocol.ResolvedEntity {
   122  	fixtures := reg.AllFixtures()
   123  	starts := buildStartFixtureMap(fixtures)
   124  
   125  	var resolved []*protocol.ResolvedEntity
   126  
   127  	for _, f := range fixtures {
   128  		resolved = append(resolved, &protocol.ResolvedEntity{
   129  			Entity:           f.EntityProto(),
   130  			StartFixtureName: starts[f.Name],
   131  		})
   132  	}
   133  
   134  	for _, t := range reg.AllTests() {
   135  		// If we encounter errors while checking test dependencies,
   136  		// treat the test as not skipped. When we actually try to
   137  		// run the test later, it will fail with errors.
   138  		var skip *protocol.Skip
   139  		if reasons, err := t.Deps().Check(features); err == nil && len(reasons) > 0 {
   140  			skip = &protocol.Skip{Reasons: reasons}
   141  		}
   142  		start, ok := starts[t.Fixture]
   143  		if !ok {
   144  			start = t.Fixture
   145  		}
   146  		resolved = append(resolved, &protocol.ResolvedEntity{
   147  			Entity:           t.EntityProto(),
   148  			Skip:             skip,
   149  			StartFixtureName: start,
   150  		})
   151  	}
   152  	return resolved
   153  }
   154  
   155  func buildStartFixtureMap(fixtures map[string]*testing.FixtureInstance) map[string]string {
   156  	starts := make(map[string]string)
   157  
   158  	// findStart is a recursive function to find a start fixture of f.
   159  	// It fills in results to starts for memoization.
   160  	var findStart func(f *testing.FixtureInstance) string
   161  	findStart = func(f *testing.FixtureInstance) string {
   162  		if start, ok := starts[f.Name]; ok {
   163  			return start // memoize
   164  		}
   165  		var start string
   166  		if parent, ok := fixtures[f.Parent]; ok {
   167  			start = findStart(parent)
   168  		} else {
   169  			start = f.Parent
   170  		}
   171  		starts[f.Name] = start
   172  		return start
   173  	}
   174  
   175  	for _, f := range fixtures {
   176  		findStart(f)
   177  	}
   178  	return starts
   179  }