github.com/grailbio/bigslice@v0.0.0-20230519005545-30c4c12152ad/exec/invocation.go (about)

     1  // Copyright 2022 GRAIL, Inc. All rights reserved.
     2  // Use of this source code is governed by the Apache 2.0
     3  // license that can be found in the LICENSE file.
     4  
     5  package exec
     6  
     7  import (
     8  	"bytes"
     9  	"encoding/gob"
    10  	"fmt"
    11  	"reflect"
    12  
    13  	"github.com/grailbio/bigslice"
    14  )
    15  
    16  // execInvocation embeds bigslice.Invocation and is augmented with fields used
    17  // for execution.
    18  type execInvocation struct {
    19  	bigslice.Invocation
    20  	// Env is the compilation environment.  This is saved so that workers
    21  	// compile slices with the same environment as the driver.
    22  	Env CompileEnv
    23  }
    24  
    25  // invocationRef is a reference to an invocation that can be serialized and
    26  // used across process boundaries.  We convert *Result arguments to
    27  // invocationRefs.  Receivers then convert them back to *Result values in their
    28  // address space.
    29  type invocationRef struct{ Index uint64 }
    30  
    31  func makeExecInvocation(inv bigslice.Invocation) execInvocation {
    32  	return execInvocation{
    33  		Invocation: inv,
    34  		Env:        makeCompileEnv(),
    35  	}
    36  }
    37  
    38  // GobEncode implements gob.GobEncoder.  The implementation handles the
    39  // encoding of the arbitrary interface{} argument types without registration of
    40  // those types using gob.Register, as we record the argument types in the call
    41  // to bigslice.Func.
    42  func (inv execInvocation) GobEncode() ([]byte, error) {
    43  	var (
    44  		b   bytes.Buffer
    45  		enc = gob.NewEncoder(&b)
    46  	)
    47  	for _, field := range inv.directEncodedFields() {
    48  		if err := enc.Encode(field.ptr); err != nil {
    49  			return nil, fmt.Errorf("encoding %s: %v", field.name, err)
    50  		}
    51  	}
    52  	fv := bigslice.FuncByIndex(inv.Func)
    53  	for i, arg := range inv.Args {
    54  		typ := fv.In(i)
    55  		if typ.Kind() == reflect.Interface {
    56  			// Pass the address of arg so Encode sees (and hence sends) a value
    57  			// of interface type.  If arg is of a concrete type and we passed
    58  			// it directly, it would see the concrete type instead.  See the
    59  			// blog post "The Laws of Reflection" for background:
    60  			// https://go.dev/blog/laws-of-reflection.
    61  			if err := enc.Encode(&arg); err != nil {
    62  				return nil, fmt.Errorf("encoding arg %d of type %v: %v", i, typ, err)
    63  			}
    64  			continue
    65  		}
    66  		if err := enc.Encode(arg); err != nil {
    67  			return nil, fmt.Errorf("encoding arg %d of type %v: %v", i, typ, err)
    68  		}
    69  	}
    70  	return b.Bytes(), nil
    71  }
    72  
    73  var (
    74  	typResultPtr      = reflect.TypeOf((*Result)(nil))
    75  	typEmptyInterface = reflect.TypeOf((*interface{})(nil)).Elem()
    76  	typInvocationRef  = reflect.TypeOf(invocationRef{})
    77  )
    78  
    79  // GobDecode implements gob.GobDecoder.  The implementation handles the
    80  // decoding of the arbitrary interface{} argument types without registration of
    81  // those types using gob.Register, as we record the argument types in the call
    82  // to bigslice.Func.
    83  func (inv *execInvocation) GobDecode(p []byte) error {
    84  	var (
    85  		b   = bytes.NewBuffer(p)
    86  		dec = gob.NewDecoder(b)
    87  	)
    88  	for _, field := range inv.directEncodedFields() {
    89  		if err := dec.Decode(field.ptr); err != nil {
    90  			return fmt.Errorf("decoding %s: %v", field.name, err)
    91  		}
    92  	}
    93  	fv := bigslice.FuncByIndex(inv.Func)
    94  	inv.Args = make([]interface{}, fv.NumIn())
    95  	for i := range inv.Args {
    96  		typ := fv.In(i)
    97  		var v reflect.Value
    98  		switch {
    99  		case typ == typResultPtr:
   100  			// *Result arguments are replaced with invocationRefs for transit
   101  			// across process boundaries.  See (*bigmachine.Executor).compile.
   102  			v = reflect.New(typInvocationRef)
   103  		case typ.Kind() == reflect.Interface:
   104  			// A *Result can also be passed as an interface that invocationRef
   105  			// does not implement, like Slice, so we decode into an empty
   106  			// interface.  The invocationRef is replaced back with a *Result
   107  			// (in the new address space) elsewhere.  Arguments are typechecked
   108  			// later, so we can safely decode into an empty interface in all
   109  			// cases here.
   110  			v = reflect.New(typEmptyInterface)
   111  		default:
   112  			v = reflect.New(typ)
   113  		}
   114  		if err := dec.DecodeValue(v); err != nil {
   115  			return fmt.Errorf("decoding arg %d of type %v: %v", i, typ, err)
   116  		}
   117  		inv.Args[i] = v.Elem().Interface()
   118  	}
   119  	return nil
   120  }
   121  
   122  // field is a field (of an *execInvocation) that we can gob-encode/decode
   123  // directly.
   124  type field struct {
   125  	// name is the name of the field.  We use it in error messages.
   126  	name string
   127  	// ptr is the address of the field.  We use the address so that we can
   128  	// decode into it.
   129  	ptr interface{}
   130  }
   131  
   132  // directEncodedFields returns the fields that we can gob-encode/decode
   133  // directly in our custom encoding/decoding.
   134  func (inv *execInvocation) directEncodedFields() []field {
   135  	return []field{
   136  		{"Index", &inv.Index},
   137  		{"Func", &inv.Func},
   138  		{"Exclusive", &inv.Exclusive},
   139  		{"Location", &inv.Location},
   140  		{"Env", &inv.Env},
   141  	}
   142  }