github.com/iDigitalFlame/xmt@v0.5.4/com/packet.go (about)

     1  // Copyright (C) 2020 - 2023 iDigitalFlame
     2  //
     3  // This program is free software: you can redistribute it and/or modify
     4  // it under the terms of the GNU General Public License as published by
     5  // the Free Software Foundation, either version 3 of the License, or
     6  // any later version.
     7  //
     8  // This program is distributed in the hope that it will be useful,
     9  // but WITHOUT ANY WARRANTY; without even the implied warranty of
    10  // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
    11  // GNU General Public License for more details.
    12  //
    13  // You should have received a copy of the GNU General Public License
    14  // along with this program.  If not, see <https://www.gnu.org/licenses/>.
    15  //
    16  
    17  package com
    18  
    19  import (
    20  	"io"
    21  
    22  	"github.com/iDigitalFlame/xmt/data"
    23  	"github.com/iDigitalFlame/xmt/device"
    24  	"github.com/iDigitalFlame/xmt/util/bugtrack"
    25  	"github.com/iDigitalFlame/xmt/util/xerr"
    26  )
    27  
    28  const (
    29  	// PacketMaxTags is the max amount of tags that are allowed on a specific
    30  	// Packet. If the amount of tags exceed this limit, an error will occur
    31  	// doing writing.
    32  	PacketMaxTags = 2 << 14
    33  	// PacketHeaderSize is the length of the Packet header in bytes.
    34  	PacketHeaderSize = 46
    35  )
    36  
    37  // ErrMalformedTag is an error returned when a read on a Packet Tag returns
    38  // an empty (zero) tag value.
    39  var ErrMalformedTag = xerr.Sub("malformed Tag", 0x2A)
    40  
    41  // Packet is a struct that is a Reader and Writer that can be generated to be
    42  // sent, or received from a Connection.
    43  //
    44  // Acts as a data buffer and 'parent' of 'data.Chunk'.
    45  type Packet struct {
    46  	Tags []uint32
    47  	data.Chunk
    48  
    49  	Flags Flag
    50  	Job   uint16
    51  
    52  	Device device.ID
    53  	ID     uint8
    54  	len    uint64
    55  }
    56  
    57  // Size returns the amount of bytes written or contained in this Packet with the
    58  // header size added.
    59  func (p *Packet) Size() int {
    60  	if p.Empty() {
    61  		return PacketHeaderSize
    62  	}
    63  	switch s := uint64(p.Chunk.Size() + PacketHeaderSize + (4 * len(p.Tags))); {
    64  	case s < data.LimitSmall:
    65  		return int(s) + 1
    66  	case s < data.LimitMedium:
    67  		return int(s) + 2
    68  	case s < data.LimitLarge:
    69  		return int(s) + 4
    70  	default:
    71  		return int(s) + 8
    72  	}
    73  }
    74  
    75  // Add attempts to combine the data and properties the supplied Packet with the
    76  // existing Packet. This function will return an error if the ID's have a
    77  // mismatch or there was an error during the write operation.
    78  func (p *Packet) Add(n *Packet) error {
    79  	if n == nil || n.Empty() {
    80  		return nil
    81  	}
    82  	if p.ID != n.ID {
    83  		return xerr.Sub("packet ID does not match the supplied ID", 0x2C)
    84  	}
    85  	if _, err := n.WriteTo(p); err != nil {
    86  		return xerr.Wrap("unable to write to Packet", err)
    87  	}
    88  	// NOTE(dij): Preserve frag flags.
    89  	p.Flags |= Flag(uint16(n.Flags))
    90  	return nil
    91  }
    92  
    93  // Belongs returns true if the specified Packet is a Frag that was a part of the
    94  // split Chunks of this as the original packet.
    95  func (p *Packet) Belongs(n *Packet) bool {
    96  	return n != nil && p.Flags >= FlagFrag && n.Flags >= FlagFrag && p.ID == n.ID && p.Job == n.Job && p.Flags.Group() == n.Flags.Group()
    97  }
    98  
    99  // Marshal will attempt to write this Packet's data and headers to the specified
   100  // Writer. This function will return any errors that have occurred during writing.
   101  func (p *Packet) Marshal(w io.Writer) error {
   102  	if err := p.writeHeader(w); err != nil {
   103  		return xerr.Wrap("marshal header", err)
   104  	}
   105  	if err := p.writeBody(w); err != nil {
   106  		return xerr.Wrap("marshal body", err)
   107  	}
   108  	return nil
   109  }
   110  func (p *Packet) readBody(r io.Reader) error {
   111  	if len(p.Tags) > 0 {
   112  		if bugtrack.Enabled {
   113  			bugtrack.Track("com.(*Packet).readBody(): len(p.Tags)=%d", len(p.Tags))
   114  		}
   115  		var b [4]byte
   116  		for i := range p.Tags {
   117  			n, err := io.ReadFull(r, b[:])
   118  			if err != nil {
   119  				return err
   120  			}
   121  			if n != 4 {
   122  				return io.ErrUnexpectedEOF
   123  			}
   124  			if p.Tags[i] = uint32(b[3]) | uint32(b[2])<<8 | uint32(b[1])<<16 | uint32(b[0])<<24; p.Tags[i] == 0 {
   125  				return ErrMalformedTag
   126  			}
   127  		}
   128  	}
   129  	if bugtrack.Enabled {
   130  		bugtrack.Track("com.(*Packet).readBody(): p.len=%d", p.len)
   131  	}
   132  	if p.len == 0 {
   133  		return nil
   134  	}
   135  	p.Limit = int(p.len)
   136  	var (
   137  		t   uint64
   138  		err error
   139  	)
   140  	for n := int64(0); t < p.len && err == nil; {
   141  		n, err = p.ReadFrom(r)
   142  		if t += uint64(n); err != nil || n == 0 {
   143  			break
   144  		}
   145  	}
   146  	if p.Limit = 0; bugtrack.Enabled {
   147  		bugtrack.Track("com.(*Packet).readBody(): p.len=%d, t=%d, err=%s", p.len, t, err)
   148  	}
   149  	if t < p.len {
   150  		return io.ErrUnexpectedEOF
   151  	}
   152  	if t == p.len {
   153  		err = nil
   154  	}
   155  	p.len = 0
   156  	return err
   157  }
   158  
   159  // Unmarshal will attempt to read Packet data and headers from the specified Reader.
   160  // This function will return any errors that have occurred during reading.
   161  func (p *Packet) Unmarshal(r io.Reader) error {
   162  	if err := p.readHeader(r); err != nil {
   163  		return xerr.Wrap("unmarshal header", err)
   164  	}
   165  	if err := p.readBody(r); err != nil {
   166  		return xerr.Wrap("unmarshal body", err)
   167  	}
   168  	return nil
   169  }
   170  func (p *Packet) writeBody(w io.Writer) error {
   171  	if len(p.Tags) > 0 {
   172  		if bugtrack.Enabled {
   173  			bugtrack.Track("com.(*Packet).writeBody(): len(p.Tags)=%d", len(p.Tags))
   174  		}
   175  		var b [4]byte
   176  		for _, t := range p.Tags {
   177  			if t == 0 {
   178  				return ErrMalformedTag
   179  			}
   180  			b[0], b[1], b[2], b[3] = byte(t>>24), byte(t>>16), byte(t>>8), byte(t)
   181  			n, err := w.Write(b[0:4])
   182  			if err != nil {
   183  				return err
   184  			}
   185  			if n != 4 {
   186  				return io.ErrShortWrite
   187  			}
   188  		}
   189  	}
   190  	if p.Seek(0, 0); p.Chunk.Size() == 0 {
   191  		return nil
   192  	}
   193  	n, err := p.WriteTo(w)
   194  	if err != nil {
   195  		return err
   196  	}
   197  	if n != int64(p.Chunk.Size()) {
   198  		return io.ErrShortWrite
   199  	}
   200  	if bugtrack.Enabled {
   201  		bugtrack.Track("com.(*Packet).writeBody(): p.Chunk.Size()=%d, n=%d, err=%s", p.Chunk.Size(), n, err)
   202  	}
   203  	return nil
   204  }
   205  func (p *Packet) readHeader(r io.Reader) error {
   206  	if err := p.Device.Read(r); err != nil {
   207  		if bugtrack.Enabled {
   208  			bugtrack.Track("com.(*Packet).readHeader(): Read Device failed err=%s", err)
   209  		}
   210  		return err
   211  	}
   212  	var (
   213  		b      [14]byte
   214  		n, err = io.ReadFull(r, b[:])
   215  	)
   216  	if bugtrack.Enabled {
   217  		bugtrack.Track("com.(*Packet).readHeader(): n=%d, err=%s", n, err)
   218  	}
   219  	if n != 14 {
   220  		if err != nil {
   221  			return err
   222  		}
   223  		return io.ErrUnexpectedEOF
   224  	}
   225  	_ = b[13]
   226  	if bugtrack.Enabled {
   227  		bugtrack.Track(
   228  			"com.(*Packet).readHeader(): b=[%d, %d, %d, %d, %d, %d, %d, %d, %d, %d, %d, %d, %d, %d]",
   229  			b[0], b[1], b[2], b[3], b[4], b[5], b[6], b[7], b[8], b[9], b[10], b[11], b[12], b[13],
   230  		)
   231  	}
   232  	p.ID, p.Job = b[0], uint16(b[2])|uint16(b[1])<<8
   233  	p.Flags = Flag(b[10]) | Flag(b[9])<<8 | Flag(b[8])<<16 | Flag(b[7])<<24 |
   234  		Flag(b[6])<<32 | Flag(b[5])<<40 | Flag(b[4])<<48 | Flag(b[3])<<56
   235  	if l := int(b[12]) | int(b[11])<<8; l > 0 {
   236  		p.Tags = make([]uint32, l)
   237  	}
   238  	switch b[13] {
   239  	case 0:
   240  		p.len, err = 0, nil
   241  	case 1:
   242  		if n, err = io.ReadFull(r, b[0:1]); n != 1 {
   243  			if err == nil {
   244  				err = io.ErrUnexpectedEOF
   245  			}
   246  			break
   247  		}
   248  		if bugtrack.Enabled {
   249  			bugtrack.Track("com.(*Packet).readHeader(): 1, n=%d, b=[%d]", n, b[0])
   250  		}
   251  		p.len, err = uint64(b[0]), nil
   252  	case 3:
   253  		if n, err = io.ReadFull(r, b[0:2]); n != 2 {
   254  			if err == nil {
   255  				err = io.ErrUnexpectedEOF
   256  			}
   257  			break
   258  		}
   259  		if bugtrack.Enabled {
   260  			bugtrack.Track("com.(*Packet).readHeader(): 3, n=%d, b=[%d, %d]", n, b[0], b[1])
   261  		}
   262  		p.len, err = uint64(b[1])|uint64(b[0])<<8, nil
   263  	case 5:
   264  		if n, err = io.ReadFull(r, b[0:4]); n != 4 {
   265  			if err == nil {
   266  				err = io.ErrUnexpectedEOF
   267  			}
   268  			break
   269  		}
   270  		if bugtrack.Enabled {
   271  			bugtrack.Track("com.(*Packet).readHeader(): 5, n=%d, b=[%d, %d, %d, %d]", n, b[0], b[1], b[2], b[3])
   272  		}
   273  		p.len, err = uint64(b[3])|uint64(b[2])<<8|uint64(b[1])<<16|uint64(b[0])<<24, nil
   274  	case 7:
   275  		if n, err = io.ReadFull(r, b[0:8]); n != 8 {
   276  			if err == nil {
   277  				err = io.ErrUnexpectedEOF
   278  			}
   279  			break
   280  		}
   281  		if bugtrack.Enabled {
   282  			bugtrack.Track(
   283  				"com.(*Packet).readHeader(): 7, n=%d, b=[%d, %d, %d, %d, %d, %d, %d, %d]",
   284  				n, b[0], b[1], b[2], b[3], b[4], b[5], b[6], b[7],
   285  			)
   286  		}
   287  		p.len, err = uint64(b[7])|uint64(b[6])<<8|uint64(b[5])<<16|uint64(b[4])<<24|
   288  			uint64(b[3])<<32|uint64(b[2])<<40|uint64(b[1])<<48|uint64(b[0])<<56, nil
   289  	default:
   290  		return data.ErrInvalidType
   291  	}
   292  	if bugtrack.Enabled {
   293  		bugtrack.Track("com.(*Packet).readHeader(): p.ID=%d, p.len=%d, err=%s", p.ID, p.len, err)
   294  	}
   295  	return err
   296  }
   297  func (p *Packet) writeHeader(w io.Writer) error {
   298  	t := len(p.Tags)
   299  	if t > PacketMaxTags {
   300  		return xerr.Sub("tags list is too large", 0x2B)
   301  	}
   302  	if bugtrack.Enabled {
   303  		if p.Device.Empty() {
   304  			bugtrack.Track("com.(*Packet).writeHeader(): Calling writeHeader with empty Device, p.ID=%d!", p.ID)
   305  		}
   306  	}
   307  	if err := p.Device.Write(w); err != nil {
   308  		return err
   309  	}
   310  	var (
   311  		b [22]byte
   312  		c int
   313  	)
   314  	_ = b[21]
   315  	b[0], b[1], b[2] = p.ID, byte(p.Job>>8), byte(p.Job)
   316  	b[3], b[4], b[5], b[6] = byte(p.Flags>>56), byte(p.Flags>>48), byte(p.Flags>>40), byte(p.Flags>>32)
   317  	b[7], b[8], b[9], b[10] = byte(p.Flags>>24), byte(p.Flags>>16), byte(p.Flags>>8), byte(p.Flags)
   318  	b[11], b[12] = byte(t>>8), byte(t)
   319  	switch l := uint64(p.Chunk.Size()); {
   320  	case l == 0:
   321  		b[13] = 0
   322  	case l < data.LimitSmall:
   323  		b[13], b[14], c = 1, byte(l), 1
   324  	case l < data.LimitMedium:
   325  		b[13], b[14], b[15], c = 3, byte(l>>8), byte(l), 2
   326  	case l < data.LimitLarge:
   327  		b[13], c = 5, 4
   328  		b[14], b[15], b[16], b[17] = byte(l>>24), byte(l>>16), byte(l>>8), byte(l)
   329  	default:
   330  		b[13], c = 7, 8
   331  		b[14], b[15], b[16], b[17] = byte(l>>56), byte(l>>48), byte(l>>40), byte(l>>32)
   332  		b[18], b[19], b[20], b[21] = byte(l>>24), byte(l>>16), byte(l>>8), byte(l)
   333  	}
   334  	// NOTE(dij): This write is split into two writes as some stateful writes (XOR)
   335  	//             require writes and reads to re-constructed in the same way.
   336  	n, err := w.Write(b[0:14])
   337  	if err != nil {
   338  		return err
   339  	}
   340  	if n != 14 {
   341  		return io.ErrShortWrite
   342  	}
   343  	if n, err = w.Write(b[14 : 14+c]); err != nil {
   344  		return err
   345  	}
   346  	if n != c {
   347  		return io.ErrShortWrite
   348  	}
   349  	if bugtrack.Enabled {
   350  		bugtrack.Track("com.(*Packet).writeHeader(): p.ID=%d, p.len=%d, n=%d", p.ID, p.Chunk.Size(), c+14+device.IDSize)
   351  	}
   352  	return nil
   353  }
   354  
   355  // MarshalStream writes the data of this Packet to the supplied Writer.
   356  func (p *Packet) MarshalStream(w data.Writer) error {
   357  	if bugtrack.Enabled {
   358  		if p.Device.Empty() {
   359  			bugtrack.Track("com.(*Packet).writeHeader(): Calling writeHeader with empty Device, p.ID=%d!", p.ID)
   360  		}
   361  	}
   362  	if err := w.WriteUint8(p.ID); err != nil {
   363  		return err
   364  	}
   365  	if err := w.WriteUint16(p.Job); err != nil {
   366  		return err
   367  	}
   368  	if err := w.WriteUint16(uint16(len(p.Tags))); err != nil {
   369  		return err
   370  	}
   371  	if err := p.Flags.MarshalStream(w); err != nil {
   372  		return err
   373  	}
   374  	if err := p.Device.MarshalStream(w); err != nil {
   375  		return err
   376  	}
   377  	for i := 0; i < len(p.Tags) && i < PacketMaxTags; i++ {
   378  		if err := w.WriteUint32(p.Tags[i]); err != nil {
   379  			return err
   380  		}
   381  	}
   382  	return p.Chunk.MarshalStream(w)
   383  }
   384  
   385  // UnmarshalStream reads the data of this Packet from the supplied Reader.
   386  func (p *Packet) UnmarshalStream(r data.Reader) error {
   387  	if err := r.ReadUint8(&p.ID); err != nil {
   388  		return err
   389  	}
   390  	if err := r.ReadUint16(&p.Job); err != nil {
   391  		return err
   392  	}
   393  	t, err := r.Uint16()
   394  	if err != nil {
   395  		return err
   396  	}
   397  	if err := p.Flags.UnmarshalStream(r); err != nil {
   398  		return err
   399  	}
   400  	if err := p.Device.UnmarshalStream(r); err != nil {
   401  		return err
   402  	}
   403  	if t > 0 {
   404  		p.Tags = make([]uint32, t)
   405  		for i := uint16(0); i < t && i < PacketMaxTags; i++ {
   406  			if err := r.ReadUint32(&p.Tags[i]); err != nil {
   407  				return err
   408  			}
   409  			if p.Tags[i] == 0 {
   410  				return ErrMalformedTag
   411  			}
   412  		}
   413  	}
   414  	return p.Chunk.UnmarshalStream(r)
   415  }