github.com/iDigitalFlame/xmt@v0.5.4/c2/vars.go (about) 1 // Copyright (C) 2020 - 2023 iDigitalFlame 2 // 3 // This program is free software: you can redistribute it and/or modify 4 // it under the terms of the GNU General Public License as published by 5 // the Free Software Foundation, either version 3 of the License, or 6 // any later version. 7 // 8 // This program is distributed in the hope that it will be useful, 9 // but WITHOUT ANY WARRANTY; without even the implied warranty of 10 // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 11 // GNU General Public License for more details. 12 // 13 // You should have received a copy of the GNU General Public License 14 // along with this program. If not, see <https://www.gnu.org/licenses/>. 15 // 16 17 package c2 18 19 import ( 20 "context" 21 "io" 22 "net" 23 "sync" 24 "time" 25 26 "github.com/iDigitalFlame/xmt/c2/cfg" 27 "github.com/iDigitalFlame/xmt/c2/cout" 28 "github.com/iDigitalFlame/xmt/c2/task" 29 "github.com/iDigitalFlame/xmt/com" 30 "github.com/iDigitalFlame/xmt/com/limits" 31 "github.com/iDigitalFlame/xmt/com/pipe" 32 "github.com/iDigitalFlame/xmt/data" 33 "github.com/iDigitalFlame/xmt/device" 34 "github.com/iDigitalFlame/xmt/util" 35 "github.com/iDigitalFlame/xmt/util/bugtrack" 36 "github.com/iDigitalFlame/xmt/util/xerr" 37 ) 38 39 // RvResult is the generic value for indicating a result value. Packets 40 // that have this as their ID value will be forwarded to the authoritative 41 // Mux and will be discarded if it does not match an active Job ID. 42 const RvResult uint8 = 0x14 43 44 const ( 45 fragMax = 0xFFFF 46 readTimeout = time.Millisecond * 350 47 ) 48 49 // ID entries that start with 'Sv*' will be handed directly by the underlying 50 // Session instead of being forwarded to the authoritative Mux. 51 // 52 // These Packet ID values are used for network congestion and flow control and 53 // should not be used in standard Packet entries. 54 const ( 55 SvResync uint8 = 0x1 56 SvHello uint8 = 0x2 57 SvRegister uint8 = 0x3 // Considered a MvDrop. 58 SvComplete uint8 = 0x4 59 SvShutdown uint8 = 0x5 60 SvDrop uint8 = 0x6 61 ) 62 63 // ErrTooManyPackets is an error returned by many of the Packet writing 64 // functions when attempts to combine Packets would create a Packet grouping 65 // size larger than the maximum size (65535/0xFFFF). 66 var ErrTooManyPackets = xerr.Sub("frag/multi count is larger than 0xFFFF", 0x56) 67 68 var buffers = sync.Pool{ 69 New: func() interface{} { 70 return new(data.Chunk) 71 }, 72 } 73 74 func returnBuffer(c *data.Chunk) { 75 c.Clear() 76 buffers.Put(c) 77 } 78 func isPacketNoP(n *com.Packet) bool { 79 return n.ID < 2 && n.Empty() && (n.Flags == 0 || n.Flags == com.FlagProxy) 80 } 81 func mergeTags(one, two []uint32) []uint32 { 82 if len(one) == 0 && len(two) == 0 { 83 return nil 84 } 85 if len(one) == 0 && len(two) > 0 { 86 return two 87 } 88 if len(one) > 0 && len(two) == 0 { 89 return one 90 } 91 i := len(one) 92 if i < len(two) { 93 i = len(two) 94 } 95 t := make(map[uint32]struct{}, i) 96 for _, v := range one { 97 t[v] = wake 98 } 99 for _, v := range two { 100 t[v] = wake 101 } 102 r := make([]uint32, 0, len(t)) 103 for v := range t { 104 r = append(r, v) 105 } 106 return r 107 } 108 func receiveSingle(s *Session, n *com.Packet) { 109 if s == nil { 110 return 111 } 112 if bugtrack.Enabled { 113 bugtrack.Track( 114 "c2.receiveSingle(): n.ID=%X, n=%s, n.Flags=%s, n.Device=%s", n.ID, n, n.Flags, n.Device, 115 ) 116 } 117 switch n.ID { 118 case SvComplete: 119 if !n.Empty() && n.Flags&com.FlagCrypt != 0 { 120 s.keySessionSync(n) 121 n.Clear() 122 return 123 } 124 case SvResync: 125 if !s.hasJob(n.Job) { 126 if cout.Enabled { 127 s.log.Error("[%s/Cr0] Client sent a SvResync Packet not associated with an active Job!", s.ID, n.Job) 128 } 129 return 130 } 131 if cout.Enabled { 132 s.log.Debug("[%s/Cr0] Client sent a SvResync Packet associated with Job %d!", s.ID, n.Job) 133 } 134 t, err := n.Uint8() 135 if err != nil { 136 if cout.Enabled { 137 s.log.Error("[%s/Cr0] Error reading SvResync Packet: %s!", s.ID, err.Error()) 138 } 139 return 140 } 141 if _, err := s.readDeviceInfo(t, n); err != nil { 142 if cout.Enabled { 143 s.log.Error("[%s/Cr0] Error reading SvResync Packet result: %s!", s.ID, err.Error()) 144 } 145 return 146 } 147 if cout.Enabled { 148 s.log.Debug("[%s/Cr0] Client indicated that it changed profile/time, updating local Session information.", s.ID) 149 } 150 return 151 case SvShutdown: 152 if !s.IsClient() { 153 if cout.Enabled { 154 s.log.Info("[%s/Cr0] Client indicated shutdown, acknowledging and closing Session.", s.ID) 155 } 156 s.write(true, &com.Packet{ID: SvShutdown, Job: 1, Device: s.ID}) 157 s.s.Remove(s.ID, false) 158 s.state.Set(stateShutdownWait) 159 } else { 160 if s.state.Closing() { 161 return 162 } 163 if cout.Enabled { 164 s.log.Info("[%s/Cr0] Server indicated shutdown, closing Session.", s.ID) 165 } 166 } 167 s.close(false) 168 return 169 case SvRegister: 170 if !s.IsClient() { 171 return 172 } 173 if cout.Enabled { 174 s.log.Info("[%s/Cr0] Server indicated that we must re-register, resending SvRegister info!", s.ID) 175 } 176 if s.proxy != nil && s.proxy.IsActive() { 177 s.proxy.subsRegister() 178 } 179 v := &com.Packet{ID: SvHello, Job: uint16(util.FastRand()), Device: s.ID} 180 s.writeDeviceInfo(infoHello, v) 181 s.keySessionGenerate(v) 182 if s.queue(v); len(s.send) <= 1 { 183 s.Wake() 184 } 185 return 186 } 187 if n.ID < task.MvRefresh { 188 return 189 } 190 if s.parent == nil { 191 s.m.queue(event{p: n, s: s, hf: defaultClientMux}) 192 return 193 } 194 s.m.queue(event{p: n, s: s, af: s.handle}) 195 } 196 func verifyPacket(n *com.Packet, i device.ID) bool { 197 if n.Job == 0 && n.Flags&com.FlagProxy == 0 && n.ID > 1 { 198 n.Job = uint16(util.FastRand()) 199 } 200 if n.Device.Empty() { 201 n.Device = i 202 return true 203 } 204 return n.Device == i 205 } 206 func receive(s *Session, l *Listener, n *com.Packet) error { 207 if n == nil || n.Device.Empty() || isPacketNoP(n) || (l == nil && s == nil) { 208 return nil 209 } 210 if bugtrack.Enabled { 211 bugtrack.Track( 212 "c2.receive(): s == nil=%t, l == nil=%t, n.ID=%X, n=%s, n.Flags=%s, n.Device=%s", 213 s == nil, l == nil, n.ID, n, n.Flags, n.Device, 214 ) 215 } 216 if s != nil && n.Flags&com.FlagMultiDevice == 0 && s.ID != n.Device { 217 if s.proxy != nil && s.proxy.IsActive() && s.proxy.accept(n) { 218 return nil 219 } 220 if n.Clear(); xerr.ExtendedInfo { 221 return xerr.Sub(`received Packet for "`+n.Device.String()+`" that does not match our own device ID "`+s.ID.String()+`"`, 0x57) 222 } 223 return xerr.Sub("received Packet that does not match our own device ID", 0x57) 224 } 225 if n.Flags&com.FlagOneshot != 0 { 226 l.oneshot(n) 227 return nil 228 } 229 if s == nil || (n.ID == SvComplete && n.Flags&com.FlagCrypt == 0) { 230 n.Clear() 231 return nil 232 } 233 switch { 234 case n.Flags&com.FlagMulti != 0: 235 x := n.Flags.Len() 236 if x == 0 { 237 return ErrInvalidPacketCount 238 } 239 for ; x > 0; x-- { 240 var v com.Packet 241 if err := v.UnmarshalStream(n); err != nil { 242 n.Clear() 243 v.Clear() 244 return err 245 } 246 if cout.Enabled { 247 s.log.Trace(`[%s] Unpacked Packet "%s"..`, s.ID, v) 248 } 249 if err := receive(s, l, &v); err != nil { 250 n.Clear() 251 v.Clear() 252 return err 253 } 254 } 255 n.Clear() 256 return nil 257 case n.Flags&com.FlagFrag != 0 && n.Flags&com.FlagMulti == 0: 258 if n.ID == SvDrop || n.ID == SvRegister { 259 if cout.Enabled { 260 s.log.Warning("[%s] Indicated to clear Frag Group 0x%X!", s.ID, n.Flags.Group()) 261 } 262 if s.state.SetLast(n.Flags.Group()); n.ID != SvRegister { 263 n.Clear() 264 return nil 265 } 266 break 267 } 268 if n.Flags.Len() == 0 { 269 n.Clear() 270 return ErrInvalidPacketCount 271 } 272 if n.Flags.Len() == 1 { 273 if cout.Enabled { 274 s.log.Trace("[%s] Received a single frag (len=1) for Group 0x%X, clearing Flags!", s.ID, n.Flags.Group()) 275 } 276 n.Flags.Clear() 277 return receive(s, l, n) 278 } 279 if cout.Enabled { 280 s.log.Trace("[%s] Received frag for Group 0x%X (%d of %d).", s.ID, n.Flags.Group(), n.Flags.Position()+1, n.Flags.Len()) 281 } 282 var ( 283 g = n.Flags.Group() 284 c, ok = s.frags[g] 285 ) 286 if !ok && n.Flags.Position() > 0 { 287 if s.write(true, &com.Packet{ID: SvDrop, Flags: n.Flags, Device: s.ID}); cout.Enabled { 288 s.log.Warning("[%s] Received an invalid Frag Group 0x%X, indicating to drop it!", s.ID, n.Flags.Group()) 289 } 290 return nil 291 } 292 if !ok { 293 c = new(cluster) 294 s.frags[g] = c 295 } 296 if err := c.add(n); err != nil { 297 return err 298 } 299 if v := c.done(); v != nil { 300 if delete(s.frags, g); cout.Enabled { 301 s.log.Trace("[%s] Completed Frag Group 0x%X, %d total.", s.ID, n.Flags.Group(), n.Flags.Len()) 302 } 303 return receive(s, l, v) 304 } 305 s.frag(n.Job, n.Flags.Group(), n.Flags.Len(), n.Flags.Position()) 306 return nil 307 } 308 receiveSingle(s, n) 309 return nil 310 } 311 func writeUnpack(dst, src *com.Packet, flags, tags bool) error { 312 if src == nil || dst == nil { 313 return nil 314 } 315 if src.Flags&com.FlagMulti != 0 || src.Flags&com.FlagMultiDevice != 0 { 316 x := src.Flags.Len() 317 if x == 0 { 318 return ErrInvalidPacketCount 319 } 320 if x+dst.Flags.Len() > fragMax { 321 return ErrTooManyPackets 322 } 323 src.WriteTo(dst) 324 dst.Flags.SetLen(dst.Flags.Len() + x) 325 src.Clear() 326 return nil 327 } 328 if dst.Flags.Len()+1 > fragMax { 329 return ErrTooManyPackets 330 } 331 src.MarshalStream(dst) 332 if dst.Flags.SetLen(dst.Flags.Len() + 1); flags { 333 if src.Flags&com.FlagChannel != 0 { 334 dst.Flags |= com.FlagChannel 335 } 336 if src.Flags&com.FlagMultiDevice != 0 { 337 dst.Flags |= com.FlagMultiDevice 338 } 339 } 340 if dst.Flags |= com.FlagMulti; tags && len(src.Tags) > 0 { 341 dst.Tags = append(dst.Tags, src.Tags...) 342 } 343 src.Clear() 344 return nil 345 } 346 func readPacketFrom(c io.Reader, w cfg.Wrapper, n *com.Packet) error { 347 if w == nil { 348 if bugtrack.Enabled { 349 bugtrack.Track("c2.readPacketFrom(): Passing read to direct Unmarshal.") 350 } 351 return n.Unmarshal(c) 352 } 353 if bugtrack.Enabled { 354 bugtrack.Track("c2.readPacketFrom(): Starting read with Wrapper.") 355 } 356 i, err := w.Unwrap(c) 357 if err != nil { 358 return xerr.Wrap("unable to unwrap Reader", err) 359 } 360 if err = n.Unmarshal(i); err != nil { 361 return err 362 } 363 return nil 364 } 365 func writePacketTo(c *data.Chunk, w cfg.Wrapper, n *com.Packet) error { 366 if w == nil { 367 if bugtrack.Enabled { 368 bugtrack.Track("c2.writePacketTo(): Passing write to direct Marshal.") 369 } 370 return n.Marshal(c) 371 } 372 o, err := w.Wrap(c) 373 if err != nil { 374 return xerr.Wrap("unable to wrap Writer", err) 375 } 376 if bugtrack.Enabled { 377 bugtrack.Track("c2.writePacketTo(): n=%s, n.Len()=%d, n.Size()=%d", n, n.Size(), n.Size()) 378 } 379 if err = n.Marshal(o); err != nil { 380 return err 381 } 382 if err = o.Close(); err != nil { 383 return xerr.Wrap("unable to close Wrapper", err) 384 } 385 return nil 386 } 387 func spinTimeout(x context.Context, n string, t time.Duration) net.Conn { 388 var ( 389 y, f = context.WithTimeout(x, t) 390 c net.Conn 391 ) 392 for c == nil { 393 select { 394 case <-y.Done(): 395 f() 396 return nil 397 case <-x.Done(): 398 f() 399 return nil 400 default: 401 c, _ = pipe.DialContext(y, n) 402 } 403 } 404 f() 405 return c 406 } 407 func readPacket(c net.Conn, w cfg.Wrapper, t cfg.Transform) (*com.Packet, error) { 408 var n com.Packet 409 if w == nil && t == nil { 410 if err := n.Unmarshal(&readerTimeout{c: c, t: readTimeout}); err != nil { 411 return nil, xerr.Wrap("unable to read from stream", err) 412 } 413 if bugtrack.Enabled { 414 bugtrack.Track("c2.readPacket(): Direct Unmarshal result n=%s", n) 415 } 416 return &n, nil 417 } 418 var ( 419 b = buffers.Get().(*data.Chunk) 420 d, err = b.ReadDeadline(c, readTimeout) 421 ) 422 if bugtrack.Enabled { 423 bugtrack.Track("c2.readPacket(): ReadDeadline result d=%d, err=%s", d, err) 424 } 425 if d == 0 { 426 if returnBuffer(b); err != nil { 427 return nil, xerr.Wrap("unable to read from stream", err) 428 } 429 return nil, xerr.Wrap("unable to read from stream", io.ErrUnexpectedEOF) 430 } 431 if t != nil { 432 o := buffers.Get().(*data.Chunk) 433 err = t.Read(b.Payload(), o) 434 if returnBuffer(b); err != nil { 435 returnBuffer(o) 436 return nil, xerr.Wrap("unable to read from cache", err) 437 } 438 b = o 439 } 440 err = readPacketFrom(b, w, &n) 441 if returnBuffer(b); err != nil { 442 n.Clear() 443 return nil, err 444 } 445 if bugtrack.Enabled { 446 bugtrack.Track("c2.readPacket(): Unmarshal result n=%s", n) 447 } 448 return &n, nil 449 } 450 func writePacket(c net.Conn, w cfg.Wrapper, t cfg.Transform, n *com.Packet) error { 451 if w == nil && t == nil { 452 err := n.Marshal(c) 453 n.Clear() 454 return err 455 } 456 var ( 457 b = buffers.Get().(*data.Chunk) 458 err = writePacketTo(b, w, n) 459 ) 460 if n.Clear(); err != nil { 461 returnBuffer(b) 462 return xerr.Wrap("unable to write to cache", err) 463 } 464 if t != nil { 465 err = t.Write(b.Payload(), c) 466 } else { 467 _, err = b.WriteTo(c) 468 } 469 if returnBuffer(b); err != nil { 470 return xerr.Wrap("unable to write to stream", err) 471 } 472 return nil 473 } 474 func nextPacket(a notifier, q <-chan *com.Packet, n *com.Packet, i device.ID, t []uint32) (*com.Packet, *com.Packet) { 475 if n == nil && len(q) == 0 { 476 return nil, nil 477 } 478 // NOTE(dij): Fast path (if we have a strict limit OR we don't have 479 // anything in queue, but we got a packet). So just send that 480 // shit/wrap if needed. 481 if limits.Packets <= 1 || (n != nil && len(q) == 0) { 482 if n == nil { 483 if n = <-q; n == nil { 484 return nil, nil 485 } 486 } 487 if a.accept(n.Job); verifyPacket(n, i) { 488 n.Tags = append(n.Tags, t...) 489 return n, nil 490 } 491 o := &com.Packet{Device: i, Flags: com.FlagMulti | com.FlagMultiDevice} 492 writeUnpack(o, n, true, true) 493 o.Tags = append(o.Tags, t...) 494 return o, nil 495 } 496 var ( 497 o = &com.Packet{Device: i, Flags: com.FlagMulti} 498 k *com.Packet 499 ) 500 for x, s, m := 0, 0, false; x < limits.Packets && len(q) > 0; x++ { 501 if n == nil { 502 n = <-q 503 } 504 // TODO(dij): ?need to add a check here to see if len(c) == 0 505 // if so, drop a SvNop and return only the first 506 if isPacketNoP(n) && ((s > 0 && !m) || (n.Device.Empty() || n.Device == i)) { 507 n.Clear() 508 n = nil 509 continue 510 } 511 // Rare case a single packet (which was already chunked, 512 // is bigger than the frag size, shouldn't happen but *shrug*) 513 // s would be zero on the first round, so just send that one and "fuck it" 514 if s > 0 { 515 if s += n.Size(); s > limits.Frag { 516 k = n 517 break 518 } 519 } else { 520 s += n.Size() 521 } 522 // Set multi device flag if there's a packet in queue that doesn't match us. 523 if a.accept(n.Job); !verifyPacket(n, i) && !m { 524 o.Flags |= com.FlagMultiDevice 525 m = true 526 } 527 writeUnpack(o, n, true, true) 528 n = nil 529 } 530 // If we get a single packet, unpack it and send it instead. 531 // I don't think there's a super good way to do this, as we clear most of the 532 // data during write. IE: we have >1 NOPs and just a single data Packet. 533 if o.Flags.Len() == 1 && o.Flags&com.FlagMultiDevice == 0 && o.ID == 0 { 534 var v com.Packet 535 v.UnmarshalStream(o) 536 o.Clear() 537 // Remove reference 538 o = nil 539 o = &v 540 } 541 return o, k 542 }