github.com/psiphon-labs/psiphon-tunnel-core@v2.0.28+incompatible/psiphon/common/quic/gquic-go/packet_packer.go (about)

     1  package gquic
     2  
     3  import (
     4  	"bytes"
     5  	"errors"
     6  	"fmt"
     7  	"net"
     8  	"time"
     9  
    10  	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/quic/gquic-go/internal/ackhandler"
    11  	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/quic/gquic-go/internal/handshake"
    12  	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/quic/gquic-go/internal/protocol"
    13  	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/quic/gquic-go/internal/utils"
    14  	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/quic/gquic-go/internal/wire"
    15  )
    16  
    17  type packer interface {
    18  	PackPacket() (*packedPacket, error)
    19  	MaybePackAckPacket() (*packedPacket, error)
    20  	PackRetransmission(packet *ackhandler.Packet) ([]*packedPacket, error)
    21  	PackConnectionClose(*wire.ConnectionCloseFrame) (*packedPacket, error)
    22  
    23  	HandleTransportParameters(*handshake.TransportParameters)
    24  
    25  	// [Psiphon]
    26  	// - Add error return value.
    27  	ChangeDestConnectionID(protocol.ConnectionID) error
    28  	// [Psiphon]
    29  }
    30  
    31  type packedPacket struct {
    32  	header          *wire.Header
    33  	raw             []byte
    34  	frames          []wire.Frame
    35  	encryptionLevel protocol.EncryptionLevel
    36  }
    37  
    38  func (p *packedPacket) ToAckHandlerPacket() *ackhandler.Packet {
    39  	return &ackhandler.Packet{
    40  		PacketNumber:    p.header.PacketNumber,
    41  		PacketType:      p.header.Type,
    42  		Frames:          p.frames,
    43  		Length:          protocol.ByteCount(len(p.raw)),
    44  		EncryptionLevel: p.encryptionLevel,
    45  		SendTime:        time.Now(),
    46  	}
    47  }
    48  
    49  func getMaxPacketSize(addr net.Addr) protocol.ByteCount {
    50  	maxSize := protocol.ByteCount(protocol.MinInitialPacketSize)
    51  	// If this is not a UDP address, we don't know anything about the MTU.
    52  	// Use the minimum size of an Initial packet as the max packet size.
    53  	if udpAddr, ok := addr.(*net.UDPAddr); ok {
    54  		// If ip is not an IPv4 address, To4 returns nil.
    55  		// Note that there might be some corner cases, where this is not correct.
    56  		// See https://stackoverflow.com/questions/22751035/golang-distinguish-ipv4-ipv6.
    57  		if udpAddr.IP.To4() == nil {
    58  			maxSize = protocol.MaxPacketSizeIPv6
    59  		} else {
    60  			maxSize = protocol.MaxPacketSizeIPv4
    61  		}
    62  	}
    63  	return maxSize
    64  }
    65  
    66  type sealingManager interface {
    67  	GetSealer() (protocol.EncryptionLevel, handshake.Sealer)
    68  	GetSealerForCryptoStream() (protocol.EncryptionLevel, handshake.Sealer)
    69  	GetSealerWithEncryptionLevel(protocol.EncryptionLevel) (handshake.Sealer, error)
    70  }
    71  
    72  type frameSource interface {
    73  	AppendStreamFrames([]wire.Frame, protocol.ByteCount) []wire.Frame
    74  	AppendControlFrames([]wire.Frame, protocol.ByteCount) ([]wire.Frame, protocol.ByteCount)
    75  }
    76  
    77  type ackFrameSource interface {
    78  	GetAckFrame() *wire.AckFrame
    79  	GetStopWaitingFrame(bool) *wire.StopWaitingFrame
    80  }
    81  
    82  type packetPacker struct {
    83  	destConnID protocol.ConnectionID
    84  	srcConnID  protocol.ConnectionID
    85  
    86  	perspective protocol.Perspective
    87  	version     protocol.VersionNumber
    88  	cryptoSetup sealingManager
    89  
    90  	token []byte
    91  
    92  	packetNumberGenerator *packetNumberGenerator
    93  	getPacketNumberLen    func(protocol.PacketNumber) protocol.PacketNumberLen
    94  	cryptoStream          cryptoStream
    95  	framer                frameSource
    96  	acks                  ackFrameSource
    97  
    98  	maxPacketSize             protocol.ByteCount
    99  	hasSentPacket             bool // has the packetPacker already sent a packet
   100  	numNonRetransmittableAcks int
   101  }
   102  
   103  var _ packer = &packetPacker{}
   104  
   105  func newPacketPacker(
   106  	destConnID protocol.ConnectionID,
   107  	srcConnID protocol.ConnectionID,
   108  	initialPacketNumber protocol.PacketNumber,
   109  	getPacketNumberLen func(protocol.PacketNumber) protocol.PacketNumberLen,
   110  	remoteAddr net.Addr, // only used for determining the max packet size
   111  	token []byte,
   112  	cryptoStream cryptoStream,
   113  	cryptoSetup sealingManager,
   114  	framer frameSource,
   115  	acks ackFrameSource,
   116  	perspective protocol.Perspective,
   117  	version protocol.VersionNumber,
   118  ) *packetPacker {
   119  	return &packetPacker{
   120  		cryptoStream:          cryptoStream,
   121  		cryptoSetup:           cryptoSetup,
   122  		token:                 token,
   123  		destConnID:            destConnID,
   124  		srcConnID:             srcConnID,
   125  		perspective:           perspective,
   126  		version:               version,
   127  		framer:                framer,
   128  		acks:                  acks,
   129  		getPacketNumberLen:    getPacketNumberLen,
   130  		packetNumberGenerator: newPacketNumberGenerator(initialPacketNumber, protocol.SkipPacketAveragePeriodLength),
   131  		maxPacketSize:         getMaxPacketSize(remoteAddr),
   132  	}
   133  }
   134  
   135  // PackConnectionClose packs a packet that ONLY contains a ConnectionCloseFrame
   136  func (p *packetPacker) PackConnectionClose(ccf *wire.ConnectionCloseFrame) (*packedPacket, error) {
   137  	frames := []wire.Frame{ccf}
   138  	encLevel, sealer := p.cryptoSetup.GetSealer()
   139  	header := p.getHeader(encLevel)
   140  	raw, err := p.writeAndSealPacket(header, frames, sealer)
   141  	return &packedPacket{
   142  		header:          header,
   143  		raw:             raw,
   144  		frames:          frames,
   145  		encryptionLevel: encLevel,
   146  	}, err
   147  }
   148  
   149  func (p *packetPacker) MaybePackAckPacket() (*packedPacket, error) {
   150  	ack := p.acks.GetAckFrame()
   151  	if ack == nil {
   152  		return nil, nil
   153  	}
   154  	encLevel, sealer := p.cryptoSetup.GetSealer()
   155  	header := p.getHeader(encLevel)
   156  	frames := []wire.Frame{ack}
   157  	raw, err := p.writeAndSealPacket(header, frames, sealer)
   158  	return &packedPacket{
   159  		header:          header,
   160  		raw:             raw,
   161  		frames:          frames,
   162  		encryptionLevel: encLevel,
   163  	}, err
   164  }
   165  
   166  // PackRetransmission packs a retransmission
   167  // For packets sent after completion of the handshake, it might happen that 2 packets have to be sent.
   168  // This can happen e.g. when a longer packet number is used in the header.
   169  func (p *packetPacker) PackRetransmission(packet *ackhandler.Packet) ([]*packedPacket, error) {
   170  	if packet.EncryptionLevel != protocol.EncryptionForwardSecure {
   171  		p, err := p.packHandshakeRetransmission(packet)
   172  		return []*packedPacket{p}, err
   173  	}
   174  
   175  	var controlFrames []wire.Frame
   176  	var streamFrames []*wire.StreamFrame
   177  	for _, f := range packet.Frames {
   178  		if sf, ok := f.(*wire.StreamFrame); ok {
   179  			sf.DataLenPresent = true
   180  			streamFrames = append(streamFrames, sf)
   181  		} else {
   182  			controlFrames = append(controlFrames, f)
   183  		}
   184  	}
   185  
   186  	var packets []*packedPacket
   187  	encLevel, sealer := p.cryptoSetup.GetSealer()
   188  	for len(controlFrames) > 0 || len(streamFrames) > 0 {
   189  		var frames []wire.Frame
   190  		var length protocol.ByteCount
   191  
   192  		header := p.getHeader(encLevel)
   193  		headerLength, err := header.GetLength(p.version)
   194  		if err != nil {
   195  			return nil, err
   196  		}
   197  		maxSize := p.maxPacketSize - protocol.ByteCount(sealer.Overhead()) - headerLength
   198  
   199  		for len(controlFrames) > 0 {
   200  			frame := controlFrames[0]
   201  			frameLen := frame.Length(p.version)
   202  			if length+frameLen > maxSize {
   203  				break
   204  			}
   205  			length += frameLen
   206  			frames = append(frames, frame)
   207  			controlFrames = controlFrames[1:]
   208  		}
   209  
   210  		for len(streamFrames) > 0 && length+protocol.MinStreamFrameSize < maxSize {
   211  			frame := streamFrames[0]
   212  			frame.DataLenPresent = false
   213  			frameToAdd := frame
   214  
   215  			sf, err := frame.MaybeSplitOffFrame(maxSize-length, p.version)
   216  			if err != nil {
   217  				return nil, err
   218  			}
   219  			if sf != nil {
   220  				frameToAdd = sf
   221  			} else {
   222  				streamFrames = streamFrames[1:]
   223  			}
   224  			frame.DataLenPresent = true
   225  			length += frameToAdd.Length(p.version)
   226  			frames = append(frames, frameToAdd)
   227  		}
   228  		if sf, ok := frames[len(frames)-1].(*wire.StreamFrame); ok {
   229  			sf.DataLenPresent = false
   230  		}
   231  		raw, err := p.writeAndSealPacket(header, frames, sealer)
   232  		if err != nil {
   233  			return nil, err
   234  		}
   235  		packets = append(packets, &packedPacket{
   236  			header:          header,
   237  			raw:             raw,
   238  			frames:          frames,
   239  			encryptionLevel: encLevel,
   240  		})
   241  	}
   242  	return packets, nil
   243  }
   244  
   245  // packHandshakeRetransmission retransmits a handshake packet, that was sent with less than forward-secure encryption
   246  func (p *packetPacker) packHandshakeRetransmission(packet *ackhandler.Packet) (*packedPacket, error) {
   247  	sealer, err := p.cryptoSetup.GetSealerWithEncryptionLevel(packet.EncryptionLevel)
   248  	if err != nil {
   249  		return nil, err
   250  	}
   251  	// make sure that the retransmission for an Initial packet is sent as an Initial packet
   252  	if packet.PacketType == protocol.PacketTypeInitial {
   253  		p.hasSentPacket = false
   254  	}
   255  	header := p.getHeader(packet.EncryptionLevel)
   256  	header.Type = packet.PacketType
   257  	raw, err := p.writeAndSealPacket(header, packet.Frames, sealer)
   258  	return &packedPacket{
   259  		header:          header,
   260  		raw:             raw,
   261  		frames:          packet.Frames,
   262  		encryptionLevel: packet.EncryptionLevel,
   263  	}, err
   264  }
   265  
   266  // PackPacket packs a new packet
   267  // the other controlFrames are sent in the next packet, but might be queued and sent in the next packet if the packet would overflow MaxPacketSize otherwise
   268  func (p *packetPacker) PackPacket() (*packedPacket, error) {
   269  	packet, err := p.maybePackCryptoPacket()
   270  	if err != nil {
   271  		return nil, err
   272  	}
   273  	if packet != nil {
   274  		return packet, nil
   275  	}
   276  	// if this is the first packet to be send, make sure it contains stream data
   277  	if !p.hasSentPacket && packet == nil {
   278  		return nil, nil
   279  	}
   280  
   281  	encLevel, sealer := p.cryptoSetup.GetSealer()
   282  
   283  	header := p.getHeader(encLevel)
   284  	headerLength, err := header.GetLength(p.version)
   285  	if err != nil {
   286  		return nil, err
   287  	}
   288  
   289  	maxSize := p.maxPacketSize - protocol.ByteCount(sealer.Overhead()) - headerLength
   290  	frames, err := p.composeNextPacket(maxSize, p.canSendData(encLevel))
   291  	if err != nil {
   292  		return nil, err
   293  	}
   294  
   295  	// Check if we have enough frames to send
   296  	if len(frames) == 0 {
   297  		return nil, nil
   298  	}
   299  	// check if this packet only contains an ACK
   300  	if !ackhandler.HasRetransmittableFrames(frames) {
   301  		if p.numNonRetransmittableAcks >= protocol.MaxNonRetransmittableAcks {
   302  			frames = append(frames, &wire.PingFrame{})
   303  			p.numNonRetransmittableAcks = 0
   304  		} else {
   305  			p.numNonRetransmittableAcks++
   306  		}
   307  	} else {
   308  		p.numNonRetransmittableAcks = 0
   309  	}
   310  
   311  	raw, err := p.writeAndSealPacket(header, frames, sealer)
   312  	if err != nil {
   313  		return nil, err
   314  	}
   315  	return &packedPacket{
   316  		header:          header,
   317  		raw:             raw,
   318  		frames:          frames,
   319  		encryptionLevel: encLevel,
   320  	}, nil
   321  }
   322  
   323  func (p *packetPacker) maybePackCryptoPacket() (*packedPacket, error) {
   324  	if !p.cryptoStream.hasData() {
   325  		return nil, nil
   326  	}
   327  	encLevel, sealer := p.cryptoSetup.GetSealerForCryptoStream()
   328  	header := p.getHeader(encLevel)
   329  	headerLength, err := header.GetLength(p.version)
   330  	if err != nil {
   331  		return nil, err
   332  	}
   333  	maxLen := p.maxPacketSize - protocol.ByteCount(sealer.Overhead()) - protocol.NonForwardSecurePacketSizeReduction - headerLength
   334  	sf, _ := p.cryptoStream.popStreamFrame(maxLen)
   335  	sf.DataLenPresent = false
   336  	frames := []wire.Frame{sf}
   337  	raw, err := p.writeAndSealPacket(header, frames, sealer)
   338  	if err != nil {
   339  		return nil, err
   340  	}
   341  	return &packedPacket{
   342  		header:          header,
   343  		raw:             raw,
   344  		frames:          frames,
   345  		encryptionLevel: encLevel,
   346  	}, nil
   347  }
   348  
   349  func (p *packetPacker) composeNextPacket(
   350  	maxFrameSize protocol.ByteCount,
   351  	canSendStreamFrames bool,
   352  ) ([]wire.Frame, error) {
   353  	var length protocol.ByteCount
   354  	var frames []wire.Frame
   355  
   356  	// ACKs need to go first, so that the sentPacketHandler will recognize them
   357  	if ack := p.acks.GetAckFrame(); ack != nil {
   358  		frames = append(frames, ack)
   359  		length += ack.Length(p.version)
   360  	}
   361  
   362  	var lengthAdded protocol.ByteCount
   363  	frames, lengthAdded = p.framer.AppendControlFrames(frames, maxFrameSize-length)
   364  	length += lengthAdded
   365  
   366  	if !canSendStreamFrames {
   367  		return frames, nil
   368  	}
   369  
   370  	// temporarily increase the maxFrameSize by the (minimum) length of the DataLen field
   371  	// this leads to a properly sized packet in all cases, since we do all the packet length calculations with STREAM frames that have the DataLen set
   372  	// however, for the last STREAM frame in the packet, we can omit the DataLen, thus yielding a packet of exactly the correct size
   373  	// the length is encoded to either 1 or 2 bytes
   374  	maxFrameSize++
   375  
   376  	frames = p.framer.AppendStreamFrames(frames, maxFrameSize-length)
   377  	if len(frames) > 0 {
   378  		lastFrame := frames[len(frames)-1]
   379  		if sf, ok := lastFrame.(*wire.StreamFrame); ok {
   380  			sf.DataLenPresent = false
   381  		}
   382  	}
   383  	return frames, nil
   384  }
   385  
   386  func (p *packetPacker) getHeader(encLevel protocol.EncryptionLevel) *wire.Header {
   387  	pnum := p.packetNumberGenerator.Peek()
   388  	packetNumberLen := p.getPacketNumberLen(pnum)
   389  
   390  	header := &wire.Header{
   391  		PacketNumber:     pnum,
   392  		PacketNumberLen:  packetNumberLen,
   393  		Version:          p.version,
   394  		DestConnectionID: p.destConnID,
   395  	}
   396  
   397  	if encLevel != protocol.EncryptionForwardSecure {
   398  		header.IsLongHeader = true
   399  		header.SrcConnectionID = p.srcConnID
   400  		// Set the payload len to maximum size.
   401  		// Since it is encoded as a varint, this guarantees us that the header will end up at most as big as GetLength() returns.
   402  		header.PayloadLen = p.maxPacketSize
   403  		if !p.hasSentPacket && p.perspective == protocol.PerspectiveClient {
   404  			header.Type = protocol.PacketTypeInitial
   405  			header.Token = p.token
   406  		} else {
   407  			header.Type = protocol.PacketTypeHandshake
   408  		}
   409  	}
   410  
   411  	return header
   412  }
   413  
   414  func (p *packetPacker) writeAndSealPacket(
   415  	header *wire.Header,
   416  	frames []wire.Frame,
   417  	sealer handshake.Sealer,
   418  ) ([]byte, error) {
   419  	raw := *getPacketBuffer()
   420  	buffer := bytes.NewBuffer(raw[:0])
   421  
   422  	// the payload length is only needed for Long Headers
   423  	if header.IsLongHeader {
   424  		if header.Type == protocol.PacketTypeInitial {
   425  			headerLen, _ := header.GetLength(p.version)
   426  			header.PayloadLen = protocol.ByteCount(protocol.MinInitialPacketSize) - headerLen
   427  		} else {
   428  			payloadLen := protocol.ByteCount(sealer.Overhead())
   429  			for _, frame := range frames {
   430  				payloadLen += frame.Length(p.version)
   431  			}
   432  			header.PayloadLen = payloadLen
   433  		}
   434  	}
   435  
   436  	if err := header.Write(buffer, p.perspective, p.version); err != nil {
   437  		return nil, err
   438  	}
   439  	payloadStartIndex := buffer.Len()
   440  
   441  	// the Initial packet needs to be padded, so the last STREAM frame must have the data length present
   442  	if header.Type == protocol.PacketTypeInitial {
   443  		lastFrame := frames[len(frames)-1]
   444  		if sf, ok := lastFrame.(*wire.StreamFrame); ok {
   445  			sf.DataLenPresent = true
   446  		}
   447  	}
   448  	for _, frame := range frames {
   449  		if err := frame.Write(buffer, p.version); err != nil {
   450  			return nil, err
   451  		}
   452  	}
   453  	// if this is an Initial packet, we need to pad it to fulfill the minimum size requirement
   454  	if header.Type == protocol.PacketTypeInitial {
   455  		paddingLen := protocol.MinInitialPacketSize - sealer.Overhead() - buffer.Len()
   456  		if paddingLen > 0 {
   457  			buffer.Write(bytes.Repeat([]byte{0}, paddingLen))
   458  		}
   459  	}
   460  
   461  	if size := protocol.ByteCount(buffer.Len() + sealer.Overhead()); size > p.maxPacketSize {
   462  		return nil, fmt.Errorf("PacketPacker BUG: packet too large (%d bytes, allowed %d bytes)", size, p.maxPacketSize)
   463  	}
   464  
   465  	raw = raw[0:buffer.Len()]
   466  	_ = sealer.Seal(raw[payloadStartIndex:payloadStartIndex], raw[payloadStartIndex:], header.PacketNumber, raw[:payloadStartIndex])
   467  	raw = raw[0 : buffer.Len()+sealer.Overhead()]
   468  
   469  	num := p.packetNumberGenerator.Pop()
   470  	if num != header.PacketNumber {
   471  		return nil, errors.New("packetPacker BUG: Peeked and Popped packet numbers do not match")
   472  	}
   473  	p.hasSentPacket = true
   474  	return raw, nil
   475  }
   476  
   477  func (p *packetPacker) canSendData(encLevel protocol.EncryptionLevel) bool {
   478  	if p.perspective == protocol.PerspectiveClient {
   479  		return encLevel >= protocol.EncryptionSecure
   480  	}
   481  	return encLevel == protocol.EncryptionForwardSecure
   482  }
   483  
   484  func (p *packetPacker) ChangeDestConnectionID(connID protocol.ConnectionID) error {
   485  	p.destConnID = connID
   486  
   487  	// [Psiphon]
   488  	return nil
   489  	// [Psiphon]
   490  }
   491  
   492  func (p *packetPacker) HandleTransportParameters(params *handshake.TransportParameters) {
   493  	if params.MaxPacketSize != 0 {
   494  		p.maxPacketSize = utils.MinByteCount(p.maxPacketSize, params.MaxPacketSize)
   495  	}
   496  }