go-hep.org/x/hep@v0.38.1/sio/encoder.go (about)

     1  // Copyright ©2017 The go-hep Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package sio
     6  
     7  import (
     8  	"encoding/binary"
     9  	"reflect"
    10  )
    11  
    12  // Encoder encodes values into a SIO stream.
    13  // Encoder provides a nice API to deal with errors that may occur during encoding.
    14  type Encoder struct {
    15  	w   Writer
    16  	err error
    17  }
    18  
    19  // NewEncoder creates a new Encoder writing to the provided sio.Writer.
    20  func NewEncoder(w Writer) *Encoder {
    21  	return &Encoder{w: w}
    22  }
    23  
    24  // Encode writes the next value to the output sio stream.
    25  func (enc *Encoder) Encode(data any) {
    26  	if enc.err != nil {
    27  		return
    28  	}
    29  	enc.err = marshal(enc.w, data)
    30  }
    31  
    32  // Tag tags a pointer, assigning it a unique identifier, so links between values
    33  // (inside a given sio record) can be stored.
    34  func (enc *Encoder) Tag(ptr any) {
    35  	if enc.err != nil {
    36  		return
    37  	}
    38  	enc.err = enc.w.Tag(ptr)
    39  }
    40  
    41  // Pointer marks a (pointer to a) pointer, assigning it a unique identifier,
    42  // so links between values (inside a given SIO record) can be stored.
    43  func (enc *Encoder) Pointer(ptr any) {
    44  	if enc.err != nil {
    45  		return
    46  	}
    47  	enc.err = enc.w.Pointer(ptr)
    48  }
    49  
    50  // Err returns the first encountered error while decoding, if any.
    51  func (dec *Encoder) Err() error {
    52  	return dec.err
    53  }
    54  
    55  // marshal marshals ptr to a stream of bytes.
    56  // If ptr implements Codec, use it.
    57  func marshal(w Writer, ptr any) error {
    58  	if ptr, ok := ptr.(Marshaler); ok {
    59  		return ptr.MarshalSio(w)
    60  	}
    61  	return bwrite(w, ptr)
    62  }
    63  
    64  func bwrite(w Writer, data any) error {
    65  
    66  	bo := binary.BigEndian
    67  	rrv := reflect.ValueOf(data)
    68  	rv := reflect.Indirect(rrv)
    69  	// fmt.Printf("::: [%v] :::...\n", rrv.Type())
    70  	// defer fmt.Printf("### [%v] [done]\n", rrv.Type())
    71  
    72  	switch rv.Type().Kind() {
    73  	case reflect.Struct:
    74  		//fmt.Printf(">>> struct: [%v]...\n", rv.Type())
    75  		for i, n := 0, rv.NumField(); i < n; i++ {
    76  			//fmt.Printf(">>> i=%d [%v] (%v)...\n", i, rv.Field(i).Type(), rv.Type().Name())
    77  			err := marshal(w, rv.Field(i).Addr().Interface())
    78  			if err != nil {
    79  				return err
    80  			}
    81  			//fmt.Printf(">>> i=%d [%v] (%v)...[done]\n", i, rv.Field(i).Type(), rv.Type().Name())
    82  		}
    83  		//fmt.Printf(">>> struct: [%v]...[done]\n", rv.Type())
    84  		return nil
    85  	case reflect.String:
    86  		str := rv.String()
    87  		sz := uint32(len(str))
    88  		// fmt.Printf("++++> (%d) [%s]\n", sz, string(str))
    89  		err := bwrite(w, &sz)
    90  		if err != nil {
    91  			return err
    92  		}
    93  		bstr := []byte(str)
    94  		bstr = append(bstr, make([]byte, align4U32(sz)-sz)...)
    95  		_, err = w.Write(bstr)
    96  		if err != nil {
    97  			return err
    98  		}
    99  		// fmt.Printf("<++++ (%d) [%s]\n", sz, string(str))
   100  		return nil
   101  
   102  	case reflect.Slice:
   103  		// fmt.Printf(">>> slice: [%v|%v]...\n", rv.Type(), rv.Type().Elem().Kind())
   104  		sz := uint32(rv.Len())
   105  		// fmt.Printf(">>> slice: %d [%v]\n", sz, rv.Type())
   106  		err := bwrite(w, &sz)
   107  		if err != nil {
   108  			return err
   109  		}
   110  		for i := range int(sz) {
   111  			err = marshal(w, rv.Index(i).Addr().Interface())
   112  			if err != nil {
   113  				return err
   114  			}
   115  		}
   116  		// fmt.Printf(">>> slice: [%v]... [done] (%v)\n", rv.Type(), rv.Interface())
   117  		return err
   118  
   119  	case reflect.Map:
   120  		//fmt.Printf(">>> map: [%v]...\n", rv.Type())
   121  		sz := uint32(rv.Len())
   122  		err := bwrite(w, &sz)
   123  		if err != nil {
   124  			return err
   125  		}
   126  		//fmt.Printf(">>> map: %d [%v]\n", sz, rv.Type())
   127  		for _, kv := range rv.MapKeys() {
   128  			vv := rv.MapIndex(kv)
   129  			err = marshal(w, kv.Interface())
   130  			if err != nil {
   131  				return err
   132  			}
   133  			err = marshal(w, vv.Interface())
   134  			if err != nil {
   135  				return err
   136  			}
   137  			//fmt.Printf("m - %d: {%v} - {%v}\n", i, kv.Elem().Interface(), vv.Elem().Interface())
   138  		}
   139  		return nil
   140  
   141  	default:
   142  		//fmt.Printf(">>> binary - [%v]...\n", rv.Type())
   143  		err := binary.Write(w, bo, data)
   144  		//fmt.Printf(">>> binary - [%v]... [done]\n", rv.Type())
   145  		return err
   146  	}
   147  }