github.com/AntonOrnatskyi/goproxy@v0.0.0-20190205095733-4526a9fa18b4/core/dst/mux.go (about)

     1  // Copyright 2014 The DST Authors. All rights reserved.
     2  // Use of this source code is governed by an MIT-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package dst
     6  
     7  import (
     8  	"fmt"
     9  	"runtime/debug"
    10  
    11  	"net"
    12  	"sync"
    13  	"time"
    14  )
    15  
    16  const (
    17  	maxIncomingRequests = 1024
    18  	maxPacketSize       = 500
    19  	handshakeTimeout    = 5 * time.Second
    20  	handshakeInterval   = 1 * time.Second
    21  )
    22  
    23  // Mux is a UDP multiplexer of DST connections.
    24  type Mux struct {
    25  	conn       net.PacketConn
    26  	packetSize int
    27  
    28  	conns      map[connectionID]*Conn
    29  	handshakes map[connectionID]chan packet
    30  	connsMut   sync.Mutex
    31  
    32  	incoming  chan *Conn
    33  	closed    chan struct{}
    34  	closeOnce sync.Once
    35  
    36  	buffers *sync.Pool
    37  }
    38  
    39  // NewMux creates a new DST Mux on top of a packet connection.
    40  func NewMux(conn net.PacketConn, packetSize int) *Mux {
    41  	if packetSize <= 0 {
    42  		packetSize = maxPacketSize
    43  	}
    44  	m := &Mux{
    45  		conn:       conn,
    46  		packetSize: packetSize,
    47  		conns:      map[connectionID]*Conn{},
    48  		handshakes: make(map[connectionID]chan packet),
    49  		incoming:   make(chan *Conn, maxIncomingRequests),
    50  		closed:     make(chan struct{}),
    51  		buffers: &sync.Pool{
    52  			New: func() interface{} {
    53  				return make([]byte, packetSize)
    54  			},
    55  		},
    56  	}
    57  
    58  	// Attempt to maximize buffer space. Start at 16 MB and work downwards 0.5
    59  	// MB at a time.
    60  
    61  	if conn, ok := conn.(*net.UDPConn); ok {
    62  		for buf := 16384 * 1024; buf >= 512*1024; buf -= 512 * 1024 {
    63  			err := conn.SetReadBuffer(buf)
    64  			if err == nil {
    65  				if debugMux {
    66  					log.Println(m, "read buffer is", buf)
    67  				}
    68  				break
    69  			}
    70  		}
    71  		for buf := 16384 * 1024; buf >= 512*1024; buf -= 512 * 1024 {
    72  			err := conn.SetWriteBuffer(buf)
    73  			if err == nil {
    74  				if debugMux {
    75  					log.Println(m, "write buffer is", buf)
    76  				}
    77  				break
    78  			}
    79  		}
    80  	}
    81  
    82  	go func() {
    83  		defer func() {
    84  			if e := recover(); e != nil {
    85  				fmt.Printf("crashed, err: %s\nstack:%s", e, string(debug.Stack()))
    86  			}
    87  		}()
    88  		m.readerLoop()
    89  	}()
    90  	return m
    91  }
    92  
    93  // Accept waits for and returns the next connection to the listener.
    94  func (m *Mux) Accept() (net.Conn, error) {
    95  	return m.AcceptDST()
    96  }
    97  
    98  // AcceptDST waits for and returns the next connection to the listener.
    99  func (m *Mux) AcceptDST() (*Conn, error) {
   100  	conn, ok := <-m.incoming
   101  	if !ok {
   102  		return nil, ErrClosedMux
   103  	}
   104  	return conn, nil
   105  }
   106  
   107  // Close closes the listener.
   108  // Any blocked Accept operations will be unblocked and return errors.
   109  func (m *Mux) Close() error {
   110  	var err error = ErrClosedMux
   111  	m.closeOnce.Do(func() {
   112  		err = m.conn.Close()
   113  		close(m.incoming)
   114  		close(m.closed)
   115  	})
   116  	return err
   117  }
   118  
   119  // Addr returns the listener's network address.
   120  func (m *Mux) Addr() net.Addr {
   121  	return m.conn.LocalAddr()
   122  }
   123  
   124  // Dial connects to the address on the named network.
   125  //
   126  // Network must be "dst".
   127  //
   128  // Addresses have the form host:port. If host is a literal IPv6 address or
   129  // host name, it must be enclosed in square brackets as in "[::1]:80",
   130  // "[ipv6-host]:http" or "[ipv6-host%zone]:80". The functions JoinHostPort and
   131  // SplitHostPort manipulate addresses in this form.
   132  //
   133  // Examples:
   134  //	Dial("dst", "12.34.56.78:80")
   135  //	Dial("dst", "google.com:http")
   136  //	Dial("dst", "[2001:db8::1]:http")
   137  //	Dial("dst", "[fe80::1%lo0]:80")
   138  func (m *Mux) Dial(network, addr string) (net.Conn, error) {
   139  	return m.DialDST(network, addr)
   140  }
   141  
   142  // Dial connects to the address on the named network.
   143  //
   144  // Network must be "dst".
   145  //
   146  // Addresses have the form host:port. If host is a literal IPv6 address or
   147  // host name, it must be enclosed in square brackets as in "[::1]:80",
   148  // "[ipv6-host]:http" or "[ipv6-host%zone]:80". The functions JoinHostPort and
   149  // SplitHostPort manipulate addresses in this form.
   150  //
   151  // Examples:
   152  //	Dial("dst", "12.34.56.78:80")
   153  //	Dial("dst", "google.com:http")
   154  //	Dial("dst", "[2001:db8::1]:http")
   155  //	Dial("dst", "[fe80::1%lo0]:80")
   156  func (m *Mux) DialDST(network, addr string) (*Conn, error) {
   157  	if network != "dst" {
   158  		return nil, ErrNotDST
   159  	}
   160  
   161  	dst, err := net.ResolveUDPAddr("udp", addr)
   162  	if err != nil {
   163  		return nil, err
   164  	}
   165  
   166  	resp := make(chan packet)
   167  
   168  	m.connsMut.Lock()
   169  	connID := m.newConnID()
   170  	m.handshakes[connID] = resp
   171  	m.connsMut.Unlock()
   172  
   173  	conn, err := m.clientHandshake(dst, connID, resp)
   174  
   175  	m.connsMut.Lock()
   176  	defer m.connsMut.Unlock()
   177  	delete(m.handshakes, connID)
   178  
   179  	if err != nil {
   180  		return nil, err
   181  	}
   182  
   183  	m.conns[connID] = conn
   184  	return conn, nil
   185  }
   186  
   187  // handshake performs the client side handshake (i.e. Dial)
   188  func (m *Mux) clientHandshake(dst net.Addr, connID connectionID, resp chan packet) (*Conn, error) {
   189  	if debugMux {
   190  		log.Printf("%v dial %v connID %v", m, dst, connID)
   191  	}
   192  
   193  	nextHandshake := time.NewTimer(0)
   194  	defer nextHandshake.Stop()
   195  
   196  	handshakeTimeout := time.NewTimer(handshakeTimeout)
   197  	defer handshakeTimeout.Stop()
   198  
   199  	var remoteCookie uint32
   200  	seqNo := randomSeqNo()
   201  
   202  	for {
   203  		select {
   204  		case <-m.closed:
   205  			// Failure. The mux has been closed.
   206  			return nil, ErrClosedConn
   207  
   208  		case <-handshakeTimeout.C:
   209  			// Handshake timeout. Close and abort.
   210  			return nil, ErrHandshakeTimeout
   211  
   212  		case <-nextHandshake.C:
   213  			// Send a handshake request.
   214  
   215  			m.write(packet{
   216  				src: connID,
   217  				dst: dst,
   218  				hdr: header{
   219  					packetType: typeHandshake,
   220  					flags:      flagRequest,
   221  					connID:     0,
   222  					sequenceNo: seqNo,
   223  					timestamp:  timestampMicros(),
   224  				},
   225  				data: handshakeData{uint32(m.packetSize), connID, remoteCookie}.marshal(),
   226  			})
   227  			nextHandshake.Reset(handshakeInterval)
   228  
   229  		case pkt := <-resp:
   230  			hd := unmarshalHandshakeData(pkt.data)
   231  
   232  			if pkt.hdr.flags&flagCookie == flagCookie {
   233  				// We should resend the handshake request with a different cookie value.
   234  				remoteCookie = hd.cookie
   235  				nextHandshake.Reset(0)
   236  			} else if pkt.hdr.flags&flagResponse == flagResponse {
   237  				// Successfull handshake response.
   238  				conn := newConn(m, dst)
   239  
   240  				conn.connID = connID
   241  				conn.remoteConnID = hd.connID
   242  				conn.nextRecvSeqNo = pkt.hdr.sequenceNo + 1
   243  				conn.packetSize = int(hd.packetSize)
   244  				if conn.packetSize > m.packetSize {
   245  					conn.packetSize = m.packetSize
   246  				}
   247  
   248  				conn.nextSeqNo = seqNo + 1
   249  
   250  				conn.start()
   251  
   252  				return conn, nil
   253  			}
   254  		}
   255  	}
   256  }
   257  
   258  func (m *Mux) readerLoop() {
   259  	buf := make([]byte, m.packetSize)
   260  	for {
   261  		buf = buf[:cap(buf)]
   262  		n, from, err := m.conn.ReadFrom(buf)
   263  		if err != nil {
   264  			m.Close()
   265  			return
   266  		}
   267  		buf = buf[:n]
   268  
   269  		hdr := unmarshalHeader(buf)
   270  
   271  		var bufCopy []byte
   272  		if len(buf) > dstHeaderLen {
   273  			bufCopy = m.buffers.Get().([]byte)[:len(buf)-dstHeaderLen]
   274  			copy(bufCopy, buf[dstHeaderLen:])
   275  		}
   276  
   277  		pkt := packet{hdr: hdr, data: bufCopy}
   278  		if debugMux {
   279  			log.Println(m, "read", pkt)
   280  		}
   281  
   282  		if hdr.packetType == typeHandshake {
   283  			m.incomingHandshake(from, hdr, bufCopy)
   284  		} else {
   285  			m.connsMut.Lock()
   286  			conn, ok := m.conns[hdr.connID]
   287  			m.connsMut.Unlock()
   288  
   289  			if ok {
   290  				conn.in <- packet{
   291  					dst:  nil,
   292  					hdr:  hdr,
   293  					data: bufCopy,
   294  				}
   295  			} else if debugMux && hdr.packetType != typeShutdown {
   296  				log.Printf("packet %v for unknown conn %v", hdr, hdr.connID)
   297  			}
   298  		}
   299  	}
   300  }
   301  
   302  func (m *Mux) incomingHandshake(from net.Addr, hdr header, data []byte) {
   303  	if hdr.connID == 0 {
   304  		// A new incoming handshake request.
   305  		m.incomingHandshakeRequest(from, hdr, data)
   306  	} else {
   307  		// A response to an ongoing handshake.
   308  		m.incomingHandshakeResponse(from, hdr, data)
   309  	}
   310  }
   311  
   312  func (m *Mux) incomingHandshakeRequest(from net.Addr, hdr header, data []byte) {
   313  	if hdr.flags&flagRequest != flagRequest {
   314  		log.Printf("Handshake pattern with flags 0x%x to connID zero", hdr.flags)
   315  		return
   316  	}
   317  
   318  	hd := unmarshalHandshakeData(data)
   319  
   320  	correctCookie := cookie(from)
   321  	if hd.cookie != correctCookie {
   322  		// Incorrect or missing SYN cookie. Send back a handshake
   323  		// with the expected one.
   324  		m.write(packet{
   325  			dst: from,
   326  			hdr: header{
   327  				packetType: typeHandshake,
   328  				flags:      flagResponse | flagCookie,
   329  				connID:     hd.connID,
   330  				timestamp:  timestampMicros(),
   331  			},
   332  			data: handshakeData{
   333  				packetSize: uint32(m.packetSize),
   334  				cookie:     correctCookie,
   335  			}.marshal(),
   336  		})
   337  		return
   338  	}
   339  
   340  	seqNo := randomSeqNo()
   341  
   342  	m.connsMut.Lock()
   343  	connID := m.newConnID()
   344  
   345  	conn := newConn(m, from)
   346  	conn.connID = connID
   347  	conn.remoteConnID = hd.connID
   348  	conn.nextSeqNo = seqNo + 1
   349  	conn.nextRecvSeqNo = hdr.sequenceNo + 1
   350  	conn.packetSize = int(hd.packetSize)
   351  	if conn.packetSize > m.packetSize {
   352  		conn.packetSize = m.packetSize
   353  	}
   354  	conn.start()
   355  
   356  	m.conns[connID] = conn
   357  	m.connsMut.Unlock()
   358  
   359  	m.write(packet{
   360  		dst: from,
   361  		hdr: header{
   362  			packetType: typeHandshake,
   363  			flags:      flagResponse,
   364  			connID:     hd.connID,
   365  			sequenceNo: seqNo,
   366  			timestamp:  timestampMicros(),
   367  		},
   368  		data: handshakeData{
   369  			connID:     conn.connID,
   370  			packetSize: uint32(conn.packetSize),
   371  		}.marshal(),
   372  	})
   373  
   374  	m.incoming <- conn
   375  }
   376  
   377  func (m *Mux) incomingHandshakeResponse(from net.Addr, hdr header, data []byte) {
   378  	m.connsMut.Lock()
   379  	handShake, ok := m.handshakes[hdr.connID]
   380  	m.connsMut.Unlock()
   381  
   382  	if ok {
   383  		// This is a response to a handshake in progress.
   384  		handShake <- packet{
   385  			dst:  nil,
   386  			hdr:  hdr,
   387  			data: data,
   388  		}
   389  	} else if debugMux && hdr.packetType != typeShutdown {
   390  		log.Printf("Handshake packet %v for unknown conn %v", hdr, hdr.connID)
   391  	}
   392  }
   393  
   394  func (m *Mux) write(pkt packet) (int, error) {
   395  	buf := m.buffers.Get().([]byte)
   396  	buf = buf[:dstHeaderLen+len(pkt.data)]
   397  	pkt.hdr.marshal(buf)
   398  	copy(buf[dstHeaderLen:], pkt.data)
   399  	if debugMux {
   400  		log.Println(m, "write", pkt)
   401  	}
   402  	n, err := m.conn.WriteTo(buf, pkt.dst)
   403  	m.buffers.Put(buf)
   404  	return n, err
   405  }
   406  
   407  func (m *Mux) String() string {
   408  	return fmt.Sprintf("Mux-%v", m.Addr())
   409  }
   410  
   411  // Find a unique connection ID
   412  func (m *Mux) newConnID() connectionID {
   413  	for {
   414  		connID := randomConnID()
   415  		if _, ok := m.conns[connID]; ok {
   416  			continue
   417  		}
   418  		if _, ok := m.handshakes[connID]; ok {
   419  			continue
   420  		}
   421  		return connID
   422  	}
   423  }
   424  
   425  func (m *Mux) removeConn(c *Conn) {
   426  	m.connsMut.Lock()
   427  	delete(m.conns, c.connID)
   428  	m.connsMut.Unlock()
   429  }