github.com/iDigitalFlame/xmt@v0.5.4/c2/vars.go (about)

     1  // Copyright (C) 2020 - 2023 iDigitalFlame
     2  //
     3  // This program is free software: you can redistribute it and/or modify
     4  // it under the terms of the GNU General Public License as published by
     5  // the Free Software Foundation, either version 3 of the License, or
     6  // any later version.
     7  //
     8  // This program is distributed in the hope that it will be useful,
     9  // but WITHOUT ANY WARRANTY; without even the implied warranty of
    10  // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
    11  // GNU General Public License for more details.
    12  //
    13  // You should have received a copy of the GNU General Public License
    14  // along with this program.  If not, see <https://www.gnu.org/licenses/>.
    15  //
    16  
    17  package c2
    18  
    19  import (
    20  	"context"
    21  	"io"
    22  	"net"
    23  	"sync"
    24  	"time"
    25  
    26  	"github.com/iDigitalFlame/xmt/c2/cfg"
    27  	"github.com/iDigitalFlame/xmt/c2/cout"
    28  	"github.com/iDigitalFlame/xmt/c2/task"
    29  	"github.com/iDigitalFlame/xmt/com"
    30  	"github.com/iDigitalFlame/xmt/com/limits"
    31  	"github.com/iDigitalFlame/xmt/com/pipe"
    32  	"github.com/iDigitalFlame/xmt/data"
    33  	"github.com/iDigitalFlame/xmt/device"
    34  	"github.com/iDigitalFlame/xmt/util"
    35  	"github.com/iDigitalFlame/xmt/util/bugtrack"
    36  	"github.com/iDigitalFlame/xmt/util/xerr"
    37  )
    38  
    39  // RvResult is the generic value for indicating a result value. Packets
    40  // that have this as their ID value will be forwarded to the authoritative
    41  // Mux and will be discarded if it does not match an active Job ID.
    42  const RvResult uint8 = 0x14
    43  
    44  const (
    45  	fragMax     = 0xFFFF
    46  	readTimeout = time.Millisecond * 350
    47  )
    48  
    49  // ID entries that start with 'Sv*' will be handed directly by the underlying
    50  // Session instead of being forwarded to the authoritative Mux.
    51  //
    52  // These Packet ID values are used for network congestion and flow control and
    53  // should not be used in standard Packet entries.
    54  const (
    55  	SvResync   uint8 = 0x1
    56  	SvHello    uint8 = 0x2
    57  	SvRegister uint8 = 0x3 // Considered a MvDrop.
    58  	SvComplete uint8 = 0x4
    59  	SvShutdown uint8 = 0x5
    60  	SvDrop     uint8 = 0x6
    61  )
    62  
    63  // ErrTooManyPackets is an error returned by many of the Packet writing
    64  // functions when attempts to combine Packets would create a Packet grouping
    65  // size larger than the maximum size (65535/0xFFFF).
    66  var ErrTooManyPackets = xerr.Sub("frag/multi count is larger than 0xFFFF", 0x56)
    67  
    68  var buffers = sync.Pool{
    69  	New: func() interface{} {
    70  		return new(data.Chunk)
    71  	},
    72  }
    73  
    74  func returnBuffer(c *data.Chunk) {
    75  	c.Clear()
    76  	buffers.Put(c)
    77  }
    78  func isPacketNoP(n *com.Packet) bool {
    79  	return n.ID < 2 && n.Empty() && (n.Flags == 0 || n.Flags == com.FlagProxy)
    80  }
    81  func mergeTags(one, two []uint32) []uint32 {
    82  	if len(one) == 0 && len(two) == 0 {
    83  		return nil
    84  	}
    85  	if len(one) == 0 && len(two) > 0 {
    86  		return two
    87  	}
    88  	if len(one) > 0 && len(two) == 0 {
    89  		return one
    90  	}
    91  	i := len(one)
    92  	if i < len(two) {
    93  		i = len(two)
    94  	}
    95  	t := make(map[uint32]struct{}, i)
    96  	for _, v := range one {
    97  		t[v] = wake
    98  	}
    99  	for _, v := range two {
   100  		t[v] = wake
   101  	}
   102  	r := make([]uint32, 0, len(t))
   103  	for v := range t {
   104  		r = append(r, v)
   105  	}
   106  	return r
   107  }
   108  func receiveSingle(s *Session, n *com.Packet) {
   109  	if s == nil {
   110  		return
   111  	}
   112  	if bugtrack.Enabled {
   113  		bugtrack.Track(
   114  			"c2.receiveSingle(): n.ID=%X, n=%s, n.Flags=%s, n.Device=%s", n.ID, n, n.Flags, n.Device,
   115  		)
   116  	}
   117  	switch n.ID {
   118  	case SvComplete:
   119  		if !n.Empty() && n.Flags&com.FlagCrypt != 0 {
   120  			s.keySessionSync(n)
   121  			n.Clear()
   122  			return
   123  		}
   124  	case SvResync:
   125  		if !s.hasJob(n.Job) {
   126  			if cout.Enabled {
   127  				s.log.Error("[%s/Cr0] Client sent a SvResync Packet not associated with an active Job!", s.ID, n.Job)
   128  			}
   129  			return
   130  		}
   131  		if cout.Enabled {
   132  			s.log.Debug("[%s/Cr0] Client sent a SvResync Packet associated with Job %d!", s.ID, n.Job)
   133  		}
   134  		t, err := n.Uint8()
   135  		if err != nil {
   136  			if cout.Enabled {
   137  				s.log.Error("[%s/Cr0] Error reading SvResync Packet: %s!", s.ID, err.Error())
   138  			}
   139  			return
   140  		}
   141  		if _, err := s.readDeviceInfo(t, n); err != nil {
   142  			if cout.Enabled {
   143  				s.log.Error("[%s/Cr0] Error reading SvResync Packet result: %s!", s.ID, err.Error())
   144  			}
   145  			return
   146  		}
   147  		if cout.Enabled {
   148  			s.log.Debug("[%s/Cr0] Client indicated that it changed profile/time, updating local Session information.", s.ID)
   149  		}
   150  		return
   151  	case SvShutdown:
   152  		if !s.IsClient() {
   153  			if cout.Enabled {
   154  				s.log.Info("[%s/Cr0] Client indicated shutdown, acknowledging and closing Session.", s.ID)
   155  			}
   156  			s.write(true, &com.Packet{ID: SvShutdown, Job: 1, Device: s.ID})
   157  			s.s.Remove(s.ID, false)
   158  			s.state.Set(stateShutdownWait)
   159  		} else {
   160  			if s.state.Closing() {
   161  				return
   162  			}
   163  			if cout.Enabled {
   164  				s.log.Info("[%s/Cr0] Server indicated shutdown, closing Session.", s.ID)
   165  			}
   166  		}
   167  		s.close(false)
   168  		return
   169  	case SvRegister:
   170  		if !s.IsClient() {
   171  			return
   172  		}
   173  		if cout.Enabled {
   174  			s.log.Info("[%s/Cr0] Server indicated that we must re-register, resending SvRegister info!", s.ID)
   175  		}
   176  		if s.proxy != nil && s.proxy.IsActive() {
   177  			s.proxy.subsRegister()
   178  		}
   179  		v := &com.Packet{ID: SvHello, Job: uint16(util.FastRand()), Device: s.ID}
   180  		s.writeDeviceInfo(infoHello, v)
   181  		s.keySessionGenerate(v)
   182  		if s.queue(v); len(s.send) <= 1 {
   183  			s.Wake()
   184  		}
   185  		return
   186  	}
   187  	if n.ID < task.MvRefresh {
   188  		return
   189  	}
   190  	if s.parent == nil {
   191  		s.m.queue(event{p: n, s: s, hf: defaultClientMux})
   192  		return
   193  	}
   194  	s.m.queue(event{p: n, s: s, af: s.handle})
   195  }
   196  func verifyPacket(n *com.Packet, i device.ID) bool {
   197  	if n.Job == 0 && n.Flags&com.FlagProxy == 0 && n.ID > 1 {
   198  		n.Job = uint16(util.FastRand())
   199  	}
   200  	if n.Device.Empty() {
   201  		n.Device = i
   202  		return true
   203  	}
   204  	return n.Device == i
   205  }
   206  func receive(s *Session, l *Listener, n *com.Packet) error {
   207  	if n == nil || n.Device.Empty() || isPacketNoP(n) || (l == nil && s == nil) {
   208  		return nil
   209  	}
   210  	if bugtrack.Enabled {
   211  		bugtrack.Track(
   212  			"c2.receive(): s == nil=%t, l == nil=%t, n.ID=%X, n=%s, n.Flags=%s, n.Device=%s",
   213  			s == nil, l == nil, n.ID, n, n.Flags, n.Device,
   214  		)
   215  	}
   216  	if s != nil && n.Flags&com.FlagMultiDevice == 0 && s.ID != n.Device {
   217  		if s.proxy != nil && s.proxy.IsActive() && s.proxy.accept(n) {
   218  			return nil
   219  		}
   220  		if n.Clear(); xerr.ExtendedInfo {
   221  			return xerr.Sub(`received Packet for "`+n.Device.String()+`" that does not match our own device ID "`+s.ID.String()+`"`, 0x57)
   222  		}
   223  		return xerr.Sub("received Packet that does not match our own device ID", 0x57)
   224  	}
   225  	if n.Flags&com.FlagOneshot != 0 {
   226  		l.oneshot(n)
   227  		return nil
   228  	}
   229  	if s == nil || (n.ID == SvComplete && n.Flags&com.FlagCrypt == 0) {
   230  		n.Clear()
   231  		return nil
   232  	}
   233  	switch {
   234  	case n.Flags&com.FlagMulti != 0:
   235  		x := n.Flags.Len()
   236  		if x == 0 {
   237  			return ErrInvalidPacketCount
   238  		}
   239  		for ; x > 0; x-- {
   240  			var v com.Packet
   241  			if err := v.UnmarshalStream(n); err != nil {
   242  				n.Clear()
   243  				v.Clear()
   244  				return err
   245  			}
   246  			if cout.Enabled {
   247  				s.log.Trace(`[%s] Unpacked Packet "%s"..`, s.ID, v)
   248  			}
   249  			if err := receive(s, l, &v); err != nil {
   250  				n.Clear()
   251  				v.Clear()
   252  				return err
   253  			}
   254  		}
   255  		n.Clear()
   256  		return nil
   257  	case n.Flags&com.FlagFrag != 0 && n.Flags&com.FlagMulti == 0:
   258  		if n.ID == SvDrop || n.ID == SvRegister {
   259  			if cout.Enabled {
   260  				s.log.Warning("[%s] Indicated to clear Frag Group 0x%X!", s.ID, n.Flags.Group())
   261  			}
   262  			if s.state.SetLast(n.Flags.Group()); n.ID != SvRegister {
   263  				n.Clear()
   264  				return nil
   265  			}
   266  			break
   267  		}
   268  		if n.Flags.Len() == 0 {
   269  			n.Clear()
   270  			return ErrInvalidPacketCount
   271  		}
   272  		if n.Flags.Len() == 1 {
   273  			if cout.Enabled {
   274  				s.log.Trace("[%s] Received a single frag (len=1) for Group 0x%X, clearing Flags!", s.ID, n.Flags.Group())
   275  			}
   276  			n.Flags.Clear()
   277  			return receive(s, l, n)
   278  		}
   279  		if cout.Enabled {
   280  			s.log.Trace("[%s] Received frag for Group 0x%X (%d of %d).", s.ID, n.Flags.Group(), n.Flags.Position()+1, n.Flags.Len())
   281  		}
   282  		var (
   283  			g     = n.Flags.Group()
   284  			c, ok = s.frags[g]
   285  		)
   286  		if !ok && n.Flags.Position() > 0 {
   287  			if s.write(true, &com.Packet{ID: SvDrop, Flags: n.Flags, Device: s.ID}); cout.Enabled {
   288  				s.log.Warning("[%s] Received an invalid Frag Group 0x%X, indicating to drop it!", s.ID, n.Flags.Group())
   289  			}
   290  			return nil
   291  		}
   292  		if !ok {
   293  			c = new(cluster)
   294  			s.frags[g] = c
   295  		}
   296  		if err := c.add(n); err != nil {
   297  			return err
   298  		}
   299  		if v := c.done(); v != nil {
   300  			if delete(s.frags, g); cout.Enabled {
   301  				s.log.Trace("[%s] Completed Frag Group 0x%X, %d total.", s.ID, n.Flags.Group(), n.Flags.Len())
   302  			}
   303  			return receive(s, l, v)
   304  		}
   305  		s.frag(n.Job, n.Flags.Group(), n.Flags.Len(), n.Flags.Position())
   306  		return nil
   307  	}
   308  	receiveSingle(s, n)
   309  	return nil
   310  }
   311  func writeUnpack(dst, src *com.Packet, flags, tags bool) error {
   312  	if src == nil || dst == nil {
   313  		return nil
   314  	}
   315  	if src.Flags&com.FlagMulti != 0 || src.Flags&com.FlagMultiDevice != 0 {
   316  		x := src.Flags.Len()
   317  		if x == 0 {
   318  			return ErrInvalidPacketCount
   319  		}
   320  		if x+dst.Flags.Len() > fragMax {
   321  			return ErrTooManyPackets
   322  		}
   323  		src.WriteTo(dst)
   324  		dst.Flags.SetLen(dst.Flags.Len() + x)
   325  		src.Clear()
   326  		return nil
   327  	}
   328  	if dst.Flags.Len()+1 > fragMax {
   329  		return ErrTooManyPackets
   330  	}
   331  	src.MarshalStream(dst)
   332  	if dst.Flags.SetLen(dst.Flags.Len() + 1); flags {
   333  		if src.Flags&com.FlagChannel != 0 {
   334  			dst.Flags |= com.FlagChannel
   335  		}
   336  		if src.Flags&com.FlagMultiDevice != 0 {
   337  			dst.Flags |= com.FlagMultiDevice
   338  		}
   339  	}
   340  	if dst.Flags |= com.FlagMulti; tags && len(src.Tags) > 0 {
   341  		dst.Tags = append(dst.Tags, src.Tags...)
   342  	}
   343  	src.Clear()
   344  	return nil
   345  }
   346  func readPacketFrom(c io.Reader, w cfg.Wrapper, n *com.Packet) error {
   347  	if w == nil {
   348  		if bugtrack.Enabled {
   349  			bugtrack.Track("c2.readPacketFrom(): Passing read to direct Unmarshal.")
   350  		}
   351  		return n.Unmarshal(c)
   352  	}
   353  	if bugtrack.Enabled {
   354  		bugtrack.Track("c2.readPacketFrom(): Starting read with Wrapper.")
   355  	}
   356  	i, err := w.Unwrap(c)
   357  	if err != nil {
   358  		return xerr.Wrap("unable to unwrap Reader", err)
   359  	}
   360  	if err = n.Unmarshal(i); err != nil {
   361  		return err
   362  	}
   363  	return nil
   364  }
   365  func writePacketTo(c *data.Chunk, w cfg.Wrapper, n *com.Packet) error {
   366  	if w == nil {
   367  		if bugtrack.Enabled {
   368  			bugtrack.Track("c2.writePacketTo(): Passing write to direct Marshal.")
   369  		}
   370  		return n.Marshal(c)
   371  	}
   372  	o, err := w.Wrap(c)
   373  	if err != nil {
   374  		return xerr.Wrap("unable to wrap Writer", err)
   375  	}
   376  	if bugtrack.Enabled {
   377  		bugtrack.Track("c2.writePacketTo(): n=%s, n.Len()=%d, n.Size()=%d", n, n.Size(), n.Size())
   378  	}
   379  	if err = n.Marshal(o); err != nil {
   380  		return err
   381  	}
   382  	if err = o.Close(); err != nil {
   383  		return xerr.Wrap("unable to close Wrapper", err)
   384  	}
   385  	return nil
   386  }
   387  func spinTimeout(x context.Context, n string, t time.Duration) net.Conn {
   388  	var (
   389  		y, f = context.WithTimeout(x, t)
   390  		c    net.Conn
   391  	)
   392  	for c == nil {
   393  		select {
   394  		case <-y.Done():
   395  			f()
   396  			return nil
   397  		case <-x.Done():
   398  			f()
   399  			return nil
   400  		default:
   401  			c, _ = pipe.DialContext(y, n)
   402  		}
   403  	}
   404  	f()
   405  	return c
   406  }
   407  func readPacket(c net.Conn, w cfg.Wrapper, t cfg.Transform) (*com.Packet, error) {
   408  	var n com.Packet
   409  	if w == nil && t == nil {
   410  		if err := n.Unmarshal(&readerTimeout{c: c, t: readTimeout}); err != nil {
   411  			return nil, xerr.Wrap("unable to read from stream", err)
   412  		}
   413  		if bugtrack.Enabled {
   414  			bugtrack.Track("c2.readPacket(): Direct Unmarshal result n=%s", n)
   415  		}
   416  		return &n, nil
   417  	}
   418  	var (
   419  		b      = buffers.Get().(*data.Chunk)
   420  		d, err = b.ReadDeadline(c, readTimeout)
   421  	)
   422  	if bugtrack.Enabled {
   423  		bugtrack.Track("c2.readPacket(): ReadDeadline result d=%d, err=%s", d, err)
   424  	}
   425  	if d == 0 {
   426  		if returnBuffer(b); err != nil {
   427  			return nil, xerr.Wrap("unable to read from stream", err)
   428  		}
   429  		return nil, xerr.Wrap("unable to read from stream", io.ErrUnexpectedEOF)
   430  	}
   431  	if t != nil {
   432  		o := buffers.Get().(*data.Chunk)
   433  		err = t.Read(b.Payload(), o)
   434  		if returnBuffer(b); err != nil {
   435  			returnBuffer(o)
   436  			return nil, xerr.Wrap("unable to read from cache", err)
   437  		}
   438  		b = o
   439  	}
   440  	err = readPacketFrom(b, w, &n)
   441  	if returnBuffer(b); err != nil {
   442  		n.Clear()
   443  		return nil, err
   444  	}
   445  	if bugtrack.Enabled {
   446  		bugtrack.Track("c2.readPacket(): Unmarshal result n=%s", n)
   447  	}
   448  	return &n, nil
   449  }
   450  func writePacket(c net.Conn, w cfg.Wrapper, t cfg.Transform, n *com.Packet) error {
   451  	if w == nil && t == nil {
   452  		err := n.Marshal(c)
   453  		n.Clear()
   454  		return err
   455  	}
   456  	var (
   457  		b   = buffers.Get().(*data.Chunk)
   458  		err = writePacketTo(b, w, n)
   459  	)
   460  	if n.Clear(); err != nil {
   461  		returnBuffer(b)
   462  		return xerr.Wrap("unable to write to cache", err)
   463  	}
   464  	if t != nil {
   465  		err = t.Write(b.Payload(), c)
   466  	} else {
   467  		_, err = b.WriteTo(c)
   468  	}
   469  	if returnBuffer(b); err != nil {
   470  		return xerr.Wrap("unable to write to stream", err)
   471  	}
   472  	return nil
   473  }
   474  func nextPacket(a notifier, q <-chan *com.Packet, n *com.Packet, i device.ID, t []uint32) (*com.Packet, *com.Packet) {
   475  	if n == nil && len(q) == 0 {
   476  		return nil, nil
   477  	}
   478  	// NOTE(dij): Fast path (if we have a strict limit OR we don't have
   479  	//            anything in queue, but we got a packet). So just send that
   480  	//            shit/wrap if needed.
   481  	if limits.Packets <= 1 || (n != nil && len(q) == 0) {
   482  		if n == nil {
   483  			if n = <-q; n == nil {
   484  				return nil, nil
   485  			}
   486  		}
   487  		if a.accept(n.Job); verifyPacket(n, i) {
   488  			n.Tags = append(n.Tags, t...)
   489  			return n, nil
   490  		}
   491  		o := &com.Packet{Device: i, Flags: com.FlagMulti | com.FlagMultiDevice}
   492  		writeUnpack(o, n, true, true)
   493  		o.Tags = append(o.Tags, t...)
   494  		return o, nil
   495  	}
   496  	var (
   497  		o = &com.Packet{Device: i, Flags: com.FlagMulti}
   498  		k *com.Packet
   499  	)
   500  	for x, s, m := 0, 0, false; x < limits.Packets && len(q) > 0; x++ {
   501  		if n == nil {
   502  			n = <-q
   503  		}
   504  		// TODO(dij): ?need to add a check here to see if len(c) == 0
   505  		//            if so, drop a SvNop and return only the first
   506  		if isPacketNoP(n) && ((s > 0 && !m) || (n.Device.Empty() || n.Device == i)) {
   507  			n.Clear()
   508  			n = nil
   509  			continue
   510  		}
   511  		// Rare case a single packet (which was already chunked,
   512  		// is bigger than the frag size, shouldn't happen but *shrug*)
   513  		// s would be zero on the first round, so just send that one and "fuck it"
   514  		if s > 0 {
   515  			if s += n.Size(); s > limits.Frag {
   516  				k = n
   517  				break
   518  			}
   519  		} else {
   520  			s += n.Size()
   521  		}
   522  		// Set multi device flag if there's a packet in queue that doesn't match us.
   523  		if a.accept(n.Job); !verifyPacket(n, i) && !m {
   524  			o.Flags |= com.FlagMultiDevice
   525  			m = true
   526  		}
   527  		writeUnpack(o, n, true, true)
   528  		n = nil
   529  	}
   530  	// If we get a single packet, unpack it and send it instead.
   531  	// I don't think there's a super good way to do this, as we clear most of the
   532  	// data during write. IE: we have >1 NOPs and just a single data Packet.
   533  	if o.Flags.Len() == 1 && o.Flags&com.FlagMultiDevice == 0 && o.ID == 0 {
   534  		var v com.Packet
   535  		v.UnmarshalStream(o)
   536  		o.Clear()
   537  		// Remove reference
   538  		o = nil
   539  		o = &v
   540  	}
   541  	return o, k
   542  }