github.com/protolambda/zssz@v0.1.5/types/ssz_container.go (about) 1 package types 2 3 import ( 4 "fmt" 5 . "github.com/protolambda/zssz/dec" 6 . "github.com/protolambda/zssz/enc" 7 . "github.com/protolambda/zssz/htr" 8 "github.com/protolambda/zssz/merkle" 9 . "github.com/protolambda/zssz/pretty" 10 "github.com/protolambda/zssz/util/tags" 11 "reflect" 12 "unsafe" 13 ) 14 15 const SSZ_TAG = "ssz" 16 const OMIT_FLAG = "omit" 17 const SQUASH_FLAG = "squash" 18 19 type FieldPtrFn func(p unsafe.Pointer) unsafe.Pointer 20 21 func (fn FieldPtrFn) WrapOffset(memOffset uintptr) FieldPtrFn { 22 return func(p unsafe.Pointer) unsafe.Pointer { 23 return fn(unsafe.Pointer(uintptr(p) + memOffset)) 24 } 25 } 26 27 type ContainerField struct { 28 ssz SSZ 29 name string 30 pureName string 31 ptrFn FieldPtrFn 32 isFixed bool 33 } 34 35 func (c *ContainerField) Wrap(name string, memOffset uintptr) ContainerField { 36 return ContainerField{ 37 ssz: c.ssz, 38 name: name + ">" + c.name, 39 pureName: c.name, 40 ptrFn: c.ptrFn.WrapOffset(memOffset), 41 isFixed: c.ssz.IsFixed(), 42 } 43 } 44 45 type SquashableFields interface { 46 // Get the ContainerFields 47 SquashFields() []ContainerField 48 } 49 50 func GetOffsetPtrFn(memOffset uintptr) FieldPtrFn { 51 return func(p unsafe.Pointer) unsafe.Pointer { 52 return unsafe.Pointer(uintptr(p) + memOffset) 53 } 54 } 55 56 type SSZContainer struct { 57 Fields []ContainerField 58 isFixedLen bool 59 fixedLen uint64 60 minLen uint64 61 maxLen uint64 62 offsetCount uint64 // includes offsets for fields that are squashed in 63 fuzzMinLen uint64 64 fuzzMaxLen uint64 65 } 66 67 func (v *SSZContainer) SquashFields() []ContainerField { 68 return v.Fields 69 } 70 71 // Get the container fields for the given struct field 72 // 0 fields (nil) if struct field is ignored 73 // 1 field for normal struct fields 74 // 0 or more fields when a struct field is squashed (recursively adding to the total field collection) 75 func getFields(factory SSZFactoryFn, f *reflect.StructField) (out []ContainerField, err error) { 76 if tags.HasFlag(f, SSZ_TAG, OMIT_FLAG) { 77 return nil, nil 78 } 79 fieldSSZ, err := factory(f.Type) 80 if err != nil { 81 return nil, err 82 } 83 84 forceSquash := tags.HasFlag(f, SSZ_TAG, SQUASH_FLAG) 85 86 if f.Anonymous || forceSquash { 87 if squashable, ok := fieldSSZ.(SquashableFields); ok { 88 for _, sq := range squashable.SquashFields() { 89 out = append(out, sq.Wrap(f.Name, f.Offset)) 90 } 91 return out, nil 92 } 93 // anonymous fields can be handled as normal fields. Only error when it was tagged to be squashed. 94 if forceSquash { 95 return nil, fmt.Errorf("could not squash field %s", f.Name) 96 } 97 } 98 99 out = append(out, ContainerField{ 100 ssz: fieldSSZ, pureName: f.Name, name: f.Name, 101 ptrFn: GetOffsetPtrFn(f.Offset), isFixed: fieldSSZ.IsFixed()}) 102 return 103 } 104 105 func NewSSZContainer(factory SSZFactoryFn, typ reflect.Type) (*SSZContainer, error) { 106 if typ.Kind() != reflect.Struct { 107 return nil, fmt.Errorf("typ is not a struct") 108 } 109 res := new(SSZContainer) 110 for i, c := 0, typ.NumField(); i < c; i++ { 111 // get the Go struct field 112 sField := typ.Field(i) 113 // For this field, get the SSZ field(s). There may be more if the Go field is squashed. 114 fields, err := getFields(factory, &sField) 115 if err != nil { 116 return nil, err 117 } 118 res.Fields = append(res.Fields, fields...) 119 } 120 for _, field := range res.Fields { 121 if field.ssz.IsFixed() { 122 fixed, min, max := field.ssz.FixedLen(), field.ssz.MinLen(), field.ssz.MaxLen() 123 if fixed != min || fixed != max { 124 return nil, fmt.Errorf("fixed-size field ('%s') in struct has invalid min/max length", field.name) 125 } 126 res.fixedLen += fixed 127 res.minLen += fixed 128 res.maxLen += fixed 129 } else { 130 res.fixedLen += BYTES_PER_LENGTH_OFFSET 131 res.minLen += BYTES_PER_LENGTH_OFFSET + field.ssz.MinLen() 132 res.maxLen += BYTES_PER_LENGTH_OFFSET + field.ssz.MaxLen() 133 res.offsetCount++ 134 } 135 res.fuzzMinLen += field.ssz.FuzzMinLen() 136 res.fuzzMaxLen += field.ssz.FuzzMaxLen() 137 } 138 res.isFixedLen = res.offsetCount == 0 139 return res, nil 140 } 141 142 func (v *SSZContainer) FuzzMinLen() uint64 { 143 return v.fuzzMinLen 144 } 145 146 func (v *SSZContainer) FuzzMaxLen() uint64 { 147 return v.fuzzMaxLen 148 } 149 150 func (v *SSZContainer) MinLen() uint64 { 151 return v.minLen 152 } 153 154 func (v *SSZContainer) MaxLen() uint64 { 155 return v.maxLen 156 } 157 158 func (v *SSZContainer) FixedLen() uint64 { 159 return v.fixedLen 160 } 161 162 func (v *SSZContainer) IsFixed() bool { 163 return v.isFixedLen 164 } 165 166 func (v *SSZContainer) SizeOf(p unsafe.Pointer) uint64 { 167 out := v.fixedLen 168 for _, f := range v.Fields { 169 if !f.ssz.IsFixed() { 170 out += f.ssz.SizeOf(f.ptrFn(p)) 171 } 172 } 173 return out 174 } 175 176 func (v *SSZContainer) Encode(eb *EncodingWriter, p unsafe.Pointer) error { 177 // hot-path for common case of fixed-size container 178 if v.isFixedLen { 179 for i := range v.Fields { 180 f := &v.Fields[i] 181 if err := f.ssz.Encode(eb, f.ptrFn(p)); err != nil { 182 return err 183 } 184 } 185 return nil 186 } 187 // the previous offset, to calculate a new offset from, starting after the fixed data. 188 prevOffset := v.fixedLen 189 // span of the previous var-size element 190 prevSize := uint64(0) 191 for i := range v.Fields { 192 f := &v.Fields[i] 193 if f.isFixed { 194 if err := f.ssz.Encode(eb, f.ptrFn(p)); err != nil { 195 return err 196 } 197 } else { 198 if offset, err := eb.WriteOffset(prevOffset, prevSize); err != nil { 199 return err 200 } else { 201 prevOffset = offset 202 } 203 prevSize = f.ssz.SizeOf(f.ptrFn(p)) 204 } 205 } 206 // Only iterate over and write dynamic parts if we need to. 207 if !v.isFixedLen { 208 for i := range v.Fields { 209 f := &v.Fields[i] 210 if !f.isFixed { 211 if err := f.ssz.Encode(eb, f.ptrFn(p)); err != nil { 212 return err 213 } 214 } 215 } 216 } 217 return nil 218 } 219 220 func (v *SSZContainer) decodeVarSizeFuzzmode(dr *DecodingReader, p unsafe.Pointer) error { 221 lengthLeftOver := v.fuzzMinLen 222 223 for _, f := range v.Fields { 224 lengthLeftOver -= f.ssz.FuzzMinLen() 225 span := dr.GetBytesSpan() 226 if span < lengthLeftOver { 227 return fmt.Errorf("under estimated length requirements for fuzzing input, not enough data available to fuzz") 228 } 229 available := span - lengthLeftOver 230 231 scoped, err := dr.Scope(available) 232 if err != nil { 233 return err 234 } 235 scoped.EnableFuzzMode() 236 if err := f.ssz.Decode(scoped, f.ptrFn(p)); err != nil { 237 return err 238 } 239 dr.UpdateIndexFromScoped(scoped) 240 } 241 return nil 242 } 243 244 func (v *SSZContainer) decodeDynamicPart(dr *DecodingReader, offsets []uint64, fieldHandler func(dr *DecodingReader, f *ContainerField) error) error { 245 i := 0 246 for fi := range v.Fields { 247 f := &v.Fields[fi] 248 // ignore fixed-size fields 249 if f.ssz.IsFixed() { 250 continue 251 } 252 // calculate the scope based on next offset, and max. value of this scope for the last value 253 var scope uint64 254 { 255 currentOffset := offsets[i] 256 if next := i + 1; next < len(offsets) { 257 if nextOffset := offsets[next]; nextOffset >= currentOffset { 258 scope = nextOffset - currentOffset 259 } else { 260 return fmt.Errorf("offset %d for field %s is invalid", i, f.name) 261 } 262 } else { 263 scope = dr.Max() - currentOffset 264 } 265 } 266 { 267 realOffset := dr.Index() 268 if expectedOffset := offsets[i]; expectedOffset != realOffset { 269 return fmt.Errorf("expected to be at %d bytes, but currently at %d", expectedOffset, realOffset) 270 } 271 scoped, err := dr.Scope(scope) 272 if err != nil { 273 return err 274 } 275 if err := fieldHandler(scoped, f); err != nil { 276 return err 277 } 278 dr.UpdateIndexFromScoped(scoped) 279 } 280 // go to next offset 281 i++ 282 } 283 return nil 284 } 285 286 func (v *SSZContainer) processFixedPart(dr *DecodingReader, fieldHandler func(f *ContainerField) error) ([]uint64, error) { 287 // technically we could also ignore offset correctness and skip ahead, 288 // but we may want to enforce proper offsets. 289 offsets := make([]uint64, 0, v.offsetCount) 290 startIndex := dr.Index() 291 fixedI := dr.Index() 292 for fi := range v.Fields { 293 f := &v.Fields[fi] 294 if f.ssz.IsFixed() { 295 fixedI += f.ssz.FixedLen() 296 // No need to redefine the scope for fixed-length SSZ objects. 297 if err := fieldHandler(f); err != nil { 298 return nil, err 299 } 300 } else { 301 fixedI += BYTES_PER_LENGTH_OFFSET 302 // write an offset to the fixed data, to find the dynamic data with as a reader 303 offset, err := dr.ReadOffset() 304 if err != nil { 305 return nil, err 306 } 307 offsets = append(offsets, offset) 308 } 309 if i := dr.Index(); i != fixedI { 310 return nil, fmt.Errorf("fixed part had different size than expected, now at %d, expected to be at %d", i, fixedI) 311 } 312 } 313 pivotIndex := dr.Index() 314 if expectedIndex := v.fixedLen + startIndex; pivotIndex != expectedIndex { 315 return nil, fmt.Errorf("expected to read to %d bytes for fixed part of container, got to %d", expectedIndex, pivotIndex) 316 } 317 return offsets, nil 318 } 319 320 func (v *SSZContainer) decodeVarSize(dr *DecodingReader, p unsafe.Pointer) error { 321 offsets, err := v.processFixedPart(dr, func(f *ContainerField) error { 322 return f.ssz.Decode(dr, f.ptrFn(p)) 323 }) 324 if err != nil { 325 return err 326 } 327 return v.decodeDynamicPart(dr, offsets, func(scopedDr *DecodingReader, f *ContainerField) error { 328 return f.ssz.Decode(scopedDr, f.ptrFn(p)) 329 }) 330 } 331 332 func (v *SSZContainer) Decode(dr *DecodingReader, p unsafe.Pointer) error { 333 if dr.IsFuzzMode() { 334 return v.decodeVarSizeFuzzmode(dr, p) 335 } else { 336 return v.decodeVarSize(dr, p) 337 } 338 } 339 340 func (v *SSZContainer) DryCheck(dr *DecodingReader) error { 341 offsets, err := v.processFixedPart(dr, func(f *ContainerField) error { 342 return f.ssz.DryCheck(dr) 343 }) 344 if err != nil { 345 return err 346 } 347 return v.decodeDynamicPart(dr, offsets, func(scopedDr *DecodingReader, f *ContainerField) error { 348 return f.ssz.DryCheck(scopedDr) 349 }) 350 } 351 352 func (v *SSZContainer) HashTreeRoot(h MerkleFn, p unsafe.Pointer) [32]byte { 353 leaf := func(i uint64) []byte { 354 f := v.Fields[i] 355 r := f.ssz.HashTreeRoot(h, f.ptrFn(p)) 356 return r[:] 357 } 358 leafCount := uint64(len(v.Fields)) 359 return merkle.Merkleize(h, leafCount, leafCount, leaf) 360 } 361 362 func (v *SSZContainer) SigningRoot(h MerkleFn, p unsafe.Pointer) [32]byte { 363 leaf := func(i uint64) []byte { 364 f := v.Fields[i] 365 r := f.ssz.HashTreeRoot(h, f.ptrFn(p)) 366 return r[:] 367 } 368 // truncate last field 369 leafCount := uint64(len(v.Fields)) 370 if leafCount != 0 { 371 leafCount-- 372 } 373 return merkle.Merkleize(h, leafCount, leafCount, leaf) 374 } 375 376 func (v *SSZContainer) Pretty(indent uint32, w *PrettyWriter, p unsafe.Pointer) { 377 w.WriteIndent(indent) 378 w.Write("{\n") 379 for i, f := range v.Fields { 380 w.WriteIndent(indent + 1) 381 w.Write(f.pureName) 382 w.Write(":\n") 383 f.ssz.Pretty(indent+3, w, f.ptrFn(p)) 384 if i == len(v.Fields)-1 { 385 w.Write("\n") 386 } else { 387 w.Write(",\n") 388 } 389 } 390 w.WriteIndent(indent) 391 w.Write("}") 392 }