github.com/psiphon-Labs/psiphon-tunnel-core@v2.0.28+incompatible/psiphon/common/quic/gquic-go/packet_handler_map.go (about)

     1  package gquic
     2  
     3  import (
     4  	"bytes"
     5  	"fmt"
     6  	"net"
     7  	"sync"
     8  	"time"
     9  
    10  	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/quic/gquic-go/internal/protocol"
    11  	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/quic/gquic-go/internal/utils"
    12  	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/quic/gquic-go/internal/wire"
    13  )
    14  
    15  // The packetHandlerMap stores packetHandlers, identified by connection ID.
    16  // It is used:
    17  // * by the server to store sessions
    18  // * when multiplexing outgoing connections to store clients
    19  type packetHandlerMap struct {
    20  	mutex sync.RWMutex
    21  
    22  	conn      net.PacketConn
    23  	connIDLen int
    24  
    25  	handlers map[string] /* string(ConnectionID)*/ packetHandler
    26  	server   unknownPacketHandler
    27  	closed   bool
    28  
    29  	deleteClosedSessionsAfter time.Duration
    30  
    31  	logger utils.Logger
    32  }
    33  
    34  var _ packetHandlerManager = &packetHandlerMap{}
    35  
    36  func newPacketHandlerMap(conn net.PacketConn, connIDLen int, logger utils.Logger) packetHandlerManager {
    37  	m := &packetHandlerMap{
    38  		conn:                      conn,
    39  		connIDLen:                 connIDLen,
    40  		handlers:                  make(map[string]packetHandler),
    41  		deleteClosedSessionsAfter: protocol.ClosedSessionDeleteTimeout,
    42  		logger:                    logger,
    43  	}
    44  	go m.listen()
    45  	return m
    46  }
    47  
    48  func (h *packetHandlerMap) Add(id protocol.ConnectionID, handler packetHandler) {
    49  	h.mutex.Lock()
    50  	h.handlers[string(id)] = handler
    51  	h.mutex.Unlock()
    52  }
    53  
    54  func (h *packetHandlerMap) Remove(id protocol.ConnectionID) {
    55  	h.removeByConnectionIDAsString(string(id))
    56  }
    57  
    58  func (h *packetHandlerMap) removeByConnectionIDAsString(id string) {
    59  	h.mutex.Lock()
    60  	h.handlers[id] = nil
    61  	h.mutex.Unlock()
    62  
    63  	time.AfterFunc(h.deleteClosedSessionsAfter, func() {
    64  		h.mutex.Lock()
    65  		delete(h.handlers, id)
    66  		h.mutex.Unlock()
    67  	})
    68  }
    69  
    70  func (h *packetHandlerMap) SetServer(s unknownPacketHandler) {
    71  	h.mutex.Lock()
    72  	h.server = s
    73  	h.mutex.Unlock()
    74  }
    75  
    76  func (h *packetHandlerMap) CloseServer() {
    77  	h.mutex.Lock()
    78  	h.server = nil
    79  	var wg sync.WaitGroup
    80  	for id, handler := range h.handlers {
    81  		if handler != nil && handler.GetPerspective() == protocol.PerspectiveServer {
    82  			wg.Add(1)
    83  			go func(id string, handler packetHandler) {
    84  				// session.Close() blocks until the CONNECTION_CLOSE has been sent and the run-loop has stopped
    85  				_ = handler.Close()
    86  				h.removeByConnectionIDAsString(id)
    87  				wg.Done()
    88  			}(id, handler)
    89  		}
    90  	}
    91  	h.mutex.Unlock()
    92  	wg.Wait()
    93  }
    94  
    95  func (h *packetHandlerMap) close(e error) error {
    96  	h.mutex.Lock()
    97  	if h.closed {
    98  		h.mutex.Unlock()
    99  		return nil
   100  	}
   101  	h.closed = true
   102  
   103  	var wg sync.WaitGroup
   104  	for _, handler := range h.handlers {
   105  		if handler != nil {
   106  			wg.Add(1)
   107  			go func(handler packetHandler) {
   108  				handler.destroy(e)
   109  				wg.Done()
   110  			}(handler)
   111  		}
   112  	}
   113  
   114  	// [Psiphon]
   115  	// Call h.server.setCloseError(e) outside of mutex to prevent deadlock
   116  	//
   117  	//    sync.(*RWMutex).Lock
   118  	//    [...]/lucas-clemente/quic-go.(*packetHandlerMap).CloseServer
   119  	//    [...]/lucas-clemente/quic-go.(*server).closeWithMutex
   120  	//    [...]/lucas-clemente/quic-go.(*server).closeWithError
   121  	//    [...]/lucas-clemente/quic-go.(*packetHandlerMap).close
   122  	//    [...]/lucas-clemente/quic-go.(*packetHandlerMap).listen
   123  	//
   124  	//    packetHandlerMap.CloseServer is attempting to lock the same mutex that
   125  	//    is already locked in packetHandlerMap.close, which deadlocks. As
   126  	//    packetHandlerMap and its mutex are used by all client sessions, this
   127  	//    effectively hangs the entire server.
   128  
   129  	var server unknownPacketHandler
   130  	if h.server != nil {
   131  		server = h.server
   132  	}
   133  
   134  	h.mutex.Unlock()
   135  
   136  	if server != nil {
   137  		server.closeWithError(e)
   138  	}
   139  
   140  	wg.Wait()
   141  	return nil
   142  }
   143  
   144  func (h *packetHandlerMap) listen() {
   145  	for {
   146  		data := *getPacketBuffer()
   147  		data = data[:protocol.MaxReceivePacketSize]
   148  		// The packet size should not exceed protocol.MaxReceivePacketSize bytes
   149  		// If it does, we only read a truncated packet, which will then end up undecryptable
   150  		n, addr, err := h.conn.ReadFrom(data)
   151  		if err != nil {
   152  
   153  			// [Psiphon]
   154  			// Do not unconditionally shutdown
   155  			if netErr, ok := err.(net.Error); !ok || !netErr.Temporary() {
   156  				h.close(err)
   157  				return
   158  			}
   159  
   160  		}
   161  		data = data[:n]
   162  
   163  		if err := h.handlePacket(addr, data); err != nil {
   164  			h.logger.Debugf("error handling packet from %s: %s", addr, err)
   165  		}
   166  	}
   167  }
   168  
   169  func (h *packetHandlerMap) handlePacket(addr net.Addr, data []byte) error {
   170  	rcvTime := time.Now()
   171  
   172  	r := bytes.NewReader(data)
   173  	iHdr, err := wire.ParseInvariantHeader(r, h.connIDLen)
   174  	// drop the packet if we can't parse the header
   175  	if err != nil {
   176  		return fmt.Errorf("error parsing invariant header: %s", err)
   177  	}
   178  
   179  	h.mutex.RLock()
   180  	handler, ok := h.handlers[string(iHdr.DestConnectionID)]
   181  	server := h.server
   182  	h.mutex.RUnlock()
   183  
   184  	var sentBy protocol.Perspective
   185  	var version protocol.VersionNumber
   186  	var handlePacket func(*receivedPacket)
   187  	if ok && handler == nil {
   188  		// Late packet for closed session
   189  		return nil
   190  	}
   191  	if !ok {
   192  		if server == nil { // no server set
   193  			return fmt.Errorf("received a packet with an unexpected connection ID %s", iHdr.DestConnectionID)
   194  		}
   195  		handlePacket = server.handlePacket
   196  		sentBy = protocol.PerspectiveClient
   197  		version = iHdr.Version
   198  	} else {
   199  		sentBy = handler.GetPerspective().Opposite()
   200  		version = handler.GetVersion()
   201  		handlePacket = handler.handlePacket
   202  	}
   203  
   204  	hdr, err := iHdr.Parse(r, sentBy, version)
   205  	if err != nil {
   206  		return fmt.Errorf("error parsing header: %s", err)
   207  	}
   208  	hdr.Raw = data[:len(data)-r.Len()]
   209  	packetData := data[len(data)-r.Len():]
   210  
   211  	if hdr.IsLongHeader && hdr.Version.UsesLengthInHeader() {
   212  		if protocol.ByteCount(len(packetData)) < hdr.PayloadLen {
   213  			return fmt.Errorf("packet payload (%d bytes) is smaller than the expected payload length (%d bytes)", len(packetData), hdr.PayloadLen)
   214  		}
   215  		packetData = packetData[:int(hdr.PayloadLen)]
   216  		// TODO(#1312): implement parsing of compound packets
   217  	}
   218  
   219  	handlePacket(&receivedPacket{
   220  		remoteAddr: addr,
   221  		header:     hdr,
   222  		data:       packetData,
   223  		rcvTime:    rcvTime,
   224  	})
   225  	return nil
   226  }