github.com/nicocha30/gvisor-ligolo@v0.0.0-20230726075806-989fa2c0a413/pkg/state/state.go (about)

     1  // Copyright 2018 The gVisor Authors.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  // Package state provides functionality related to saving and loading object
    16  // graphs.  For most types, it provides a set of default saving / loading logic
    17  // that will be invoked automatically if custom logic is not defined.
    18  //
    19  //	Kind             Support
    20  //	----             -------
    21  //	Bool             default
    22  //	Int              default
    23  //	Int8             default
    24  //	Int16            default
    25  //	Int32            default
    26  //	Int64            default
    27  //	Uint             default
    28  //	Uint8            default
    29  //	Uint16           default
    30  //	Uint32           default
    31  //	Uint64           default
    32  //	Float32          default
    33  //	Float64          default
    34  //	Complex64        default
    35  //	Complex128       default
    36  //	Array            default
    37  //	Chan             custom
    38  //	Func             custom
    39  //	Interface        default
    40  //	Map              default
    41  //	Ptr              default
    42  //	Slice            default
    43  //	String           default
    44  //	Struct           custom (*) Unless zero-sized.
    45  //	UnsafePointer    custom
    46  //
    47  // See README.md for an overview of how encoding and decoding works.
    48  package state
    49  
    50  import (
    51  	"context"
    52  	"fmt"
    53  	"reflect"
    54  	"runtime"
    55  
    56  	"github.com/nicocha30/gvisor-ligolo/pkg/state/wire"
    57  )
    58  
    59  // objectID is a unique identifier assigned to each object to be serialized.
    60  // Each instance of an object is considered separately, i.e. if there are two
    61  // objects of the same type in the object graph being serialized, they'll be
    62  // assigned unique objectIDs.
    63  type objectID uint32
    64  
    65  // typeID is the identifier for a type. Types are serialized and tracked
    66  // alongside objects in order to avoid the overhead of encoding field names in
    67  // all objects.
    68  type typeID uint32
    69  
    70  // ErrState is returned when an error is encountered during encode/decode.
    71  type ErrState struct {
    72  	// err is the underlying error.
    73  	err error
    74  
    75  	// trace is the stack trace.
    76  	trace string
    77  }
    78  
    79  // Error returns a sensible description of the state error.
    80  func (e *ErrState) Error() string {
    81  	return fmt.Sprintf("%v:\n%s", e.err, e.trace)
    82  }
    83  
    84  // Unwrap implements standard unwrapping.
    85  func (e *ErrState) Unwrap() error {
    86  	return e.err
    87  }
    88  
    89  // Save saves the given object state.
    90  func Save(ctx context.Context, w wire.Writer, rootPtr any) (Stats, error) {
    91  	// Create the encoding state.
    92  	es := encodeState{
    93  		ctx:            ctx,
    94  		w:              w,
    95  		types:          makeTypeEncodeDatabase(),
    96  		zeroValues:     make(map[reflect.Type]*objectEncodeState),
    97  		pending:        make(map[objectID]*objectEncodeState),
    98  		encodedStructs: make(map[reflect.Value]*wire.Struct),
    99  	}
   100  
   101  	// Perform the encoding.
   102  	err := safely(func() {
   103  		es.Save(reflect.ValueOf(rootPtr).Elem())
   104  	})
   105  	return es.stats, err
   106  }
   107  
   108  // Load loads a checkpoint.
   109  func Load(ctx context.Context, r wire.Reader, rootPtr any) (Stats, error) {
   110  	// Create the decoding state.
   111  	ds := decodeState{
   112  		ctx:      ctx,
   113  		r:        r,
   114  		types:    makeTypeDecodeDatabase(),
   115  		deferred: make(map[objectID]wire.Object),
   116  	}
   117  
   118  	// Attempt our decode.
   119  	err := safely(func() {
   120  		ds.Load(reflect.ValueOf(rootPtr).Elem())
   121  	})
   122  	return ds.stats, err
   123  }
   124  
   125  // Sink is used for Type.StateSave.
   126  type Sink struct {
   127  	internal objectEncoder
   128  }
   129  
   130  // Save adds the given object to the map.
   131  //
   132  // You should pass always pointers to the object you are saving. For example:
   133  //
   134  //	type X struct {
   135  //		A int
   136  //		B *int
   137  //	}
   138  //
   139  //	func (x *X) StateTypeInfo(m Sink) state.TypeInfo {
   140  //		return state.TypeInfo{
   141  //			Name:   "pkg.X",
   142  //			Fields: []string{
   143  //				"A",
   144  //				"B",
   145  //			},
   146  //		}
   147  //	}
   148  //
   149  //	func (x *X) StateSave(m Sink) {
   150  //		m.Save(0, &x.A) // Field is A.
   151  //		m.Save(1, &x.B) // Field is B.
   152  //	}
   153  //
   154  //	func (x *X) StateLoad(m Source) {
   155  //		m.Load(0, &x.A) // Field is A.
   156  //		m.Load(1, &x.B) // Field is B.
   157  //	}
   158  func (s Sink) Save(slot int, objPtr any) {
   159  	s.internal.save(slot, reflect.ValueOf(objPtr).Elem())
   160  }
   161  
   162  // SaveValue adds the given object value to the map.
   163  //
   164  // This should be used for values where pointers are not available, or casts
   165  // are required during Save/Load.
   166  //
   167  // For example, if we want to cast external package type P.Foo to int64:
   168  //
   169  //	func (x *X) StateSave(m Sink) {
   170  //		m.SaveValue(0, "A", int64(x.A))
   171  //	}
   172  //
   173  //	func (x *X) StateLoad(m Source) {
   174  //		m.LoadValue(0, new(int64), func(x any) {
   175  //			x.A = P.Foo(x.(int64))
   176  //		})
   177  //	}
   178  func (s Sink) SaveValue(slot int, obj any) {
   179  	s.internal.save(slot, reflect.ValueOf(obj))
   180  }
   181  
   182  // Context returns the context object provided at save time.
   183  func (s Sink) Context() context.Context {
   184  	return s.internal.es.ctx
   185  }
   186  
   187  // Type is an interface that must be implemented by Struct objects. This allows
   188  // these objects to be serialized while minimizing runtime reflection required.
   189  //
   190  // All these methods can be automatically generated by the go_statify tool.
   191  type Type interface {
   192  	// StateTypeName returns the type's name.
   193  	//
   194  	// This is used for matching type information during encoding and
   195  	// decoding, as well as dynamic interface dispatch. This should be
   196  	// globally unique.
   197  	StateTypeName() string
   198  
   199  	// StateFields returns information about the type.
   200  	//
   201  	// Fields is the set of fields for the object. Calls to Sink.Save and
   202  	// Source.Load must be made in-order with respect to these fields.
   203  	//
   204  	// This will be called at most once per serialization.
   205  	StateFields() []string
   206  }
   207  
   208  // SaverLoader must be implemented by struct types.
   209  type SaverLoader interface {
   210  	// StateSave saves the state of the object to the given Map.
   211  	StateSave(Sink)
   212  
   213  	// StateLoad loads the state of the object.
   214  	StateLoad(Source)
   215  }
   216  
   217  // Source is used for Type.StateLoad.
   218  type Source struct {
   219  	internal objectDecoder
   220  }
   221  
   222  // Load loads the given object passed as a pointer..
   223  //
   224  // See Sink.Save for an example.
   225  func (s Source) Load(slot int, objPtr any) {
   226  	s.internal.load(slot, reflect.ValueOf(objPtr), false, nil)
   227  }
   228  
   229  // LoadWait loads the given objects from the map, and marks it as requiring all
   230  // AfterLoad executions to complete prior to running this object's AfterLoad.
   231  //
   232  // See Sink.Save for an example.
   233  func (s Source) LoadWait(slot int, objPtr any) {
   234  	s.internal.load(slot, reflect.ValueOf(objPtr), true, nil)
   235  }
   236  
   237  // LoadValue loads the given object value from the map.
   238  //
   239  // See Sink.SaveValue for an example.
   240  func (s Source) LoadValue(slot int, objPtr any, fn func(any)) {
   241  	o := reflect.ValueOf(objPtr)
   242  	s.internal.load(slot, o, true, func() { fn(o.Elem().Interface()) })
   243  }
   244  
   245  // AfterLoad schedules a function execution when all objects have been
   246  // allocated and their automated loading and customized load logic have been
   247  // executed. fn will not be executed until all of current object's
   248  // dependencies' AfterLoad() logic, if exist, have been executed.
   249  func (s Source) AfterLoad(fn func()) {
   250  	s.internal.afterLoad(fn)
   251  }
   252  
   253  // Context returns the context object provided at load time.
   254  func (s Source) Context() context.Context {
   255  	return s.internal.ds.ctx
   256  }
   257  
   258  // IsZeroValue checks if the given value is the zero value.
   259  //
   260  // This function is used by the stateify tool.
   261  func IsZeroValue(val any) bool {
   262  	return val == nil || reflect.ValueOf(val).Elem().IsZero()
   263  }
   264  
   265  // Failf is a wrapper around panic that should be used to generate errors that
   266  // can be caught during saving and loading.
   267  func Failf(fmtStr string, v ...any) {
   268  	panic(fmt.Errorf(fmtStr, v...))
   269  }
   270  
   271  // safely executes the given function, catching a panic and unpacking as an
   272  // error.
   273  //
   274  // The error flow through the state package uses panic and recover. There are
   275  // two important reasons for this:
   276  //
   277  // 1) Many of the reflection methods will already panic with invalid data or
   278  // violated assumptions. We would want to recover anyways here.
   279  //
   280  // 2) It allows us to eliminate boilerplate within Save() and Load() functions.
   281  // In nearly all cases, when the low-level serialization functions fail, you
   282  // will want the checkpoint to fail anyways. Plumbing errors through every
   283  // method doesn't add a lot of value. If there are specific error conditions
   284  // that you'd like to handle, you should add appropriate functionality to
   285  // objects themselves prior to calling Save() and Load().
   286  func safely(fn func()) (err error) {
   287  	defer func() {
   288  		if r := recover(); r != nil {
   289  			if es, ok := r.(*ErrState); ok {
   290  				err = es // Propagate.
   291  				return
   292  			}
   293  
   294  			// Build a new state error.
   295  			es := new(ErrState)
   296  			if e, ok := r.(error); ok {
   297  				es.err = e
   298  			} else {
   299  				es.err = fmt.Errorf("%v", r)
   300  			}
   301  
   302  			// Make a stack. We don't know how big it will be ahead
   303  			// of time, but want to make sure we get the whole
   304  			// thing. So we just do a stupid brute force approach.
   305  			var stack []byte
   306  			for sz := 1024; ; sz *= 2 {
   307  				stack = make([]byte, sz)
   308  				n := runtime.Stack(stack, false)
   309  				if n < sz {
   310  					es.trace = string(stack[:n])
   311  					break
   312  				}
   313  			}
   314  
   315  			// Set the error.
   316  			err = es
   317  		}
   318  	}()
   319  
   320  	// Execute the function.
   321  	fn()
   322  	return nil
   323  }