github.com/sagernet/gvisor@v0.0.0-20240428053021-e691de28565f/pkg/state/state.go (about)

     1  // Copyright 2018 The gVisor Authors.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  // Package state provides functionality related to saving and loading object
    16  // graphs.  For most types, it provides a set of default saving / loading logic
    17  // that will be invoked automatically if custom logic is not defined.
    18  //
    19  //	Kind             Support
    20  //	----             -------
    21  //	Bool             default
    22  //	Int              default
    23  //	Int8             default
    24  //	Int16            default
    25  //	Int32            default
    26  //	Int64            default
    27  //	Uint             default
    28  //	Uint8            default
    29  //	Uint16           default
    30  //	Uint32           default
    31  //	Uint64           default
    32  //	Float32          default
    33  //	Float64          default
    34  //	Complex64        default
    35  //	Complex128       default
    36  //	Array            default
    37  //	Chan             custom
    38  //	Func             custom
    39  //	Interface        default
    40  //	Map              default
    41  //	Ptr              default
    42  //	Slice            default
    43  //	String           default
    44  //	Struct           custom (*) Unless zero-sized.
    45  //	UnsafePointer    custom
    46  //
    47  // See README.md for an overview of how encoding and decoding works.
    48  package state
    49  
    50  import (
    51  	"context"
    52  	"fmt"
    53  	"io"
    54  	"reflect"
    55  	"runtime"
    56  
    57  	"github.com/sagernet/gvisor/pkg/state/wire"
    58  )
    59  
    60  // objectID is a unique identifier assigned to each object to be serialized.
    61  // Each instance of an object is considered separately, i.e. if there are two
    62  // objects of the same type in the object graph being serialized, they'll be
    63  // assigned unique objectIDs.
    64  type objectID uint32
    65  
    66  // typeID is the identifier for a type. Types are serialized and tracked
    67  // alongside objects in order to avoid the overhead of encoding field names in
    68  // all objects.
    69  type typeID uint32
    70  
    71  // ErrState is returned when an error is encountered during encode/decode.
    72  type ErrState struct {
    73  	// err is the underlying error.
    74  	err error
    75  
    76  	// trace is the stack trace.
    77  	trace string
    78  }
    79  
    80  // Error returns a sensible description of the state error.
    81  func (e *ErrState) Error() string {
    82  	return fmt.Sprintf("%v:\n%s", e.err, e.trace)
    83  }
    84  
    85  // Unwrap implements standard unwrapping.
    86  func (e *ErrState) Unwrap() error {
    87  	return e.err
    88  }
    89  
    90  // Save saves the given object state.
    91  func Save(ctx context.Context, w io.Writer, rootPtr any) (Stats, error) {
    92  	// Create the encoding state.
    93  	es := encodeState{
    94  		ctx:            ctx,
    95  		w:              w,
    96  		types:          makeTypeEncodeDatabase(),
    97  		zeroValues:     make(map[reflect.Type]*objectEncodeState),
    98  		pending:        make(map[objectID]*objectEncodeState),
    99  		encodedStructs: make(map[reflect.Value]*wire.Struct),
   100  	}
   101  
   102  	// Perform the encoding.
   103  	err := safely(func() {
   104  		es.Save(reflect.ValueOf(rootPtr).Elem())
   105  	})
   106  	return es.stats, err
   107  }
   108  
   109  // Load loads a checkpoint.
   110  func Load(ctx context.Context, r io.Reader, rootPtr any) (Stats, error) {
   111  	// Create the decoding state.
   112  	ds := decodeState{
   113  		ctx:      ctx,
   114  		r:        r,
   115  		types:    makeTypeDecodeDatabase(),
   116  		deferred: make(map[objectID]wire.Object),
   117  	}
   118  
   119  	// Attempt our decode.
   120  	err := safely(func() {
   121  		ds.Load(reflect.ValueOf(rootPtr).Elem())
   122  	})
   123  	return ds.stats, err
   124  }
   125  
   126  // Sink is used for Type.StateSave.
   127  type Sink struct {
   128  	internal objectEncoder
   129  }
   130  
   131  // Save adds the given object to the map.
   132  //
   133  // You should pass always pointers to the object you are saving. For example:
   134  //
   135  //	type X struct {
   136  //		A int
   137  //		B *int
   138  //	}
   139  //
   140  //	func (x *X) StateTypeInfo(m Sink) state.TypeInfo {
   141  //		return state.TypeInfo{
   142  //			Name:   "pkg.X",
   143  //			Fields: []string{
   144  //				"A",
   145  //				"B",
   146  //			},
   147  //		}
   148  //	}
   149  //
   150  //	func (x *X) StateSave(m Sink) {
   151  //		m.Save(0, &x.A) // Field is A.
   152  //		m.Save(1, &x.B) // Field is B.
   153  //	}
   154  //
   155  //	func (x *X) StateLoad(m Source) {
   156  //		m.Load(0, &x.A) // Field is A.
   157  //		m.Load(1, &x.B) // Field is B.
   158  //	}
   159  func (s Sink) Save(slot int, objPtr any) {
   160  	s.internal.save(slot, reflect.ValueOf(objPtr).Elem())
   161  }
   162  
   163  // SaveValue adds the given object value to the map.
   164  //
   165  // This should be used for values where pointers are not available, or casts
   166  // are required during Save/Load.
   167  //
   168  // For example, if we want to cast external package type P.Foo to int64:
   169  //
   170  //	func (x *X) StateSave(m Sink) {
   171  //		m.SaveValue(0, "A", int64(x.A))
   172  //	}
   173  //
   174  //	func (x *X) StateLoad(m Source) {
   175  //		m.LoadValue(0, new(int64), func(x any) {
   176  //			x.A = P.Foo(x.(int64))
   177  //		})
   178  //	}
   179  func (s Sink) SaveValue(slot int, obj any) {
   180  	s.internal.save(slot, reflect.ValueOf(obj))
   181  }
   182  
   183  // Context returns the context object provided at save time.
   184  func (s Sink) Context() context.Context {
   185  	return s.internal.es.ctx
   186  }
   187  
   188  // Type is an interface that must be implemented by Struct objects. This allows
   189  // these objects to be serialized while minimizing runtime reflection required.
   190  //
   191  // All these methods can be automatically generated by the go_statify tool.
   192  type Type interface {
   193  	// StateTypeName returns the type's name.
   194  	//
   195  	// This is used for matching type information during encoding and
   196  	// decoding, as well as dynamic interface dispatch. This should be
   197  	// globally unique.
   198  	StateTypeName() string
   199  
   200  	// StateFields returns information about the type.
   201  	//
   202  	// Fields is the set of fields for the object. Calls to Sink.Save and
   203  	// Source.Load must be made in-order with respect to these fields.
   204  	//
   205  	// This will be called at most once per serialization.
   206  	StateFields() []string
   207  }
   208  
   209  // SaverLoader must be implemented by struct types.
   210  type SaverLoader interface {
   211  	// StateSave saves the state of the object to the given Map.
   212  	StateSave(Sink)
   213  
   214  	// StateLoad loads the state of the object.
   215  	StateLoad(context.Context, Source)
   216  }
   217  
   218  // Source is used for Type.StateLoad.
   219  type Source struct {
   220  	internal objectDecoder
   221  }
   222  
   223  // Load loads the given object passed as a pointer..
   224  //
   225  // See Sink.Save for an example.
   226  func (s Source) Load(slot int, objPtr any) {
   227  	s.internal.load(slot, reflect.ValueOf(objPtr), false, nil)
   228  }
   229  
   230  // LoadWait loads the given objects from the map, and marks it as requiring all
   231  // AfterLoad executions to complete prior to running this object's AfterLoad.
   232  //
   233  // See Sink.Save for an example.
   234  func (s Source) LoadWait(slot int, objPtr any) {
   235  	s.internal.load(slot, reflect.ValueOf(objPtr), true, nil)
   236  }
   237  
   238  // LoadValue loads the given object value from the map.
   239  //
   240  // See Sink.SaveValue for an example.
   241  func (s Source) LoadValue(slot int, objPtr any, fn func(any)) {
   242  	o := reflect.ValueOf(objPtr)
   243  	s.internal.load(slot, o, true, func() { fn(o.Elem().Interface()) })
   244  }
   245  
   246  // AfterLoad schedules a function execution when all objects have been
   247  // allocated and their automated loading and customized load logic have been
   248  // executed. fn will not be executed until all of current object's
   249  // dependencies' AfterLoad() logic, if exist, have been executed.
   250  func (s Source) AfterLoad(fn func()) {
   251  	s.internal.afterLoad(fn)
   252  }
   253  
   254  // Context returns the context object provided at load time.
   255  func (s Source) Context() context.Context {
   256  	return s.internal.ds.ctx
   257  }
   258  
   259  // IsZeroValue checks if the given value is the zero value.
   260  //
   261  // This function is used by the stateify tool.
   262  func IsZeroValue(val any) bool {
   263  	return val == nil || reflect.ValueOf(val).Elem().IsZero()
   264  }
   265  
   266  // Failf is a wrapper around panic that should be used to generate errors that
   267  // can be caught during saving and loading.
   268  func Failf(fmtStr string, v ...any) {
   269  	panic(fmt.Errorf(fmtStr, v...))
   270  }
   271  
   272  // safely executes the given function, catching a panic and unpacking as an
   273  // error.
   274  //
   275  // The error flow through the state package uses panic and recover. There are
   276  // two important reasons for this:
   277  //
   278  // 1) Many of the reflection methods will already panic with invalid data or
   279  // violated assumptions. We would want to recover anyways here.
   280  //
   281  // 2) It allows us to eliminate boilerplate within Save() and Load() functions.
   282  // In nearly all cases, when the low-level serialization functions fail, you
   283  // will want the checkpoint to fail anyways. Plumbing errors through every
   284  // method doesn't add a lot of value. If there are specific error conditions
   285  // that you'd like to handle, you should add appropriate functionality to
   286  // objects themselves prior to calling Save() and Load().
   287  func safely(fn func()) (err error) {
   288  	defer func() {
   289  		if r := recover(); r != nil {
   290  			if es, ok := r.(*ErrState); ok {
   291  				err = es // Propagate.
   292  				return
   293  			}
   294  
   295  			// Build a new state error.
   296  			es := new(ErrState)
   297  			if e, ok := r.(error); ok {
   298  				es.err = e
   299  			} else {
   300  				es.err = fmt.Errorf("%v", r)
   301  			}
   302  
   303  			// Make a stack. We don't know how big it will be ahead
   304  			// of time, but want to make sure we get the whole
   305  			// thing. So we just do a stupid brute force approach.
   306  			var stack []byte
   307  			for sz := 1024; ; sz *= 2 {
   308  				stack = make([]byte, sz)
   309  				n := runtime.Stack(stack, false)
   310  				if n < sz {
   311  					es.trace = string(stack[:n])
   312  					break
   313  				}
   314  			}
   315  
   316  			// Set the error.
   317  			err = es
   318  		}
   319  	}()
   320  
   321  	// Execute the function.
   322  	fn()
   323  	return nil
   324  }