github.com/bluenviron/gomavlib/v2@v2.2.1-0.20240308101627-2c07e3da629c/pkg/message/readwriter.go (about)

     1  package message
     2  
     3  import (
     4  	"bytes"
     5  	"encoding/binary"
     6  	"fmt"
     7  	"math"
     8  	"reflect"
     9  	"regexp"
    10  	"sort"
    11  	"strconv"
    12  	"strings"
    13  
    14  	"github.com/bluenviron/gomavlib/v2/pkg/x25"
    15  )
    16  
    17  type fieldType int
    18  
    19  const (
    20  	typeDouble fieldType = iota + 1
    21  	typeUint64
    22  	typeInt64
    23  	typeFloat
    24  	typeUint32
    25  	typeInt32
    26  	typeUint16
    27  	typeInt16
    28  	typeUint8
    29  	typeInt8
    30  	typeChar
    31  )
    32  
    33  var fieldTypeFromGo = map[string]fieldType{
    34  	"float64": typeDouble,
    35  	"uint64":  typeUint64,
    36  	"int64":   typeInt64,
    37  	"float32": typeFloat,
    38  	"uint32":  typeUint32,
    39  	"int32":   typeInt32,
    40  	"uint16":  typeUint16,
    41  	"int16":   typeInt16,
    42  	"uint8":   typeUint8,
    43  	"int8":    typeInt8,
    44  	"string":  typeChar,
    45  }
    46  
    47  var fieldTypeString = map[fieldType]string{
    48  	typeDouble: "double",
    49  	typeUint64: "uint64_t",
    50  	typeInt64:  "int64_t",
    51  	typeFloat:  "float",
    52  	typeUint32: "uint32_t",
    53  	typeInt32:  "int32_t",
    54  	typeUint16: "uint16_t",
    55  	typeInt16:  "int16_t",
    56  	typeUint8:  "uint8_t",
    57  	typeInt8:   "int8_t",
    58  	typeChar:   "char",
    59  }
    60  
    61  var fieldTypeSizes = map[fieldType]byte{
    62  	typeDouble: 8,
    63  	typeUint64: 8,
    64  	typeInt64:  8,
    65  	typeFloat:  4,
    66  	typeUint32: 4,
    67  	typeInt32:  4,
    68  	typeUint16: 2,
    69  	typeInt16:  2,
    70  	typeUint8:  1,
    71  	typeInt8:   1,
    72  	typeChar:   1,
    73  }
    74  
    75  func fieldGoToDef(in string) string {
    76  	re := regexp.MustCompile("([A-Z])")
    77  	in = re.ReplaceAllString(in, "_${1}")
    78  	return strings.ToLower(in[1:])
    79  }
    80  
    81  func msgGoToDef(in string) string {
    82  	re := regexp.MustCompile("([A-Z])")
    83  	in = re.ReplaceAllString(in, "_${1}")
    84  	return strings.ToUpper(in[1:])
    85  }
    86  
    87  func readValue(target reflect.Value, buf []byte, f *decEncoderField) int {
    88  	if f.isEnum {
    89  		switch f.ftype {
    90  		case typeUint8:
    91  			target.SetUint(uint64(buf[0]))
    92  			return 1
    93  
    94  		case typeInt8:
    95  			target.SetUint(uint64(buf[0]))
    96  			return 1
    97  
    98  		case typeUint16:
    99  			target.SetUint(uint64(binary.LittleEndian.Uint16(buf)))
   100  			return 2
   101  
   102  		case typeUint32:
   103  			target.SetUint(uint64(binary.LittleEndian.Uint32(buf)))
   104  			return 4
   105  
   106  		case typeInt32:
   107  			target.SetUint(uint64(binary.LittleEndian.Uint32(buf)))
   108  			return 4
   109  
   110  		case typeUint64:
   111  			target.SetUint(binary.LittleEndian.Uint64(buf))
   112  			return 8
   113  		}
   114  	}
   115  
   116  	switch tt := target.Addr().Interface().(type) {
   117  	case *string:
   118  		// find string end or NULL character
   119  		end := 0
   120  		for end < int(f.arrayLength) && buf[end] != 0 {
   121  			end++
   122  		}
   123  		*tt = string(buf[:end])
   124  		return int(f.arrayLength) // return length including zeros
   125  
   126  	case *int8:
   127  		*tt = int8(buf[0])
   128  		return 1
   129  
   130  	case *uint8:
   131  		*tt = buf[0]
   132  		return 1
   133  
   134  	case *int16:
   135  		*tt = int16(binary.LittleEndian.Uint16(buf))
   136  		return 2
   137  
   138  	case *uint16:
   139  		*tt = binary.LittleEndian.Uint16(buf)
   140  		return 2
   141  
   142  	case *int32:
   143  		*tt = int32(binary.LittleEndian.Uint32(buf))
   144  		return 4
   145  
   146  	case *uint32:
   147  		*tt = binary.LittleEndian.Uint32(buf)
   148  		return 4
   149  
   150  	case *int64:
   151  		*tt = int64(binary.LittleEndian.Uint64(buf))
   152  		return 8
   153  
   154  	case *uint64:
   155  		*tt = binary.LittleEndian.Uint64(buf)
   156  		return 8
   157  
   158  	case *float32:
   159  		*tt = math.Float32frombits(binary.LittleEndian.Uint32(buf))
   160  		return 4
   161  
   162  	case *float64:
   163  		*tt = math.Float64frombits(binary.LittleEndian.Uint64(buf))
   164  		return 8
   165  	}
   166  
   167  	return 0
   168  }
   169  
   170  func writeValue(buf []byte, target reflect.Value, f *decEncoderField) int {
   171  	if f.isEnum {
   172  		switch f.ftype {
   173  		case typeUint8:
   174  			buf[0] = byte(target.Uint())
   175  			return 1
   176  
   177  		case typeInt8:
   178  			buf[0] = byte(target.Uint())
   179  			return 1
   180  
   181  		case typeUint16:
   182  			binary.LittleEndian.PutUint16(buf, uint16(target.Uint()))
   183  			return 2
   184  
   185  		case typeUint32:
   186  			binary.LittleEndian.PutUint32(buf, uint32(target.Uint()))
   187  			return 4
   188  
   189  		case typeInt32:
   190  			binary.LittleEndian.PutUint32(buf, uint32(target.Uint()))
   191  			return 4
   192  
   193  		case typeUint64:
   194  			binary.LittleEndian.PutUint64(buf, target.Uint())
   195  			return 8
   196  		}
   197  	}
   198  
   199  	switch tt := target.Addr().Interface().(type) {
   200  	case *string:
   201  		copy(buf[:f.arrayLength], *tt)
   202  		return int(f.arrayLength) // return length including zeros
   203  
   204  	case *int8:
   205  		buf[0] = uint8(*tt)
   206  		return 1
   207  
   208  	case *uint8:
   209  		buf[0] = *tt
   210  		return 1
   211  
   212  	case *int16:
   213  		binary.LittleEndian.PutUint16(buf, uint16(*tt))
   214  		return 2
   215  
   216  	case *uint16:
   217  		binary.LittleEndian.PutUint16(buf, *tt)
   218  		return 2
   219  
   220  	case *int32:
   221  		binary.LittleEndian.PutUint32(buf, uint32(*tt))
   222  		return 4
   223  
   224  	case *uint32:
   225  		binary.LittleEndian.PutUint32(buf, *tt)
   226  		return 4
   227  
   228  	case *int64:
   229  		binary.LittleEndian.PutUint64(buf, uint64(*tt))
   230  		return 8
   231  
   232  	case *uint64:
   233  		binary.LittleEndian.PutUint64(buf, *tt)
   234  		return 8
   235  
   236  	case *float32:
   237  		binary.LittleEndian.PutUint32(buf, math.Float32bits(*tt))
   238  		return 4
   239  
   240  	case *float64:
   241  		binary.LittleEndian.PutUint64(buf, math.Float64bits(*tt))
   242  		return 8
   243  	}
   244  
   245  	return 0
   246  }
   247  
   248  type decEncoderField struct {
   249  	isEnum      bool
   250  	ftype       fieldType
   251  	name        string
   252  	arrayLength byte
   253  	index       int
   254  	isExtension bool
   255  }
   256  
   257  // ReadWriter is a Message Reader and Writer.
   258  type ReadWriter struct {
   259  	fields       []*decEncoderField
   260  	sizeNormal   byte
   261  	sizeExtended byte
   262  	elemType     reflect.Type
   263  	crcExtra     byte
   264  }
   265  
   266  // NewReadWriter allocates a ReadWriter.
   267  func NewReadWriter(msg Message) (*ReadWriter, error) {
   268  	mde := &ReadWriter{}
   269  	mde.elemType = reflect.TypeOf(msg).Elem()
   270  
   271  	mde.fields = make([]*decEncoderField, mde.elemType.NumField())
   272  
   273  	// get name
   274  	if !strings.HasPrefix(mde.elemType.Name(), "Message") {
   275  		return nil, fmt.Errorf("struct name must begin with 'Message'")
   276  	}
   277  	msgName := msgGoToDef(mde.elemType.Name()[len("Message"):])
   278  
   279  	// collect message fields
   280  	for i := 0; i < mde.elemType.NumField(); i++ {
   281  		field := mde.elemType.Field(i)
   282  		arrayLength := byte(0)
   283  		goType := field.Type
   284  
   285  		// array
   286  		if goType.Kind() == reflect.Array {
   287  			arrayLength = byte(goType.Len())
   288  			goType = goType.Elem()
   289  		}
   290  
   291  		isEnum := false
   292  		var dialectType fieldType
   293  
   294  		// enum
   295  		if tagEnum := field.Tag.Get("mavenum"); tagEnum != "" {
   296  			isEnum = true
   297  
   298  			if goType.Kind() != reflect.Uint64 {
   299  				return nil, fmt.Errorf("an enum must be an uint64")
   300  			}
   301  
   302  			dialectType = fieldTypeFromGo[tagEnum]
   303  			if dialectType == 0 {
   304  				return nil, fmt.Errorf("unsupported Go type: %v", tagEnum)
   305  			}
   306  
   307  			switch dialectType {
   308  			case typeUint8:
   309  			case typeInt8:
   310  			case typeUint16:
   311  			case typeUint32:
   312  			case typeInt32:
   313  			case typeUint64:
   314  				break
   315  
   316  			default:
   317  				return nil, fmt.Errorf("type '%v' cannot be used as enum", tagEnum)
   318  			}
   319  		} else {
   320  			dialectType = fieldTypeFromGo[goType.Name()]
   321  			if dialectType == 0 {
   322  				return nil, fmt.Errorf("unsupported Go type: %v", goType.Name())
   323  			}
   324  
   325  			// string or char
   326  			if goType.Kind() == reflect.String {
   327  				tagLen := field.Tag.Get("mavlen")
   328  
   329  				if len(tagLen) == 0 { // char
   330  					arrayLength = 1
   331  				} else { // string
   332  					slen, err := strconv.Atoi(tagLen)
   333  					if err != nil {
   334  						return nil, fmt.Errorf("string has invalid length: %v", tagLen)
   335  					}
   336  					arrayLength = byte(slen)
   337  				}
   338  			}
   339  		}
   340  
   341  		// extension
   342  		isExtension := (field.Tag.Get("mavext") == "true")
   343  
   344  		// size
   345  		var size byte
   346  		if arrayLength > 0 {
   347  			size = fieldTypeSizes[dialectType] * arrayLength
   348  		} else {
   349  			size = fieldTypeSizes[dialectType]
   350  		}
   351  
   352  		mde.fields[i] = &decEncoderField{
   353  			isEnum: isEnum,
   354  			ftype:  dialectType,
   355  			name: func() string {
   356  				if mavname := field.Tag.Get("mavname"); mavname != "" {
   357  					return mavname
   358  				}
   359  				return fieldGoToDef(field.Name)
   360  			}(),
   361  			arrayLength: arrayLength,
   362  			index:       i,
   363  			isExtension: isExtension,
   364  		}
   365  
   366  		mde.sizeExtended += size
   367  		if !isExtension {
   368  			mde.sizeNormal += size
   369  		}
   370  	}
   371  
   372  	// reorder fields as described in
   373  	// https://mavlink.io/en/guide/serialization.html#field_reordering
   374  	sort.Slice(mde.fields, func(i, j int) bool {
   375  		// sort by weight if not extension
   376  		if !mde.fields[i].isExtension && !mde.fields[j].isExtension {
   377  			if w1, w2 := fieldTypeSizes[mde.fields[i].ftype], fieldTypeSizes[mde.fields[j].ftype]; w1 != w2 {
   378  				return w1 > w2
   379  			}
   380  		}
   381  		// sort by original index
   382  		return mde.fields[i].index < mde.fields[j].index
   383  	})
   384  
   385  	// generate CRC extra
   386  	// https://mavlink.io/en/guide/serialization.html#crc_extra
   387  	mde.crcExtra = func() byte {
   388  		h := x25.New()
   389  		h.Write([]byte(msgName + " "))
   390  
   391  		for _, f := range mde.fields {
   392  			// skip extensions
   393  			if f.isExtension {
   394  				continue
   395  			}
   396  
   397  			h.Write([]byte(fieldTypeString[f.ftype] + " "))
   398  			h.Write([]byte(f.name + " "))
   399  
   400  			if f.arrayLength > 0 {
   401  				h.Write([]byte{f.arrayLength})
   402  			}
   403  		}
   404  		sum := h.Sum16()
   405  		return byte((sum & 0xFF) ^ (sum >> 8))
   406  	}()
   407  
   408  	return mde, nil
   409  }
   410  
   411  // CRCExtra returns the CRC extra of the message.
   412  func (mde *ReadWriter) CRCExtra() byte {
   413  	return mde.crcExtra
   414  }
   415  
   416  // Read converts a *MessageRaw into a Message.
   417  func (mde *ReadWriter) Read(m *MessageRaw, isV2 bool) (Message, error) {
   418  	rmsg := reflect.New(mde.elemType)
   419  
   420  	if isV2 {
   421  		// in V2 buffer length can be > message or < message
   422  		// in this latter case it must be filled with zeros to support empty-byte de-truncation
   423  		// and extension fields
   424  		if len(m.Payload) < int(mde.sizeExtended) {
   425  			m.Payload = append(m.Payload, bytes.Repeat([]byte{0x00}, int(mde.sizeExtended)-len(m.Payload))...)
   426  		}
   427  	} else {
   428  		// in V1 buffer must fit message perfectly
   429  		if len(m.Payload) != int(mde.sizeNormal) {
   430  			return nil, fmt.Errorf("wrong size: expected %d, got %d", mde.sizeNormal, len(m.Payload))
   431  		}
   432  	}
   433  
   434  	// decode field by field
   435  	for _, f := range mde.fields {
   436  		// skip extensions in V1 frames
   437  		if !isV2 && f.isExtension {
   438  			continue
   439  		}
   440  
   441  		target := rmsg.Elem().Field(f.index)
   442  
   443  		switch target.Kind() {
   444  		case reflect.Array:
   445  			length := target.Len()
   446  			for i := 0; i < length; i++ {
   447  				n := readValue(target.Index(i), m.Payload, f)
   448  				m.Payload = m.Payload[n:]
   449  			}
   450  
   451  		default:
   452  			n := readValue(target, m.Payload, f)
   453  			m.Payload = m.Payload[n:]
   454  		}
   455  	}
   456  
   457  	return rmsg.Interface().(Message), nil
   458  }
   459  
   460  func (mde *ReadWriter) size(isV2 bool) uint8 {
   461  	if isV2 {
   462  		return mde.sizeExtended
   463  	}
   464  	return mde.sizeNormal
   465  }
   466  
   467  // Write converts a Message into a *MessageRaw.
   468  func (mde *ReadWriter) Write(msg Message, isV2 bool) *MessageRaw {
   469  	buf := make([]byte, mde.size(isV2))
   470  	start := buf
   471  
   472  	// encode field by field
   473  	for _, f := range mde.fields {
   474  		// skip extensions in V1 frames
   475  		if !isV2 && f.isExtension {
   476  			continue
   477  		}
   478  
   479  		target := reflect.ValueOf(msg).Elem().Field(f.index)
   480  
   481  		switch target.Kind() {
   482  		case reflect.Array:
   483  			length := target.Len()
   484  			for i := 0; i < length; i++ {
   485  				n := writeValue(buf, target.Index(i), f)
   486  				buf = buf[n:]
   487  			}
   488  
   489  		default:
   490  			n := writeValue(buf, target, f)
   491  			buf = buf[n:]
   492  		}
   493  	}
   494  
   495  	buf = start
   496  
   497  	// empty-byte truncation
   498  	// even with truncation, message length must be at least 1 byte
   499  	// https://github.com/mavlink/c_library_v2/blob/master/mavlink_helpers.h#L103
   500  	if isV2 {
   501  		end := len(buf)
   502  		for end > 1 && buf[end-1] == 0x00 {
   503  			end--
   504  		}
   505  		buf = buf[:end]
   506  	}
   507  
   508  	return &MessageRaw{
   509  		ID:      msg.GetID(),
   510  		Payload: buf,
   511  	}
   512  }