github.com/fluhus/gostuff@v0.4.1-0.20240331134726-be71864f2b5d/bnry/read.go (about)

     1  package bnry
     2  
     3  import (
     4  	"bytes"
     5  	"encoding/binary"
     6  	"fmt"
     7  	"io"
     8  	"math"
     9  	"reflect"
    10  	"slices"
    11  
    12  	"golang.org/x/exp/constraints"
    13  )
    14  
    15  // UnmarshalBinary decodes binary data into the given values.
    16  // Values should be pointers to any of the supported types.
    17  // Panics if a value is of an unsupported type.
    18  func UnmarshalBinary(data []byte, vals ...any) error {
    19  	return Read(bytes.NewBuffer(data), vals...)
    20  }
    21  
    22  // Read reads and decodes binary data into the given values.
    23  // Values should be pointers to any of the supported types.
    24  // Panics if a value is of an unsupported type.
    25  func Read(r io.ByteReader, vals ...any) error {
    26  	for i, val := range vals {
    27  		if err := readSingle(r, val); err != nil {
    28  			if i > 0 {
    29  				err = notExpectingEOF(err)
    30  			}
    31  			if err != io.EOF {
    32  				err = fmt.Errorf("reading value #%d: %w", i+1, err)
    33  			}
    34  			return err
    35  		}
    36  	}
    37  	return nil
    38  }
    39  
    40  // Decodes a single value from r.
    41  func readSingle(r io.ByteReader, val any) error {
    42  	switch val := val.(type) {
    43  	case *uint8:
    44  		return readUint8(r, val)
    45  	case *uint16:
    46  		return readUvarint(r, val)
    47  	case *uint32:
    48  		return readUvarint(r, val)
    49  	case *uint64:
    50  		return readUvarint(r, val)
    51  	case *uint:
    52  		return readUvarint(r, val)
    53  	case *int8:
    54  		return readInt8(r, val)
    55  	case *int16:
    56  		return readVarint(r, val)
    57  	case *int32:
    58  		return readVarint(r, val)
    59  	case *int64:
    60  		return readVarint(r, val)
    61  	case *int:
    62  		return readVarint(r, val)
    63  	case *float32:
    64  		return readFloat32(r, val)
    65  	case *float64:
    66  		return readFloat64(r, val)
    67  	case *bool:
    68  		return readBool(r, val)
    69  	case *string:
    70  		return readString(r, val)
    71  	case *[]uint8:
    72  		return readUint8Slice(r, val)
    73  	case *[]uint16:
    74  		return readUintSlice(r, val)
    75  	case *[]uint32:
    76  		return readUintSlice(r, val)
    77  	case *[]uint64:
    78  		return readUintSlice(r, val)
    79  	case *[]uint:
    80  		return readUintSlice(r, val)
    81  	case *[]int8:
    82  		return readInt8Slice(r, val)
    83  	case *[]int16:
    84  		return readIntSlice(r, val)
    85  	case *[]int32:
    86  		return readIntSlice(r, val)
    87  	case *[]int64:
    88  		return readIntSlice(r, val)
    89  	case *[]int:
    90  		return readIntSlice(r, val)
    91  	case *[]float32:
    92  		return readFloat32Slice(r, val)
    93  	case *[]float64:
    94  		return readFloat64Slice(r, val)
    95  	case *[]bool:
    96  		return readBoolSlice(r, val)
    97  	case *[]string:
    98  		return readStringSlice(r, val)
    99  	default:
   100  		panic(fmt.Sprintf("unsupported type: %v", reflect.TypeOf(val).Name()))
   101  	}
   102  }
   103  
   104  func readUint8(r io.ByteReader, val *uint8) error {
   105  	x, err := r.ReadByte()
   106  	*val = x
   107  	return err
   108  }
   109  
   110  func readInt8(r io.ByteReader, val *int8) error {
   111  	x, err := r.ReadByte()
   112  	*val = int8(x)
   113  	return err
   114  }
   115  
   116  func readUvarint[T constraints.Unsigned](r io.ByteReader, val *T) error {
   117  	x, err := binary.ReadUvarint(r)
   118  	*val = T(x)
   119  	return err
   120  }
   121  
   122  func readVarint[T constraints.Signed](r io.ByteReader, val *T) error {
   123  	x, err := binary.ReadVarint(r)
   124  	*val = T(x)
   125  	return err
   126  }
   127  
   128  func readFloat32(r io.ByteReader, val *float32) error {
   129  	x, err := binary.ReadUvarint(r)
   130  	*val = math.Float32frombits(uint32(x))
   131  	return err
   132  }
   133  
   134  func readFloat64(r io.ByteReader, val *float64) error {
   135  	x, err := binary.ReadUvarint(r)
   136  	*val = math.Float64frombits(x)
   137  	return err
   138  }
   139  
   140  func readBool(r io.ByteReader, val *bool) error {
   141  	b, err := r.ReadByte()
   142  	if err != nil {
   143  		return err
   144  	}
   145  	switch b {
   146  	case 0:
   147  		*val = false
   148  	case 1:
   149  		*val = true
   150  	default:
   151  		return fmt.Errorf("unexpected value for bool: %v, want 0 or 1", b)
   152  	}
   153  	return nil
   154  }
   155  
   156  func readString(r io.ByteReader, s *string) error {
   157  	var buf []byte
   158  	if err := readUint8Slice(r, &buf); err != nil {
   159  		return err
   160  	}
   161  	*s = string(buf)
   162  	return nil
   163  }
   164  
   165  func readUint8Slice(r io.ByteReader, val *[]uint8) error {
   166  	n, err := binary.ReadUvarint(r)
   167  	if err != nil {
   168  		return err
   169  	}
   170  	buf := slices.Grow(*val, int(n))[:0]
   171  	for range n {
   172  		b, err := r.ReadByte()
   173  		if err != nil {
   174  			return notExpectingEOF(err)
   175  		}
   176  		buf = append(buf, b)
   177  	}
   178  	*val = buf
   179  	return nil
   180  }
   181  
   182  func readInt8Slice(r io.ByteReader, val *[]int8) error {
   183  	n, err := binary.ReadUvarint(r)
   184  	if err != nil {
   185  		return err
   186  	}
   187  	buf := slices.Grow(*val, int(n))[:0]
   188  	for range n {
   189  		b, err := r.ReadByte()
   190  		if err != nil {
   191  			return notExpectingEOF(err)
   192  		}
   193  		buf = append(buf, int8(b))
   194  	}
   195  	*val = buf
   196  	return nil
   197  }
   198  
   199  func readUintSlice[T constraints.Unsigned](r io.ByteReader, val *[]T) error {
   200  	n, err := binary.ReadUvarint(r)
   201  	if err != nil {
   202  		return err
   203  	}
   204  	buf := slices.Grow(*val, int(n))[:0]
   205  	for range n {
   206  		x, err := binary.ReadUvarint(r)
   207  		if err != nil {
   208  			return notExpectingEOF(err)
   209  		}
   210  		buf = append(buf, T(x))
   211  	}
   212  	*val = buf
   213  	return nil
   214  }
   215  
   216  func readIntSlice[T constraints.Signed](r io.ByteReader, val *[]T) error {
   217  	n, err := binary.ReadUvarint(r)
   218  	if err != nil {
   219  		return err
   220  	}
   221  	buf := slices.Grow(*val, int(n))[:0]
   222  	for range n {
   223  		x, err := binary.ReadVarint(r)
   224  		if err != nil {
   225  			return notExpectingEOF(err)
   226  		}
   227  		buf = append(buf, T(x))
   228  	}
   229  	*val = buf
   230  	return nil
   231  }
   232  
   233  func readFloat32Slice(r io.ByteReader, val *[]float32) error {
   234  	n, err := binary.ReadUvarint(r)
   235  	if err != nil {
   236  		return err
   237  	}
   238  	buf := slices.Grow(*val, int(n))[:0]
   239  	for range n {
   240  		var x float32
   241  		if err := readFloat32(r, &x); err != nil {
   242  			return notExpectingEOF(err)
   243  		}
   244  		buf = append(buf, x)
   245  	}
   246  	*val = buf
   247  	return nil
   248  }
   249  
   250  func readFloat64Slice(r io.ByteReader, val *[]float64) error {
   251  	n, err := binary.ReadUvarint(r)
   252  	if err != nil {
   253  		return err
   254  	}
   255  	buf := slices.Grow(*val, int(n))[:0]
   256  	for range n {
   257  		var x float64
   258  		if err := readFloat64(r, &x); err != nil {
   259  			return notExpectingEOF(err)
   260  		}
   261  		buf = append(buf, x)
   262  	}
   263  	*val = buf
   264  	return nil
   265  }
   266  
   267  func readBoolSlice(r io.ByteReader, val *[]bool) error {
   268  	n, err := binary.ReadUvarint(r)
   269  	if err != nil {
   270  		return err
   271  	}
   272  	buf := slices.Grow(*val, int(n))[:0]
   273  	for range n {
   274  		var x bool
   275  		if err := readBool(r, &x); err != nil {
   276  			return notExpectingEOF(err)
   277  		}
   278  		buf = append(buf, x)
   279  	}
   280  	*val = buf
   281  	return nil
   282  }
   283  
   284  func readStringSlice(r io.ByteReader, val *[]string) error {
   285  	n, err := binary.ReadUvarint(r)
   286  	if err != nil {
   287  		return err
   288  	}
   289  	buf := slices.Grow(*val, int(n))[:0]
   290  	for range n {
   291  		var x string
   292  		if err := readString(r, &x); err != nil {
   293  			return notExpectingEOF(err)
   294  		}
   295  		buf = append(buf, x)
   296  	}
   297  	*val = buf
   298  	return nil
   299  }
   300  
   301  func notExpectingEOF(err error) error {
   302  	if err == io.EOF {
   303  		return io.ErrUnexpectedEOF
   304  	}
   305  	return err
   306  }