golang.org/x/net@v0.25.1-0.20240516223405-c87a5b62e243/quic/conn_id.go (about)

     1  // Copyright 2023 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  //go:build go1.21
     6  
     7  package quic
     8  
     9  import (
    10  	"bytes"
    11  	"crypto/rand"
    12  )
    13  
    14  // connIDState is a conn's connection IDs.
    15  type connIDState struct {
    16  	// The destination connection IDs of packets we receive are local.
    17  	// The destination connection IDs of packets we send are remote.
    18  	//
    19  	// Local IDs are usually issued by us, and remote IDs by the peer.
    20  	// The exception is the transient destination connection ID sent in
    21  	// a client's Initial packets, which is chosen by the client.
    22  	//
    23  	// These are []connID rather than []*connID to minimize allocations.
    24  	local  []connID
    25  	remote []remoteConnID
    26  
    27  	nextLocalSeq          int64
    28  	retireRemotePriorTo   int64 // largest Retire Prior To value sent by the peer
    29  	peerActiveConnIDLimit int64 // peer's active_connection_id_limit transport parameter
    30  
    31  	originalDstConnID []byte // expected original_destination_connection_id param
    32  	retrySrcConnID    []byte // expected retry_source_connection_id param
    33  
    34  	needSend bool
    35  }
    36  
    37  // A connID is a connection ID and associated metadata.
    38  type connID struct {
    39  	// cid is the connection ID itself.
    40  	cid []byte
    41  
    42  	// seq is the connection ID's sequence number:
    43  	// https://www.rfc-editor.org/rfc/rfc9000.html#section-5.1.1-1
    44  	//
    45  	// For the transient destination ID in a client's Initial packet, this is -1.
    46  	seq int64
    47  
    48  	// retired is set when the connection ID is retired.
    49  	retired bool
    50  
    51  	// send is set when the connection ID's state needs to be sent to the peer.
    52  	//
    53  	// For local IDs, this indicates a new ID that should be sent
    54  	// in a NEW_CONNECTION_ID frame.
    55  	//
    56  	// For remote IDs, this indicates a retired ID that should be sent
    57  	// in a RETIRE_CONNECTION_ID frame.
    58  	send sentVal
    59  }
    60  
    61  // A remoteConnID is a connection ID and stateless reset token.
    62  type remoteConnID struct {
    63  	connID
    64  	resetToken statelessResetToken
    65  }
    66  
    67  func (s *connIDState) initClient(c *Conn) error {
    68  	// Client chooses its initial connection ID, and sends it
    69  	// in the Source Connection ID field of the first Initial packet.
    70  	locid, err := c.newConnID(0)
    71  	if err != nil {
    72  		return err
    73  	}
    74  	s.local = append(s.local, connID{
    75  		seq: 0,
    76  		cid: locid,
    77  	})
    78  	s.nextLocalSeq = 1
    79  	c.endpoint.connsMap.updateConnIDs(func(conns *connsMap) {
    80  		conns.addConnID(c, locid)
    81  	})
    82  
    83  	// Client chooses an initial, transient connection ID for the server,
    84  	// and sends it in the Destination Connection ID field of the first Initial packet.
    85  	remid, err := c.newConnID(-1)
    86  	if err != nil {
    87  		return err
    88  	}
    89  	s.remote = append(s.remote, remoteConnID{
    90  		connID: connID{
    91  			seq: -1,
    92  			cid: remid,
    93  		},
    94  	})
    95  	s.originalDstConnID = remid
    96  	return nil
    97  }
    98  
    99  func (s *connIDState) initServer(c *Conn, cids newServerConnIDs) error {
   100  	dstConnID := cloneBytes(cids.dstConnID)
   101  	// Client-chosen, transient connection ID received in the first Initial packet.
   102  	// The server will not use this as the Source Connection ID of packets it sends,
   103  	// but remembers it because it may receive packets sent to this destination.
   104  	s.local = append(s.local, connID{
   105  		seq: -1,
   106  		cid: dstConnID,
   107  	})
   108  
   109  	// Server chooses a connection ID, and sends it in the Source Connection ID of
   110  	// the response to the clent.
   111  	locid, err := c.newConnID(0)
   112  	if err != nil {
   113  		return err
   114  	}
   115  	s.local = append(s.local, connID{
   116  		seq: 0,
   117  		cid: locid,
   118  	})
   119  	s.nextLocalSeq = 1
   120  	c.endpoint.connsMap.updateConnIDs(func(conns *connsMap) {
   121  		conns.addConnID(c, dstConnID)
   122  		conns.addConnID(c, locid)
   123  	})
   124  
   125  	// Client chose its own connection ID.
   126  	s.remote = append(s.remote, remoteConnID{
   127  		connID: connID{
   128  			seq: 0,
   129  			cid: cloneBytes(cids.srcConnID),
   130  		},
   131  	})
   132  	return nil
   133  }
   134  
   135  // srcConnID is the Source Connection ID to use in a sent packet.
   136  func (s *connIDState) srcConnID() []byte {
   137  	if s.local[0].seq == -1 && len(s.local) > 1 {
   138  		// Don't use the transient connection ID if another is available.
   139  		return s.local[1].cid
   140  	}
   141  	return s.local[0].cid
   142  }
   143  
   144  // dstConnID is the Destination Connection ID to use in a sent packet.
   145  func (s *connIDState) dstConnID() (cid []byte, ok bool) {
   146  	for i := range s.remote {
   147  		if !s.remote[i].retired {
   148  			return s.remote[i].cid, true
   149  		}
   150  	}
   151  	return nil, false
   152  }
   153  
   154  // isValidStatelessResetToken reports whether the given reset token is
   155  // associated with a non-retired connection ID which we have used.
   156  func (s *connIDState) isValidStatelessResetToken(resetToken statelessResetToken) bool {
   157  	for i := range s.remote {
   158  		// We currently only use the first available remote connection ID,
   159  		// so any other reset token is not valid.
   160  		if !s.remote[i].retired {
   161  			return s.remote[i].resetToken == resetToken
   162  		}
   163  	}
   164  	return false
   165  }
   166  
   167  // setPeerActiveConnIDLimit sets the active_connection_id_limit
   168  // transport parameter received from the peer.
   169  func (s *connIDState) setPeerActiveConnIDLimit(c *Conn, lim int64) error {
   170  	s.peerActiveConnIDLimit = lim
   171  	return s.issueLocalIDs(c)
   172  }
   173  
   174  func (s *connIDState) issueLocalIDs(c *Conn) error {
   175  	toIssue := min(int(s.peerActiveConnIDLimit), maxPeerActiveConnIDLimit)
   176  	for i := range s.local {
   177  		if s.local[i].seq != -1 && !s.local[i].retired {
   178  			toIssue--
   179  		}
   180  	}
   181  	var newIDs [][]byte
   182  	for toIssue > 0 {
   183  		cid, err := c.newConnID(s.nextLocalSeq)
   184  		if err != nil {
   185  			return err
   186  		}
   187  		newIDs = append(newIDs, cid)
   188  		s.local = append(s.local, connID{
   189  			seq: s.nextLocalSeq,
   190  			cid: cid,
   191  		})
   192  		s.local[len(s.local)-1].send.setUnsent()
   193  		s.nextLocalSeq++
   194  		s.needSend = true
   195  		toIssue--
   196  	}
   197  	c.endpoint.connsMap.updateConnIDs(func(conns *connsMap) {
   198  		for _, cid := range newIDs {
   199  			conns.addConnID(c, cid)
   200  		}
   201  	})
   202  	return nil
   203  }
   204  
   205  // validateTransportParameters verifies the original_destination_connection_id and
   206  // initial_source_connection_id transport parameters match the expected values.
   207  func (s *connIDState) validateTransportParameters(c *Conn, isRetry bool, p transportParameters) error {
   208  	// TODO: Consider returning more detailed errors, for debugging.
   209  	// Verify original_destination_connection_id matches
   210  	// the transient remote connection ID we chose (client)
   211  	// or is empty (server).
   212  	if !bytes.Equal(s.originalDstConnID, p.originalDstConnID) {
   213  		return localTransportError{
   214  			code:   errTransportParameter,
   215  			reason: "original_destination_connection_id mismatch",
   216  		}
   217  	}
   218  	s.originalDstConnID = nil // we have no further need for this
   219  	// Verify retry_source_connection_id matches the value from
   220  	// the server's Retry packet (when one was sent), or is empty.
   221  	if !bytes.Equal(p.retrySrcConnID, s.retrySrcConnID) {
   222  		return localTransportError{
   223  			code:   errTransportParameter,
   224  			reason: "retry_source_connection_id mismatch",
   225  		}
   226  	}
   227  	s.retrySrcConnID = nil // we have no further need for this
   228  	// Verify initial_source_connection_id matches the first remote connection ID.
   229  	if len(s.remote) == 0 || s.remote[0].seq != 0 {
   230  		return localTransportError{
   231  			code:   errInternal,
   232  			reason: "remote connection id missing",
   233  		}
   234  	}
   235  	if !bytes.Equal(p.initialSrcConnID, s.remote[0].cid) {
   236  		return localTransportError{
   237  			code:   errTransportParameter,
   238  			reason: "initial_source_connection_id mismatch",
   239  		}
   240  	}
   241  	if len(p.statelessResetToken) > 0 {
   242  		if c.side == serverSide {
   243  			return localTransportError{
   244  				code:   errTransportParameter,
   245  				reason: "client sent stateless_reset_token",
   246  			}
   247  		}
   248  		token := statelessResetToken(p.statelessResetToken)
   249  		s.remote[0].resetToken = token
   250  		c.endpoint.connsMap.updateConnIDs(func(conns *connsMap) {
   251  			conns.addResetToken(c, token)
   252  		})
   253  	}
   254  	return nil
   255  }
   256  
   257  // handlePacket updates the connection ID state during the handshake
   258  // (Initial and Handshake packets).
   259  func (s *connIDState) handlePacket(c *Conn, ptype packetType, srcConnID []byte) {
   260  	switch {
   261  	case ptype == packetTypeInitial && c.side == clientSide:
   262  		if len(s.remote) == 1 && s.remote[0].seq == -1 {
   263  			// We're a client connection processing the first Initial packet
   264  			// from the server. Replace the transient remote connection ID
   265  			// with the Source Connection ID from the packet.
   266  			s.remote[0] = remoteConnID{
   267  				connID: connID{
   268  					seq: 0,
   269  					cid: cloneBytes(srcConnID),
   270  				},
   271  			}
   272  		}
   273  	case ptype == packetTypeHandshake && c.side == serverSide:
   274  		if len(s.local) > 0 && s.local[0].seq == -1 && !s.local[0].retired {
   275  			// We're a server connection processing the first Handshake packet from
   276  			// the client. Discard the transient, client-chosen connection ID used
   277  			// for Initial packets; the client will never send it again.
   278  			cid := s.local[0].cid
   279  			c.endpoint.connsMap.updateConnIDs(func(conns *connsMap) {
   280  				conns.retireConnID(c, cid)
   281  			})
   282  			s.local = append(s.local[:0], s.local[1:]...)
   283  		}
   284  	}
   285  }
   286  
   287  func (s *connIDState) handleRetryPacket(srcConnID []byte) {
   288  	if len(s.remote) != 1 || s.remote[0].seq != -1 {
   289  		panic("BUG: handling retry with non-transient remote conn id")
   290  	}
   291  	s.retrySrcConnID = cloneBytes(srcConnID)
   292  	s.remote[0].cid = s.retrySrcConnID
   293  }
   294  
   295  func (s *connIDState) handleNewConnID(c *Conn, seq, retire int64, cid []byte, resetToken statelessResetToken) error {
   296  	if len(s.remote[0].cid) == 0 {
   297  		// "An endpoint that is sending packets with a zero-length
   298  		// Destination Connection ID MUST treat receipt of a NEW_CONNECTION_ID
   299  		// frame as a connection error of type PROTOCOL_VIOLATION."
   300  		// https://www.rfc-editor.org/rfc/rfc9000.html#section-19.15-6
   301  		return localTransportError{
   302  			code:   errProtocolViolation,
   303  			reason: "NEW_CONNECTION_ID from peer with zero-length DCID",
   304  		}
   305  	}
   306  
   307  	if retire > s.retireRemotePriorTo {
   308  		s.retireRemotePriorTo = retire
   309  	}
   310  
   311  	have := false // do we already have this connection ID?
   312  	active := 0
   313  	for i := range s.remote {
   314  		rcid := &s.remote[i]
   315  		if !rcid.retired && rcid.seq >= 0 && rcid.seq < s.retireRemotePriorTo {
   316  			s.retireRemote(rcid)
   317  			c.endpoint.connsMap.updateConnIDs(func(conns *connsMap) {
   318  				conns.retireResetToken(c, rcid.resetToken)
   319  			})
   320  		}
   321  		if !rcid.retired {
   322  			active++
   323  		}
   324  		if rcid.seq == seq {
   325  			if !bytes.Equal(rcid.cid, cid) {
   326  				return localTransportError{
   327  					code:   errProtocolViolation,
   328  					reason: "NEW_CONNECTION_ID does not match prior id",
   329  				}
   330  			}
   331  			have = true // yes, we've seen this sequence number
   332  		}
   333  	}
   334  
   335  	if !have {
   336  		// This is a new connection ID that we have not seen before.
   337  		//
   338  		// We could take steps to keep the list of remote connection IDs
   339  		// sorted by sequence number, but there's no particular need
   340  		// so we don't bother.
   341  		s.remote = append(s.remote, remoteConnID{
   342  			connID: connID{
   343  				seq: seq,
   344  				cid: cloneBytes(cid),
   345  			},
   346  			resetToken: resetToken,
   347  		})
   348  		if seq < s.retireRemotePriorTo {
   349  			// This ID was already retired by a previous NEW_CONNECTION_ID frame.
   350  			s.retireRemote(&s.remote[len(s.remote)-1])
   351  		} else {
   352  			active++
   353  			c.endpoint.connsMap.updateConnIDs(func(conns *connsMap) {
   354  				conns.addResetToken(c, resetToken)
   355  			})
   356  		}
   357  	}
   358  
   359  	if active > activeConnIDLimit {
   360  		// Retired connection IDs (including newly-retired ones) do not count
   361  		// against the limit.
   362  		// https://www.rfc-editor.org/rfc/rfc9000.html#section-5.1.1-5
   363  		return localTransportError{
   364  			code:   errConnectionIDLimit,
   365  			reason: "active_connection_id_limit exceeded",
   366  		}
   367  	}
   368  
   369  	// "An endpoint SHOULD limit the number of connection IDs it has retired locally
   370  	// for which RETIRE_CONNECTION_ID frames have not yet been acknowledged."
   371  	// https://www.rfc-editor.org/rfc/rfc9000#section-5.1.2-6
   372  	//
   373  	// Set a limit of four times the active_connection_id_limit for
   374  	// the total number of remote connection IDs we keep state for locally.
   375  	if len(s.remote) > 4*activeConnIDLimit {
   376  		return localTransportError{
   377  			code:   errConnectionIDLimit,
   378  			reason: "too many unacknowledged RETIRE_CONNECTION_ID frames",
   379  		}
   380  	}
   381  
   382  	return nil
   383  }
   384  
   385  // retireRemote marks a remote connection ID as retired.
   386  func (s *connIDState) retireRemote(rcid *remoteConnID) {
   387  	rcid.retired = true
   388  	rcid.send.setUnsent()
   389  	s.needSend = true
   390  }
   391  
   392  func (s *connIDState) handleRetireConnID(c *Conn, seq int64) error {
   393  	if seq >= s.nextLocalSeq {
   394  		return localTransportError{
   395  			code:   errProtocolViolation,
   396  			reason: "RETIRE_CONNECTION_ID for unissued sequence number",
   397  		}
   398  	}
   399  	for i := range s.local {
   400  		if s.local[i].seq == seq {
   401  			cid := s.local[i].cid
   402  			c.endpoint.connsMap.updateConnIDs(func(conns *connsMap) {
   403  				conns.retireConnID(c, cid)
   404  			})
   405  			s.local = append(s.local[:i], s.local[i+1:]...)
   406  			break
   407  		}
   408  	}
   409  	s.issueLocalIDs(c)
   410  	return nil
   411  }
   412  
   413  func (s *connIDState) ackOrLossNewConnectionID(pnum packetNumber, seq int64, fate packetFate) {
   414  	for i := range s.local {
   415  		if s.local[i].seq != seq {
   416  			continue
   417  		}
   418  		s.local[i].send.ackOrLoss(pnum, fate)
   419  		if fate != packetAcked {
   420  			s.needSend = true
   421  		}
   422  		return
   423  	}
   424  }
   425  
   426  func (s *connIDState) ackOrLossRetireConnectionID(pnum packetNumber, seq int64, fate packetFate) {
   427  	for i := 0; i < len(s.remote); i++ {
   428  		if s.remote[i].seq != seq {
   429  			continue
   430  		}
   431  		if fate == packetAcked {
   432  			// We have retired this connection ID, and the peer has acked.
   433  			// Discard its state completely.
   434  			s.remote = append(s.remote[:i], s.remote[i+1:]...)
   435  		} else {
   436  			// RETIRE_CONNECTION_ID frame was lost, mark for retransmission.
   437  			s.needSend = true
   438  			s.remote[i].send.ackOrLoss(pnum, fate)
   439  		}
   440  		return
   441  	}
   442  }
   443  
   444  // appendFrames appends NEW_CONNECTION_ID and RETIRE_CONNECTION_ID frames
   445  // to the current packet.
   446  //
   447  // It returns true if no more frames need appending,
   448  // false if not everything fit in the current packet.
   449  func (s *connIDState) appendFrames(c *Conn, pnum packetNumber, pto bool) bool {
   450  	if !s.needSend && !pto {
   451  		// Fast path: We don't need to send anything.
   452  		return true
   453  	}
   454  	retireBefore := int64(0)
   455  	if s.local[0].seq != -1 {
   456  		retireBefore = s.local[0].seq
   457  	}
   458  	for i := range s.local {
   459  		if !s.local[i].send.shouldSendPTO(pto) {
   460  			continue
   461  		}
   462  		if !c.w.appendNewConnectionIDFrame(
   463  			s.local[i].seq,
   464  			retireBefore,
   465  			s.local[i].cid,
   466  			c.endpoint.resetGen.tokenForConnID(s.local[i].cid),
   467  		) {
   468  			return false
   469  		}
   470  		s.local[i].send.setSent(pnum)
   471  	}
   472  	for i := range s.remote {
   473  		if !s.remote[i].send.shouldSendPTO(pto) {
   474  			continue
   475  		}
   476  		if !c.w.appendRetireConnectionIDFrame(s.remote[i].seq) {
   477  			return false
   478  		}
   479  		s.remote[i].send.setSent(pnum)
   480  	}
   481  	s.needSend = false
   482  	return true
   483  }
   484  
   485  func cloneBytes(b []byte) []byte {
   486  	n := make([]byte, len(b))
   487  	copy(n, b)
   488  	return n
   489  }
   490  
   491  func (c *Conn) newConnID(seq int64) ([]byte, error) {
   492  	if c.testHooks != nil {
   493  		return c.testHooks.newConnID(seq)
   494  	}
   495  	return newRandomConnID(seq)
   496  }
   497  
   498  func newRandomConnID(_ int64) ([]byte, error) {
   499  	// It is not necessary for connection IDs to be cryptographically secure,
   500  	// but it doesn't hurt.
   501  	id := make([]byte, connIDLen)
   502  	if _, err := rand.Read(id); err != nil {
   503  		// TODO: Surface this error as a metric or log event or something.
   504  		// rand.Read really shouldn't ever fail, but if it does, we should
   505  		// have a way to inform the user.
   506  		return nil, err
   507  	}
   508  	return id, nil
   509  }