go.chromium.org/luci@v0.0.0-20240309015107-7cdc2e660f33/cipkg/base/actions/reexec.go (about)

     1  // Copyright 2023 The LUCI Authors.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //	http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package actions
    16  
    17  import (
    18  	"context"
    19  	"crypto"
    20  	"errors"
    21  	"fmt"
    22  	"log"
    23  	"os"
    24  	"runtime"
    25  	"sync"
    26  
    27  	"go.chromium.org/luci/cipkg/core"
    28  	luciproto "go.chromium.org/luci/common/proto"
    29  	"go.chromium.org/luci/common/system/environ"
    30  
    31  	"google.golang.org/protobuf/encoding/protojson"
    32  	"google.golang.org/protobuf/proto"
    33  	"google.golang.org/protobuf/reflect/protoreflect"
    34  	"google.golang.org/protobuf/types/known/anypb"
    35  )
    36  
    37  // Executor is the type of the Executor for action spec M. When the function is
    38  // invoked, the action associated with the message spec should be executed.
    39  // `out` is the output path for the artifacts generated by the executor.
    40  type Executor[M proto.Message] func(ctx context.Context, msg M, out string) error
    41  
    42  const envCipkgExec = "_CIPKG_EXEC_CMD"
    43  
    44  // ReexecRegistry is the registry for actions that requires implemented in the
    45  // binary itself.
    46  //
    47  // By default, NewReexecRegistry registers the following Reexec Executors:
    48  //   - core.ActionURLFetch
    49  //   - core.ActionFilesCopy
    50  //   - core.ActionCIPDExport
    51  //
    52  // In order for a binary to work with reexec actions, you must call the
    53  // .Intercept() function very early in your program's `main()`. This will
    54  // intercept the invocation of this program and divert execution control to the
    55  // registered Executor.
    56  //
    57  // Programs may register additional Executors in this registry using SetExecutor
    58  // or MustSetExecutor functions from this package.
    59  type ReexecRegistry struct {
    60  	execs map[protoreflect.FullName]Executor[proto.Message]
    61  
    62  	mu     sync.Mutex
    63  	sealed bool
    64  }
    65  
    66  func NewReexecRegistry() *ReexecRegistry {
    67  	m := &ReexecRegistry{
    68  		execs: make(map[protoreflect.FullName]Executor[proto.Message]),
    69  	}
    70  	MustSetExecutor[*core.ActionURLFetch](m, ActionURLFetchExecutor)
    71  	MustSetExecutor[*core.ActionFilesCopy](m, defaultFilesCopyExecutor.Execute)
    72  	MustSetExecutor[*core.ActionCIPDExport](m, ActionCIPDExportExecutor)
    73  	return m
    74  }
    75  
    76  var (
    77  	ErrExecutorExisted      = errors.New("executor for the message type already existed")
    78  	ErrReexecRegistrySealed = errors.New("executor can't be set after Intercept being called")
    79  )
    80  
    81  // SetExecutor set the executor for the action specification M.
    82  // All executors must be set before calling .Intercept(). If there is a executor
    83  // already registed for M, SetExecutor will return ErrExecutorExisted.
    84  func SetExecutor[M proto.Message](r *ReexecRegistry, execFunc Executor[M]) error {
    85  	r.mu.Lock()
    86  	defer r.mu.Unlock()
    87  	if r.sealed {
    88  		return ErrReexecRegistrySealed
    89  	}
    90  
    91  	var msg M
    92  	name := proto.MessageName(msg)
    93  	if _, ok := r.execs[name]; ok {
    94  		return ErrExecutorExisted
    95  	}
    96  	r.execs[name] = func(ctx context.Context, msg proto.Message, out string) error {
    97  		return execFunc(ctx, msg.(M), out)
    98  	}
    99  	return nil
   100  }
   101  
   102  // MustSetExecutor set the executor for the action specification M similar to
   103  // SetExecutor, but will panic if any error happened.
   104  func MustSetExecutor[M proto.Message](r *ReexecRegistry, execFunc Executor[M]) {
   105  	if err := SetExecutor[M](r, execFunc); err != nil {
   106  		panic(err)
   107  	}
   108  }
   109  
   110  // Intercept executes the registed executor and exit if _CIPKG_EXEC_CMD is
   111  // found. This is REQUIRED for reexec to function properly and need to be
   112  // executed after init() because embed fs or other resources may be
   113  // registered in init().
   114  // Any application using the framework must call the .Intercept() function very
   115  // early in your program's `main()`. This will intercept the invocation of this
   116  // program and divert execution control to the registered Executor.
   117  // On windows, environment variable NoDefaultCurrentDirectoryInExePath will
   118  // always be set to prevent searching binaries from current workding directory
   119  // by default, which because of its relative nature, is forbidden by golang.
   120  func (r *ReexecRegistry) Intercept(ctx context.Context) {
   121  	if runtime.GOOS == "windows" {
   122  		if err := os.Setenv("NoDefaultCurrentDirectoryInExePath", "1"); err != nil {
   123  			panic(fmt.Sprintf("failed to set NoDefaultCurrentDirectoryInExePath on Windows: %s", err))
   124  		}
   125  	}
   126  	r.interceptWithArgs(ctx, environ.System(), os.Args, os.Exit)
   127  }
   128  
   129  func (r *ReexecRegistry) interceptWithArgs(ctx context.Context, env environ.Env, args []string, exit func(int)) {
   130  	r.mu.Lock()
   131  	if !r.sealed {
   132  		r.sealed = true
   133  	}
   134  	r.mu.Unlock()
   135  	if !env.Remove(envCipkgExec) {
   136  		return
   137  	}
   138  
   139  	if len(args) < 2 {
   140  		panic(fmt.Sprintf("usage: cipkg-exec <proto>: insufficient args: %s", args))
   141  	}
   142  
   143  	var any anypb.Any
   144  	if err := protojson.Unmarshal([]byte(args[1]), &any); err != nil {
   145  		panic(fmt.Sprintf("failed to unmarshal anypb: %s, %s", err, args))
   146  	}
   147  	msg, err := any.UnmarshalNew()
   148  	if err != nil {
   149  		panic(fmt.Sprintf("failed to unmarshal proto from any: %s, %s", err, args))
   150  	}
   151  	f := r.execs[proto.MessageName(msg)]
   152  	if f == nil {
   153  		panic(fmt.Sprintf("unknown cipkg-exec command: %s", args))
   154  	}
   155  
   156  	if err := f(env.SetInCtx(ctx), msg, env.Get("out")); err != nil {
   157  		log.Fatalln(err)
   158  	}
   159  
   160  	exit(0)
   161  }
   162  
   163  // reexecVersion is the globle reexec version which, if changed, will affect all
   164  // derivations' FixedOutput generated from ReexecDerivation(...).
   165  const reexecVersion = "v1"
   166  
   167  // ReexecDerivation returns a derivation for re-executing the binary. It sets
   168  // the FixedOutput using hash generated from action spec.
   169  func ReexecDerivation(m proto.Message, hostEnv bool) (*core.Derivation, error) {
   170  	self, err := os.Executable()
   171  	if err != nil {
   172  		return nil, err
   173  	}
   174  
   175  	m, err = anypb.New(m)
   176  	if err != nil {
   177  		return nil, err
   178  	}
   179  	b, err := protojson.Marshal(m)
   180  	if err != nil {
   181  		return nil, err
   182  	}
   183  
   184  	fixed, err := sha256String(m)
   185  	if err != nil {
   186  		return nil, err
   187  	}
   188  
   189  	env := environ.New(nil)
   190  	if hostEnv {
   191  		env = environ.System()
   192  	}
   193  	env.Set(envCipkgExec, "1")
   194  
   195  	return &core.Derivation{
   196  		Args:        []string{self, string(b)},
   197  		Env:         env.Sorted(),
   198  		FixedOutput: fixed,
   199  	}, nil
   200  }
   201  
   202  func sha256String(m proto.Message) (string, error) {
   203  	const algo = crypto.SHA256
   204  	h := algo.New()
   205  	if err := luciproto.StableHash(h, m); err != nil {
   206  		return "", err
   207  	}
   208  	return fmt.Sprintf("%s%s:%x", reexecVersion, algo, h.Sum(nil)), nil
   209  }