github.com/psiphon-labs/psiphon-tunnel-core@v2.0.28+incompatible/psiphon/common/quic/gquic-go/packet_packer_legacy.go (about)

     1  package gquic
     2  
     3  import (
     4  	"bytes"
     5  	"errors"
     6  	"fmt"
     7  	"net"
     8  
     9  	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/quic/gquic-go/internal/ackhandler"
    10  	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/quic/gquic-go/internal/handshake"
    11  	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/quic/gquic-go/internal/protocol"
    12  	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/quic/gquic-go/internal/wire"
    13  )
    14  
    15  // sentAndReceivedPacketManager is only needed until STOP_WAITING is removed
    16  type sentAndReceivedPacketManager struct {
    17  	ackhandler.SentPacketHandler
    18  	ackhandler.ReceivedPacketHandler
    19  }
    20  
    21  var _ ackFrameSource = &sentAndReceivedPacketManager{}
    22  
    23  type packetPackerLegacy struct {
    24  	destConnID protocol.ConnectionID
    25  	srcConnID  protocol.ConnectionID
    26  
    27  	perspective protocol.Perspective
    28  	version     protocol.VersionNumber
    29  	cryptoSetup sealingManager
    30  
    31  	divNonce []byte
    32  
    33  	packetNumberGenerator *packetNumberGenerator
    34  	getPacketNumberLen    func(protocol.PacketNumber) protocol.PacketNumberLen
    35  	cryptoStream          cryptoStream
    36  	framer                frameSource
    37  	acks                  ackFrameSource
    38  
    39  	omitConnectionID          bool
    40  	maxPacketSize             protocol.ByteCount
    41  	hasSentPacket             bool // has the packetPacker already sent a packet
    42  	numNonRetransmittableAcks int
    43  }
    44  
    45  var _ packer = &packetPackerLegacy{}
    46  
    47  func newPacketPackerLegacy(
    48  	destConnID protocol.ConnectionID,
    49  	srcConnID protocol.ConnectionID,
    50  	getPacketNumberLen func(protocol.PacketNumber) protocol.PacketNumberLen,
    51  	remoteAddr net.Addr, // only used for determining the max packet size
    52  	divNonce []byte,
    53  	cryptoStream cryptoStream,
    54  	cryptoSetup sealingManager,
    55  	framer frameSource,
    56  	acks ackFrameSource,
    57  	perspective protocol.Perspective,
    58  	version protocol.VersionNumber,
    59  ) *packetPackerLegacy {
    60  	return &packetPackerLegacy{
    61  		cryptoStream:          cryptoStream,
    62  		cryptoSetup:           cryptoSetup,
    63  		divNonce:              divNonce,
    64  		destConnID:            destConnID,
    65  		srcConnID:             srcConnID,
    66  		perspective:           perspective,
    67  		version:               version,
    68  		framer:                framer,
    69  		acks:                  acks,
    70  		getPacketNumberLen:    getPacketNumberLen,
    71  		packetNumberGenerator: newPacketNumberGenerator(1, protocol.SkipPacketAveragePeriodLength),
    72  		maxPacketSize:         getMaxPacketSize(remoteAddr),
    73  	}
    74  }
    75  
    76  // PackConnectionClose packs a packet that ONLY contains a ConnectionCloseFrame
    77  func (p *packetPackerLegacy) PackConnectionClose(ccf *wire.ConnectionCloseFrame) (*packedPacket, error) {
    78  	frames := []wire.Frame{ccf}
    79  	encLevel, sealer := p.cryptoSetup.GetSealer()
    80  	header := p.getHeader(encLevel)
    81  	raw, err := p.writeAndSealPacket(header, frames, sealer)
    82  	return &packedPacket{
    83  		header:          header,
    84  		raw:             raw,
    85  		frames:          frames,
    86  		encryptionLevel: encLevel,
    87  	}, err
    88  }
    89  
    90  func (p *packetPackerLegacy) MaybePackAckPacket() (*packedPacket, error) {
    91  	ack := p.acks.GetAckFrame()
    92  	if ack == nil {
    93  		return nil, nil
    94  	}
    95  	encLevel, sealer := p.cryptoSetup.GetSealer()
    96  	header := p.getHeader(encLevel)
    97  	frames := []wire.Frame{ack}
    98  	// add a STOP_WAITING frame, if necessary
    99  	if p.version.UsesStopWaitingFrames() {
   100  		if swf := p.acks.GetStopWaitingFrame(false); swf != nil {
   101  			swf.PacketNumber = header.PacketNumber
   102  			swf.PacketNumberLen = header.PacketNumberLen
   103  			frames = append(frames, swf)
   104  		}
   105  	}
   106  	raw, err := p.writeAndSealPacket(header, frames, sealer)
   107  	return &packedPacket{
   108  		header:          header,
   109  		raw:             raw,
   110  		frames:          frames,
   111  		encryptionLevel: encLevel,
   112  	}, err
   113  }
   114  
   115  // PackRetransmission packs a retransmission
   116  // For packets sent after completion of the handshake, it might happen that 2 packets have to be sent.
   117  // This can happen e.g. when a longer packet number is used in the header.
   118  func (p *packetPackerLegacy) PackRetransmission(packet *ackhandler.Packet) ([]*packedPacket, error) {
   119  	if packet.EncryptionLevel != protocol.EncryptionForwardSecure {
   120  		p, err := p.packHandshakeRetransmission(packet)
   121  		return []*packedPacket{p}, err
   122  	}
   123  
   124  	var controlFrames []wire.Frame
   125  	var streamFrames []*wire.StreamFrame
   126  	for _, f := range packet.Frames {
   127  		if sf, ok := f.(*wire.StreamFrame); ok {
   128  			sf.DataLenPresent = true
   129  			streamFrames = append(streamFrames, sf)
   130  		} else {
   131  			controlFrames = append(controlFrames, f)
   132  		}
   133  	}
   134  
   135  	var packets []*packedPacket
   136  	encLevel, sealer := p.cryptoSetup.GetSealer()
   137  	var swf *wire.StopWaitingFrame
   138  	// add a STOP_WAITING for *every* retransmission
   139  	if p.version.UsesStopWaitingFrames() {
   140  		swf = p.acks.GetStopWaitingFrame(true)
   141  	}
   142  	for len(controlFrames) > 0 || len(streamFrames) > 0 {
   143  		var frames []wire.Frame
   144  		var length protocol.ByteCount
   145  
   146  		header := p.getHeader(encLevel)
   147  		headerLength, err := header.GetLength(p.version)
   148  		if err != nil {
   149  			return nil, err
   150  		}
   151  		maxSize := p.maxPacketSize - protocol.ByteCount(sealer.Overhead()) - headerLength
   152  
   153  		if p.version.UsesStopWaitingFrames() {
   154  			// create a new STOP_WAIITNG Frame, since we might need to send more than one packet as a retransmission
   155  			stopWaitingFrame := &wire.StopWaitingFrame{
   156  				LeastUnacked:    swf.LeastUnacked,
   157  				PacketNumber:    header.PacketNumber,
   158  				PacketNumberLen: header.PacketNumberLen,
   159  			}
   160  			length += stopWaitingFrame.Length(p.version)
   161  			frames = append(frames, stopWaitingFrame)
   162  		}
   163  
   164  		for len(controlFrames) > 0 {
   165  			frame := controlFrames[0]
   166  			frameLen := frame.Length(p.version)
   167  			if length+frameLen > maxSize {
   168  				break
   169  			}
   170  			length += frameLen
   171  			frames = append(frames, frame)
   172  			controlFrames = controlFrames[1:]
   173  		}
   174  
   175  		// temporarily increase the maxFrameSize by the (minimum) length of the DataLen field
   176  		// this leads to a properly sized packet in all cases, since we do all the packet length calculations with StreamFrames that have the DataLen set
   177  		// however, for the last STREAM frame in the packet, we can omit the DataLen, thus yielding a packet of exactly the correct size
   178  		maxSize += 2
   179  
   180  		for len(streamFrames) > 0 && length+protocol.MinStreamFrameSize < maxSize {
   181  			frame := streamFrames[0]
   182  			frameToAdd := frame
   183  
   184  			sf, err := frame.MaybeSplitOffFrame(maxSize-length, p.version)
   185  			if err != nil {
   186  				return nil, err
   187  			}
   188  			if sf != nil {
   189  				frameToAdd = sf
   190  			} else {
   191  				streamFrames = streamFrames[1:]
   192  			}
   193  			length += frameToAdd.Length(p.version)
   194  			frames = append(frames, frameToAdd)
   195  		}
   196  		if sf, ok := frames[len(frames)-1].(*wire.StreamFrame); ok {
   197  			sf.DataLenPresent = false
   198  		}
   199  		raw, err := p.writeAndSealPacket(header, frames, sealer)
   200  		if err != nil {
   201  			return nil, err
   202  		}
   203  		packets = append(packets, &packedPacket{
   204  			header:          header,
   205  			raw:             raw,
   206  			frames:          frames,
   207  			encryptionLevel: encLevel,
   208  		})
   209  	}
   210  	return packets, nil
   211  }
   212  
   213  // packHandshakeRetransmission retransmits a handshake packet, that was sent with less than forward-secure encryption
   214  func (p *packetPackerLegacy) packHandshakeRetransmission(packet *ackhandler.Packet) (*packedPacket, error) {
   215  	sealer, err := p.cryptoSetup.GetSealerWithEncryptionLevel(packet.EncryptionLevel)
   216  	if err != nil {
   217  		return nil, err
   218  	}
   219  	// make sure that the retransmission for an Initial packet is sent as an Initial packet
   220  	if packet.PacketType == protocol.PacketTypeInitial {
   221  		p.hasSentPacket = false
   222  	}
   223  	header := p.getHeader(packet.EncryptionLevel)
   224  	header.Type = packet.PacketType
   225  	var frames []wire.Frame
   226  	if p.version.UsesStopWaitingFrames() { // pack a STOP_WAITING first
   227  		swf := p.acks.GetStopWaitingFrame(true)
   228  		swf.PacketNumber = header.PacketNumber
   229  		swf.PacketNumberLen = header.PacketNumberLen
   230  		frames = append([]wire.Frame{swf}, packet.Frames...)
   231  	} else {
   232  		frames = packet.Frames
   233  	}
   234  	raw, err := p.writeAndSealPacket(header, frames, sealer)
   235  	return &packedPacket{
   236  		header:          header,
   237  		raw:             raw,
   238  		frames:          frames,
   239  		encryptionLevel: packet.EncryptionLevel,
   240  	}, err
   241  }
   242  
   243  // PackPacket packs a new packet
   244  // the other controlFrames are sent in the next packet, but might be queued and sent in the next packet if the packet would overflow MaxPacketSize otherwise
   245  func (p *packetPackerLegacy) PackPacket() (*packedPacket, error) {
   246  	packet, err := p.maybePackCryptoPacket()
   247  	if err != nil {
   248  		return nil, err
   249  	}
   250  	if packet != nil {
   251  		return packet, nil
   252  	}
   253  	// if this is the first packet to be send, make sure it contains stream data
   254  	if !p.hasSentPacket && packet == nil {
   255  		return nil, nil
   256  	}
   257  
   258  	encLevel, sealer := p.cryptoSetup.GetSealer()
   259  
   260  	header := p.getHeader(encLevel)
   261  	headerLength, err := header.GetLength(p.version)
   262  	if err != nil {
   263  		return nil, err
   264  	}
   265  
   266  	maxSize := p.maxPacketSize - protocol.ByteCount(sealer.Overhead()) - headerLength
   267  	frames, err := p.composeNextPacket(header, maxSize, p.canSendData(encLevel))
   268  	if err != nil {
   269  		return nil, err
   270  	}
   271  
   272  	// Check if we have enough frames to send
   273  	if len(frames) == 0 {
   274  		return nil, nil
   275  	}
   276  	// check if this packet only contains an ACK (and maybe a STOP_WAITING)
   277  	if !ackhandler.HasRetransmittableFrames(frames) {
   278  		if p.numNonRetransmittableAcks >= protocol.MaxNonRetransmittableAcks {
   279  			frames = append(frames, &wire.PingFrame{})
   280  			p.numNonRetransmittableAcks = 0
   281  		} else {
   282  			p.numNonRetransmittableAcks++
   283  		}
   284  	} else {
   285  		p.numNonRetransmittableAcks = 0
   286  	}
   287  
   288  	raw, err := p.writeAndSealPacket(header, frames, sealer)
   289  	if err != nil {
   290  		return nil, err
   291  	}
   292  	return &packedPacket{
   293  		header:          header,
   294  		raw:             raw,
   295  		frames:          frames,
   296  		encryptionLevel: encLevel,
   297  	}, nil
   298  }
   299  
   300  func (p *packetPackerLegacy) maybePackCryptoPacket() (*packedPacket, error) {
   301  	if !p.cryptoStream.hasData() {
   302  		return nil, nil
   303  	}
   304  	encLevel, sealer := p.cryptoSetup.GetSealerForCryptoStream()
   305  	header := p.getHeader(encLevel)
   306  	headerLength, err := header.GetLength(p.version)
   307  	if err != nil {
   308  		return nil, err
   309  	}
   310  	maxLen := p.maxPacketSize - protocol.ByteCount(sealer.Overhead()) - protocol.NonForwardSecurePacketSizeReduction - headerLength
   311  	sf, _ := p.cryptoStream.popStreamFrame(maxLen)
   312  	sf.DataLenPresent = false
   313  	frames := []wire.Frame{sf}
   314  	raw, err := p.writeAndSealPacket(header, frames, sealer)
   315  	if err != nil {
   316  		return nil, err
   317  	}
   318  	return &packedPacket{
   319  		header:          header,
   320  		raw:             raw,
   321  		frames:          frames,
   322  		encryptionLevel: encLevel,
   323  	}, nil
   324  }
   325  
   326  func (p *packetPackerLegacy) composeNextPacket(
   327  	header *wire.Header, // only needed to fill in the STOP_WAITING frame
   328  	maxFrameSize protocol.ByteCount,
   329  	canSendStreamFrames bool,
   330  ) ([]wire.Frame, error) {
   331  	var length protocol.ByteCount
   332  	var frames []wire.Frame
   333  
   334  	// STOP_WAITING and ACK will always fit
   335  	// ACKs need to go first, so that the sentPacketHandler will recognize them
   336  	if ack := p.acks.GetAckFrame(); ack != nil {
   337  		frames = append(frames, ack)
   338  		length += ack.Length(p.version)
   339  		// add a STOP_WAITING, for gQUIC
   340  		if p.version.UsesStopWaitingFrames() {
   341  			if swf := p.acks.GetStopWaitingFrame(false); swf != nil {
   342  				swf.PacketNumber = header.PacketNumber
   343  				swf.PacketNumberLen = header.PacketNumberLen
   344  				frames = append(frames, swf)
   345  				length += swf.Length(p.version)
   346  			}
   347  		}
   348  	}
   349  
   350  	var lengthAdded protocol.ByteCount
   351  	frames, lengthAdded = p.framer.AppendControlFrames(frames, maxFrameSize-length)
   352  	length += lengthAdded
   353  
   354  	if !canSendStreamFrames {
   355  		return frames, nil
   356  	}
   357  
   358  	// temporarily increase the maxFrameSize by the (minimum) length of the DataLen field
   359  	// this leads to a properly sized packet in all cases, since we do all the packet length calculations with StreamFrames that have the DataLen set
   360  	// however, for the last STREAM frame in the packet, we can omit the DataLen, thus yielding a packet of exactly the correct size
   361  	maxFrameSize += 2
   362  
   363  	frames = p.framer.AppendStreamFrames(frames, maxFrameSize-length)
   364  	if len(frames) > 0 {
   365  		lastFrame := frames[len(frames)-1]
   366  		if sf, ok := lastFrame.(*wire.StreamFrame); ok {
   367  			sf.DataLenPresent = false
   368  		}
   369  	}
   370  	return frames, nil
   371  }
   372  
   373  func (p *packetPackerLegacy) getHeader(encLevel protocol.EncryptionLevel) *wire.Header {
   374  	pnum := p.packetNumberGenerator.Peek()
   375  	packetNumberLen := p.getPacketNumberLen(pnum)
   376  
   377  	header := &wire.Header{
   378  		PacketNumber:    pnum,
   379  		PacketNumberLen: packetNumberLen,
   380  		Version:         p.version,
   381  	}
   382  
   383  	if p.version.UsesIETFHeaderFormat() && encLevel != protocol.EncryptionForwardSecure {
   384  		header.IsLongHeader = true
   385  		header.SrcConnectionID = p.srcConnID
   386  		header.PacketNumberLen = protocol.PacketNumberLen4
   387  		if !p.hasSentPacket && p.perspective == protocol.PerspectiveClient {
   388  			header.Type = protocol.PacketTypeInitial
   389  		} else {
   390  			header.Type = protocol.PacketTypeHandshake
   391  		}
   392  	}
   393  
   394  	if !p.omitConnectionID || encLevel != protocol.EncryptionForwardSecure {
   395  		header.DestConnectionID = p.destConnID
   396  	}
   397  	if p.perspective == protocol.PerspectiveServer && encLevel == protocol.EncryptionSecure {
   398  		header.Type = protocol.PacketType0RTT
   399  		header.DiversificationNonce = p.divNonce
   400  	}
   401  	if p.perspective == protocol.PerspectiveClient && encLevel != protocol.EncryptionForwardSecure {
   402  		header.VersionFlag = true
   403  	}
   404  	return header
   405  }
   406  
   407  func (p *packetPackerLegacy) writeAndSealPacket(
   408  	header *wire.Header,
   409  	frames []wire.Frame,
   410  	sealer handshake.Sealer,
   411  ) ([]byte, error) {
   412  	raw := *getPacketBuffer()
   413  	buffer := bytes.NewBuffer(raw[:0])
   414  
   415  	if err := header.Write(buffer, p.perspective, p.version); err != nil {
   416  		return nil, err
   417  	}
   418  	payloadStartIndex := buffer.Len()
   419  
   420  	// [Psiphon]
   421  	// In our tests, gQUICv44 works only when we restore this block of code that was refactored out in:
   422  	// https://github.com/lucas-clemente/quic-go/commit/5df98dc3891df7349ee6ff41714fbe6bb4d72440.
   423  
   424  	// the Initial packet needs to be padded, so the last STREAM frame must have the data length present
   425  	if header.Type == protocol.PacketTypeInitial {
   426  		lastFrame := frames[len(frames)-1]
   427  		if sf, ok := lastFrame.(*wire.StreamFrame); ok {
   428  			sf.DataLenPresent = true
   429  		}
   430  	}
   431  	// [Psiphon]
   432  
   433  	for _, frame := range frames {
   434  		if err := frame.Write(buffer, p.version); err != nil {
   435  			return nil, err
   436  		}
   437  	}
   438  
   439  	if size := protocol.ByteCount(buffer.Len() + sealer.Overhead()); size > p.maxPacketSize {
   440  		return nil, fmt.Errorf("PacketPacker BUG: packet too large (%d bytes, allowed %d bytes)", size, p.maxPacketSize)
   441  	}
   442  
   443  	raw = raw[0:buffer.Len()]
   444  	_ = sealer.Seal(raw[payloadStartIndex:payloadStartIndex], raw[payloadStartIndex:], header.PacketNumber, raw[:payloadStartIndex])
   445  	raw = raw[0 : buffer.Len()+sealer.Overhead()]
   446  
   447  	num := p.packetNumberGenerator.Pop()
   448  	if num != header.PacketNumber {
   449  		return nil, errors.New("packetPacker BUG: Peeked and Popped packet numbers do not match")
   450  	}
   451  	p.hasSentPacket = true
   452  	return raw, nil
   453  }
   454  
   455  func (p *packetPackerLegacy) canSendData(encLevel protocol.EncryptionLevel) bool {
   456  	if p.perspective == protocol.PerspectiveClient {
   457  		return encLevel >= protocol.EncryptionSecure
   458  	}
   459  	return encLevel == protocol.EncryptionForwardSecure
   460  }
   461  
   462  func (p *packetPackerLegacy) ChangeDestConnectionID(connID protocol.ConnectionID) error {
   463  	// [Psiphon]
   464  	// - Don't panic, since this may be invoked due to invalid input.
   465  	//panic("changing connection IDs not supported by gQUIC")
   466  	return errors.New("changing connection IDs not supported by gQUIC")
   467  	// [Psiphon]
   468  }
   469  
   470  func (p *packetPackerLegacy) HandleTransportParameters(params *handshake.TransportParameters) {
   471  	p.omitConnectionID = params.OmitConnectionID
   472  }