go-hep.org/x/hep@v0.38.1/groot/rbytes/rbuffer.go (about)

     1  // Copyright ©2017 The go-hep Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package rbytes
     6  
     7  import (
     8  	"encoding/hex"
     9  	"fmt"
    10  	"io"
    11  	"math"
    12  
    13  	"go-hep.org/x/hep/groot/root"
    14  	"go-hep.org/x/hep/groot/rtypes"
    15  	"go-hep.org/x/hep/groot/rvers"
    16  )
    17  
    18  type rbuff struct {
    19  	p []byte // buffer of data to read from
    20  	c int    // current position in buffer of data
    21  }
    22  
    23  func (r *rbuff) Read(p []byte) (int, error) {
    24  	if r.c >= len(r.p) {
    25  		return 0, io.EOF
    26  	}
    27  	n := copy(p, r.p[r.c:])
    28  	r.c += n
    29  	return n, nil
    30  }
    31  
    32  func (r *rbuff) ReadByte() (byte, error) {
    33  	if r.c >= len(r.p) {
    34  		return 0, io.EOF
    35  	}
    36  	v := r.p[r.c]
    37  	r.c++
    38  	return v, nil
    39  }
    40  
    41  func (r *rbuff) Seek(offset int64, whence int) (int64, error) {
    42  	switch whence {
    43  	case io.SeekStart:
    44  		r.c = int(offset)
    45  	case io.SeekCurrent:
    46  		r.c += int(offset)
    47  	case io.SeekEnd:
    48  		r.c = len(r.p) - int(offset)
    49  	default:
    50  		return 0, fmt.Errorf("rbytes: invalid whence")
    51  	}
    52  	if r.c < 0 {
    53  		return 0, fmt.Errorf("rbytes: negative position")
    54  	}
    55  	return int64(r.c), nil
    56  }
    57  
    58  // RBuffer is a read-only ROOT buffer for streaming.
    59  type RBuffer struct {
    60  	r      rbuff
    61  	err    error
    62  	offset uint32
    63  	refs   map[int64]any
    64  	sictx  StreamerInfoContext
    65  }
    66  
    67  func NewRBuffer(data []byte, refs map[int64]any, offset uint32, ctx StreamerInfoContext) *RBuffer {
    68  	if refs == nil {
    69  		refs = make(map[int64]any)
    70  	}
    71  
    72  	return &RBuffer{
    73  		r:      rbuff{p: data, c: 0},
    74  		refs:   refs,
    75  		offset: offset,
    76  		sictx:  ctx,
    77  	}
    78  }
    79  
    80  func (r *RBuffer) Reset(data []byte, refs map[int64]any, offset uint32, ctx StreamerInfoContext) *RBuffer {
    81  	if r == nil {
    82  		return NewRBuffer(data, refs, offset, ctx)
    83  	}
    84  	if refs == nil {
    85  		for k := range r.refs {
    86  			delete(r.refs, k)
    87  		}
    88  		refs = r.refs
    89  	}
    90  
    91  	r.r = rbuff{p: data, c: 0}
    92  	r.refs = refs
    93  	r.offset = offset
    94  	r.sictx = ctx
    95  	return r
    96  }
    97  
    98  // ReadHeader reads the serialization header for the given class and its known maximum version.
    99  func (r *RBuffer) ReadHeader(class string, vmax int16) Header {
   100  	hdr := Header{
   101  		Name: class,
   102  		Pos:  r.Pos(),
   103  	}
   104  	if r.err != nil {
   105  		return hdr
   106  	}
   107  
   108  	bcnt := r.ReadU32()
   109  	if (int64(bcnt) & kByteCountMask) != 0 {
   110  		hdr.Len = int32(int64(bcnt) & ^kByteCountMask)
   111  		hdr.Vers = int16(r.ReadU16())
   112  	} else {
   113  		// no byte count. rewind and read version
   114  		r.setPos(int64(hdr.Pos))
   115  		hdr.Vers = int16(r.ReadU16())
   116  	}
   117  
   118  	hdr.MemberWise = hdr.Vers&StreamedMemberWise != 0
   119  	if hdr.MemberWise {
   120  		hdr.Vers &= ^StreamedMemberWise
   121  	}
   122  
   123  	if hdr.Vers <= 0 {
   124  		if hdr.Name != "" && r.sictx != nil {
   125  			si, err := r.sictx.StreamerInfo(hdr.Name, -1)
   126  			if err == nil && si.ClassVersion() != int(hdr.Vers) {
   127  				chksum := r.ReadU32()
   128  				if si.CheckSum() == int(chksum) {
   129  					hdr.Vers = int16(si.ClassVersion())
   130  				}
   131  			}
   132  		}
   133  	}
   134  
   135  	// vmax=-1 is a special "garbage-in/garbage-out" mode.
   136  	if vmax >= 0 && hdr.Vers > vmax {
   137  		panic(fmt.Errorf("rbytes: invalid version for %q: got=%d > max=%d", class, hdr.Vers, vmax))
   138  	}
   139  
   140  	return hdr
   141  }
   142  
   143  // StreamerInfo returns the named StreamerInfo.
   144  // If version is negative, the latest version should be returned.
   145  func (r *RBuffer) StreamerInfo(name string, version int) (StreamerInfo, error) {
   146  	if r.sictx == nil {
   147  		return nil, fmt.Errorf("rbytes: no streamers")
   148  	}
   149  	return r.sictx.StreamerInfo(name, version)
   150  }
   151  
   152  func (r *RBuffer) Pos() int64 {
   153  	return int64(r.r.c) + int64(r.offset)
   154  }
   155  
   156  func (r *RBuffer) SetPos(pos int64) { r.setPos(pos) }
   157  func (r *RBuffer) setPos(pos int64) {
   158  	pos -= int64(r.offset)
   159  	r.r.c = int(pos)
   160  }
   161  
   162  func (r *RBuffer) Len() int64 {
   163  	return int64(len(r.r.p) - r.r.c)
   164  }
   165  
   166  func (r *RBuffer) Err() error       { return r.err }
   167  func (r *RBuffer) SetErr(err error) { r.err = err }
   168  
   169  func (r *RBuffer) read(data []byte) {
   170  	if r.err != nil {
   171  		return
   172  	}
   173  	n := copy(data, r.r.p[r.r.c:])
   174  	r.r.c += n
   175  }
   176  
   177  func (r *RBuffer) Read(p []byte) (int, error) {
   178  	if r.err != nil {
   179  		return 0, r.err
   180  	}
   181  	n, err := r.r.Read(p)
   182  	r.err = err
   183  	return n, r.err
   184  }
   185  
   186  func (r *RBuffer) bytes() []byte {
   187  	return r.r.p[r.r.c:]
   188  }
   189  
   190  // func (r *RBuffer) dumpRefs() {
   191  // 	fmt.Printf("--- refs ---\n")
   192  // 	ids := make([]int64, 0, len(r.refs))
   193  // 	for k := range r.refs {
   194  // 		ids = append(ids, k)
   195  // 	}
   196  // 	sort.Sort(int64Slice(ids))
   197  // 	for _, id := range ids {
   198  // 		fmt.Printf(" id=%4d -> %v\n", id, r.refs[id])
   199  // 	}
   200  // }
   201  //
   202  // type int64Slice []int64
   203  //
   204  // func (p int64Slice) Len() int           { return len(p) }
   205  // func (p int64Slice) Less(i, j int) bool { return p[i] < p[j] }
   206  // func (p int64Slice) Swap(i, j int)      { p[i], p[j] = p[j], p[i] }
   207  
   208  func (r *RBuffer) DumpHex(n int) {
   209  	buf := r.bytes()
   210  	if len(buf) > n {
   211  		buf = buf[:n]
   212  	}
   213  	fmt.Printf("--- hex --- (pos=%d len=%d end=%d)\n%s\n", r.Pos(), n, r.Len(), string(hex.Dump(buf)))
   214  }
   215  
   216  func (r *RBuffer) ReadStdString() string {
   217  	if r.Err() != nil {
   218  		return ""
   219  	}
   220  
   221  	hdr := r.ReadHeader("string", rvers.StreamerBaseSTL) // FIXME(sbinet): streamline with RStreamROOT
   222  	if hdr.Vers > rvers.StreamerBaseSTL {
   223  		r.SetErr(fmt.Errorf("rbytes: invalid version for std::string. got=%v, want=%v", hdr.Vers, rvers.StreamerBaseSTL))
   224  		return ""
   225  	}
   226  
   227  	o := r.ReadString()
   228  	r.CheckHeader(hdr)
   229  
   230  	return o
   231  }
   232  
   233  func (r *RBuffer) ReadString() string {
   234  	if r.err != nil {
   235  		return ""
   236  	}
   237  
   238  	u8 := r.ReadU8()
   239  	n := int(u8)
   240  	if u8 == 255 {
   241  		// large string
   242  		n = int(r.ReadU32())
   243  	}
   244  	if n == 0 {
   245  		return ""
   246  	}
   247  	v := r.ReadU8()
   248  	if v == 0 {
   249  		return ""
   250  	}
   251  	buf := make([]byte, n)
   252  	buf[0] = v
   253  	if n != 0 {
   254  		r.read(buf[1:])
   255  		if r.err != nil {
   256  			return ""
   257  		}
   258  		return string(buf)
   259  	}
   260  	return ""
   261  }
   262  
   263  func (r *RBuffer) ReadCString(n int) string {
   264  	if r.err != nil {
   265  		return ""
   266  	}
   267  
   268  	buf := make([]byte, n)
   269  	for i := range n {
   270  		r.read(buf[i : i+1])
   271  		if buf[i] == 0 {
   272  			buf = buf[:i]
   273  			break
   274  		}
   275  	}
   276  	return string(buf)
   277  }
   278  
   279  func (r *RBuffer) ReadBool() bool {
   280  	if r.err != nil {
   281  		return false
   282  	}
   283  	return r.readBool()
   284  }
   285  
   286  func (r *RBuffer) readBool() bool {
   287  	v := r.readI8()
   288  	return v != 0
   289  }
   290  
   291  func (r *RBuffer) ReadU8() uint8 {
   292  	if r.err != nil {
   293  		return 0
   294  	}
   295  	return r.readU8()
   296  }
   297  
   298  func (r *RBuffer) readU8() uint8 {
   299  	beg := r.r.c
   300  	r.r.c++
   301  	v := r.r.p[beg]
   302  	return uint8(v)
   303  }
   304  
   305  func (r *RBuffer) ReadI8() int8 {
   306  	if r.err != nil {
   307  		return 0
   308  	}
   309  	return r.readI8()
   310  }
   311  
   312  func (r *RBuffer) readI8() int8 {
   313  	beg := r.r.c
   314  	r.r.c++
   315  	v := r.r.p[beg]
   316  	return int8(v)
   317  }
   318  
   319  func (r *RBuffer) ReadF16(elm StreamerElement) root.Float16 {
   320  	switch {
   321  	case elm != nil && elm.Factor() != 0:
   322  		return r.readWithFactorF16(elm.Factor(), elm.XMin())
   323  	default:
   324  		var nbits uint32
   325  		if elm != nil {
   326  			nbits = uint32(elm.XMin())
   327  		}
   328  		if nbits == 0 {
   329  			nbits = 12
   330  		}
   331  		return r.readWithNbitsF16(nbits)
   332  	}
   333  }
   334  
   335  func (r *RBuffer) readWithFactorF16(f, xmin float64) root.Float16 {
   336  	v := float64(r.ReadU32())
   337  	return root.Float16(v/f + xmin)
   338  }
   339  
   340  func (r *RBuffer) readWithNbitsF16(nbits uint32) root.Float16 {
   341  	var (
   342  		exp = uint32(r.ReadU8())
   343  		man = uint32(r.ReadU16())
   344  		val = uint32(exp)
   345  	)
   346  	val <<= 23
   347  	val |= (man & ((1 << (nbits + 1)) - 1)) << (23 - nbits)
   348  
   349  	f := math.Float32frombits(val)
   350  	if (1 << (nbits + 1) & man) != 0 {
   351  		f = -f
   352  	}
   353  
   354  	return root.Float16(f)
   355  }
   356  
   357  func (r *RBuffer) ReadD32(elm StreamerElement) root.Double32 {
   358  	switch {
   359  	case elm != nil && elm.Factor() != 0:
   360  		return r.readWithFactorD32(elm.Factor(), elm.XMin())
   361  	default:
   362  		var nbits uint32
   363  		if elm != nil {
   364  			nbits = uint32(elm.XMin())
   365  		}
   366  		if nbits == 0 {
   367  			f32 := r.ReadF32()
   368  			return root.Double32(f32)
   369  		}
   370  		return r.readWithNbitsD32(nbits)
   371  	}
   372  }
   373  
   374  func (r *RBuffer) readWithFactorD32(f, xmin float64) root.Double32 {
   375  	v := float64(r.ReadU32())
   376  	return root.Double32(v/f + xmin)
   377  }
   378  
   379  func (r *RBuffer) readWithNbitsD32(nbits uint32) root.Double32 {
   380  	var (
   381  		exp = uint32(r.ReadU8())
   382  		man = uint32(r.ReadU16())
   383  		val = uint32(exp)
   384  	)
   385  	val <<= 23
   386  	val |= (man & ((1 << (nbits + 1)) - 1)) << (23 - nbits)
   387  
   388  	f := math.Float32frombits(val)
   389  	if (1 << (nbits + 1) & man) != 0 {
   390  		f = -f
   391  	}
   392  
   393  	return root.Double32(f)
   394  }
   395  
   396  func (r *RBuffer) ReadStaticArrayI32() []int32 {
   397  	if r.err != nil {
   398  		return nil
   399  	}
   400  
   401  	n := int(r.ReadI32())
   402  	if n <= 0 || int64(n) > r.Len() {
   403  		return nil
   404  	}
   405  
   406  	arr := make([]int32, n)
   407  	for i := range arr {
   408  		arr[i] = r.ReadI32()
   409  	}
   410  
   411  	if r.err != nil {
   412  		return nil
   413  	}
   414  
   415  	return arr
   416  }
   417  
   418  func (r *RBuffer) ReadArrayBool(arr []bool) {
   419  	if r.err != nil {
   420  		return
   421  	}
   422  	n := len(arr)
   423  	if n <= 0 || int64(n) > r.Len() {
   424  		return
   425  	}
   426  
   427  	for i := range arr {
   428  		arr[i] = r.readBool()
   429  	}
   430  }
   431  
   432  func (r *RBuffer) ReadArrayI8(arr []int8) {
   433  	if r.err != nil {
   434  		return
   435  	}
   436  	n := len(arr)
   437  	if n <= 0 || int64(n) > r.Len() {
   438  		return
   439  	}
   440  
   441  	for i := range arr {
   442  		arr[i] = r.readI8()
   443  	}
   444  }
   445  
   446  func (r *RBuffer) ReadArrayU8(arr []uint8) {
   447  	if r.err != nil {
   448  		return
   449  	}
   450  	n := len(arr)
   451  	if n <= 0 || int64(n) > r.Len() {
   452  		return
   453  	}
   454  
   455  	for i := range arr {
   456  		arr[i] = r.readU8()
   457  	}
   458  }
   459  
   460  func (r *RBuffer) ReadArrayF16(arr []root.Float16, elm StreamerElement) {
   461  	if r.err != nil {
   462  		return
   463  	}
   464  	n := len(arr)
   465  	if n <= 0 || int64(n) > r.Len() {
   466  		return
   467  	}
   468  
   469  	for i := range arr {
   470  		arr[i] = r.ReadF16(elm)
   471  	}
   472  }
   473  
   474  func (r *RBuffer) ReadArrayD32(arr []root.Double32, elm StreamerElement) {
   475  	if r.err != nil {
   476  		return
   477  	}
   478  
   479  	n := len(arr)
   480  	if n <= 0 || int64(n) > r.Len() {
   481  		return
   482  	}
   483  
   484  	for i := range arr {
   485  		arr[i] = r.ReadD32(elm)
   486  	}
   487  }
   488  
   489  func (r *RBuffer) ReadArrayString(arr []string) {
   490  	if r.err != nil {
   491  		return
   492  	}
   493  	n := len(arr)
   494  	if n <= 0 || int64(n) > r.Len() {
   495  		return
   496  	}
   497  
   498  	for i := range arr {
   499  		arr[i] = r.ReadString()
   500  	}
   501  }
   502  
   503  func (r *RBuffer) ReadStdVectorStrs(sli *[]string) {
   504  	if r.err != nil {
   505  		return
   506  	}
   507  
   508  	hdr := r.ReadHeader("vector<string>", rvers.StreamerBaseSTL)
   509  	n := int(r.ReadI32())
   510  	*sli = ResizeStr(*sli, n)
   511  	for i := range *sli {
   512  		(*sli)[i] = r.ReadString()
   513  	}
   514  	r.CheckHeader(hdr)
   515  }
   516  
   517  func (r *RBuffer) SkipVersion(class string) {
   518  	if r.err != nil {
   519  		return
   520  	}
   521  
   522  	version := r.ReadI16()
   523  
   524  	if int64(version)&kByteCountVMask != 0 {
   525  		_ = r.ReadI16()
   526  		_ = r.ReadI16()
   527  	}
   528  
   529  	if class != "" && version <= 1 {
   530  		panic("not implemented")
   531  	}
   532  }
   533  
   534  //func (r *RBuffer) chk(pos, count int32) bool {
   535  //	if count <= 0 {
   536  //		return true
   537  //	}
   538  //
   539  //	var (
   540  //		want = int64(pos) + int64(count) + 4
   541  //		got  = r.Pos()
   542  //	)
   543  //
   544  //	return got == want
   545  //}
   546  
   547  func (r *RBuffer) CheckHeader(hdr Header) {
   548  	if r.err != nil {
   549  		return
   550  	}
   551  
   552  	if hdr.Len <= 0 {
   553  		return
   554  	}
   555  
   556  	var (
   557  		want = hdr.Pos + int64(hdr.Len) + 4
   558  		got  = r.Pos()
   559  	)
   560  
   561  	switch {
   562  	case got == want:
   563  		return
   564  
   565  	case got > want:
   566  		r.err = fmt.Errorf("rbytes: read too many bytes. got=%d, want=%d (pos=%d count=%d) [class=%q]",
   567  			got, want, hdr.Pos, hdr.Len, hdr.Name,
   568  		)
   569  		return
   570  
   571  	case got < want:
   572  		r.err = fmt.Errorf("rbytes: read too few bytes. got=%d, want=%d (pos=%d count=%d) [class=%q]",
   573  			got, want, hdr.Pos, hdr.Len, hdr.Name,
   574  		)
   575  		return
   576  	}
   577  }
   578  
   579  func (r *RBuffer) SkipObject() {
   580  	if r.err != nil {
   581  		return
   582  	}
   583  	vers := r.ReadI16()
   584  	if vers&kByteCountVMask != 0 {
   585  		_, r.err = r.r.Seek(4, io.SeekCurrent)
   586  		if r.err != nil {
   587  			return
   588  		}
   589  	}
   590  	_ = r.ReadU32() // fUniqueID
   591  	fbits := r.ReadU32() | kIsOnHeap
   592  
   593  	if fbits&kIsReferenced != 0 {
   594  		_, r.err = r.r.Seek(2, io.SeekCurrent)
   595  		if r.err != nil {
   596  			return
   597  		}
   598  	}
   599  }
   600  
   601  func (r *RBuffer) ReadObject(obj Unmarshaler) {
   602  	if r.err != nil {
   603  		return
   604  	}
   605  
   606  	r.err = obj.UnmarshalROOT(r)
   607  }
   608  
   609  func (r *RBuffer) ReadObjectAny() (obj root.Object) {
   610  	if r.err != nil {
   611  		return obj
   612  	}
   613  
   614  	beg := r.Pos()
   615  	var (
   616  		tag   uint32
   617  		vers  int32
   618  		start int64
   619  		bcnt  = r.ReadU32()
   620  	)
   621  
   622  	if int64(bcnt)&kByteCountMask == 0 || int64(bcnt) == kNewClassTag {
   623  		tag = bcnt
   624  		bcnt = 0
   625  	} else {
   626  		vers = 1
   627  		start = r.Pos()
   628  		tag = r.ReadU32()
   629  	}
   630  
   631  	tag64 := int64(tag)
   632  	switch {
   633  	case tag64&kClassMask == 0:
   634  		if tag64 == kNullTag {
   635  			return nil
   636  		}
   637  		// FIXME(sbinet): tag==1 means "self". not implemented yet.
   638  		if tag == 1 {
   639  			r.err = fmt.Errorf("rbytes: tag == 1 means 'self'. not implemented yet")
   640  			return nil
   641  		}
   642  
   643  		o, ok := r.refs[tag64]
   644  		if !ok {
   645  			r.setPos(beg + int64(bcnt) + 4)
   646  			// r.err = fmt.Errorf("rbytes: invalid tag [%v] found", tag64)
   647  			return nil
   648  		}
   649  		obj, ok = o.(root.Object)
   650  		if !ok {
   651  			r.err = fmt.Errorf("rbytes: invalid tag [%v] found (not a root.Object)", tag64)
   652  			return nil
   653  		}
   654  		return obj
   655  
   656  	case tag64 == kNewClassTag:
   657  		cname := r.ReadCString(80)
   658  		fct := rtypes.Factory.Get(cname)
   659  
   660  		if vers > 0 {
   661  			r.refs[start+kMapOffset] = fct
   662  		} else {
   663  			r.refs[int64(len(r.refs))+1] = fct
   664  		}
   665  
   666  		obj = fct().Interface().(root.Object)
   667  		r.ReadObject(obj.(Unmarshaler))
   668  		if r.Err() != nil {
   669  			return nil
   670  		}
   671  
   672  		if vers > 0 {
   673  			r.refs[beg+kMapOffset] = obj
   674  		} else {
   675  			r.refs[int64(len(r.refs))+1] = obj
   676  		}
   677  		return obj
   678  
   679  	default:
   680  		ref := tag64 & ^kClassMask
   681  		cls, ok := r.refs[ref]
   682  		if !ok {
   683  			r.err = fmt.Errorf("rbytes: invalid class-tag reference [%v] found", ref)
   684  			return nil
   685  		}
   686  
   687  		fct, ok := cls.(rtypes.FactoryFct)
   688  		if !ok {
   689  			r.err = fmt.Errorf("rbytes: invalid class-tag reference [%v] found (not a rypes.FactoryFct: %T)", ref, cls)
   690  			return nil
   691  		}
   692  
   693  		obj = fct().Interface().(root.Object)
   694  		if vers > 0 {
   695  			r.refs[beg+kMapOffset] = obj
   696  		} else {
   697  			r.refs[int64(len(r.refs))+1] = obj
   698  		}
   699  
   700  		r.ReadObject(obj.(Unmarshaler))
   701  		if r.Err() != nil {
   702  			return nil
   703  		}
   704  		return obj
   705  	}
   706  }
   707  
   708  func (r *RBuffer) RStream(si StreamerInfo, ptr any) error {
   709  	const kind = ObjectWise
   710  	dec, err := si.NewDecoder(kind, r)
   711  	if err != nil {
   712  		return fmt.Errorf("rbytes: could not create %v decoder for %q (version=%d): %w", kind, si.Name(), si.ClassVersion(), err)
   713  	}
   714  	return dec.DecodeROOT(ptr)
   715  }
   716  
   717  func (r *RBuffer) ReadStdBitset(v []uint8) {
   718  	n := len(v)
   719  	for i := range v {
   720  		v[n-1-i] = r.readU8()
   721  	}
   722  }
   723  
   724  var (
   725  	_ StreamerInfoContext = (*RBuffer)(nil)
   726  )