golang.org/x/build@v0.0.0-20240506185731-218518f32b70/revdial/v2/revdial.go (about)

     1  // Copyright 2019 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  // Package revdial implements a Dialer and Listener which work together
     6  // to turn an accepted connection (for instance, a Hijacked HTTP request) into
     7  // a Dialer which can then create net.Conns connecting back to the original
     8  // dialer, which then gets a net.Listener accepting those conns.
     9  //
    10  // This is basically a very minimal SOCKS5 client & server.
    11  //
    12  // The motivation is that sometimes you want to run a server on a
    13  // machine deep inside a NAT. Rather than connecting to the machine
    14  // directly (which you can't, because of the NAT), you have the
    15  // sequestered machine connect out to a public machine. Both sides
    16  // then use revdial and the public machine can become a client for the
    17  // NATed machine.
    18  package revdial
    19  
    20  import (
    21  	"bufio"
    22  	"context"
    23  	"crypto/rand"
    24  	"encoding/json"
    25  	"errors"
    26  	"fmt"
    27  	"io"
    28  	"log"
    29  	"net"
    30  	"net/http"
    31  	"net/url"
    32  	"strings"
    33  	"sync"
    34  	"time"
    35  )
    36  
    37  // dialerUniqParam is the parameter name of the GET URL form value
    38  // containing the Dialer's random unique ID.
    39  const dialerUniqParam = "revdial.dialer"
    40  
    41  // The Dialer can create new connections.
    42  type Dialer struct {
    43  	conn       net.Conn // hijacked client conn
    44  	path       string   // e.g. "/revdial"
    45  	uniqID     string
    46  	pickupPath string // path + uniqID: "/revdial?revdial.dialer="+uniqID
    47  
    48  	incomingConn chan net.Conn
    49  	pickupFailed chan error
    50  	connReady    chan bool
    51  	donec        chan struct{}
    52  	closeOnce    sync.Once
    53  }
    54  
    55  var (
    56  	dmapMu  sync.Mutex
    57  	dialers = map[string]*Dialer{}
    58  )
    59  
    60  // NewDialer returns the side of the connection which will initiate
    61  // new connections. This will typically be the side which did the HTTP
    62  // Hijack. The connection is (typically) the hijacked HTTP client
    63  // connection. The connPath is the HTTP path and optional query (but
    64  // without scheme or host) on the dialer where the ConnHandler is
    65  // mounted.
    66  func NewDialer(c net.Conn, connPath string) *Dialer {
    67  	d := &Dialer{
    68  		path:         connPath,
    69  		uniqID:       newUniqID(),
    70  		conn:         c,
    71  		donec:        make(chan struct{}),
    72  		connReady:    make(chan bool),
    73  		incomingConn: make(chan net.Conn),
    74  		pickupFailed: make(chan error),
    75  	}
    76  
    77  	join := "?"
    78  	if strings.Contains(connPath, "?") {
    79  		join = "&"
    80  	}
    81  	d.pickupPath = connPath + join + dialerUniqParam + "=" + d.uniqID
    82  	d.register()
    83  	go d.serve()
    84  	return d
    85  }
    86  
    87  func newUniqID() string {
    88  	buf := make([]byte, 16)
    89  	rand.Read(buf)
    90  	return fmt.Sprintf("%x", buf)
    91  }
    92  
    93  func (d *Dialer) register() {
    94  	dmapMu.Lock()
    95  	defer dmapMu.Unlock()
    96  	dialers[d.uniqID] = d
    97  }
    98  
    99  func (d *Dialer) unregister() {
   100  	dmapMu.Lock()
   101  	defer dmapMu.Unlock()
   102  	delete(dialers, d.uniqID)
   103  }
   104  
   105  // Done returns a channel which is closed when d is closed (either by
   106  // this process on purpose, by a local error, or close or error from
   107  // the peer).
   108  func (d *Dialer) Done() <-chan struct{} { return d.donec }
   109  
   110  // Close closes the Dialer.
   111  func (d *Dialer) Close() error {
   112  	d.closeOnce.Do(d.close)
   113  	return nil
   114  }
   115  
   116  func (d *Dialer) close() {
   117  	d.unregister()
   118  	d.conn.Close()
   119  	close(d.donec)
   120  }
   121  
   122  // Dial creates a new connection back to the Listener.
   123  func (d *Dialer) Dial(ctx context.Context) (net.Conn, error) {
   124  	// First, tell serve that we want a connection:
   125  	select {
   126  	case d.connReady <- true:
   127  	case <-d.donec:
   128  		return nil, errors.New("revdial.Dialer closed")
   129  	case <-ctx.Done():
   130  		return nil, ctx.Err()
   131  	}
   132  
   133  	// Then pick it up:
   134  	select {
   135  	case c := <-d.incomingConn:
   136  		return c, nil
   137  	case err := <-d.pickupFailed:
   138  		return nil, err
   139  	case <-d.donec:
   140  		return nil, errors.New("revdial.Dialer closed")
   141  	case <-ctx.Done():
   142  		return nil, ctx.Err()
   143  	}
   144  }
   145  
   146  func (d *Dialer) matchConn(c net.Conn) {
   147  	select {
   148  	case d.incomingConn <- c:
   149  	case <-d.donec:
   150  	}
   151  }
   152  
   153  // serve blocks and runs the control message loop, keeping the peer
   154  // alive and notifying the peer when new connections are available.
   155  func (d *Dialer) serve() error {
   156  	defer d.Close()
   157  	go func() {
   158  		defer d.Close()
   159  		br := bufio.NewReader(d.conn)
   160  		for {
   161  			line, err := br.ReadSlice('\n')
   162  			if err != nil {
   163  				return
   164  			}
   165  			var msg controlMsg
   166  			if err := json.Unmarshal(line, &msg); err != nil {
   167  				log.Printf("revdial.Dialer read invalid JSON: %q: %v", line, err)
   168  				return
   169  			}
   170  			switch msg.Command {
   171  			case "pickup-failed":
   172  				err := fmt.Errorf("revdial listener failed to pick up connection: %v", msg.Err)
   173  				select {
   174  				case d.pickupFailed <- err:
   175  				case <-d.donec:
   176  					return
   177  				}
   178  			}
   179  		}
   180  	}()
   181  	for {
   182  		if err := d.sendMessage(controlMsg{Command: "keep-alive"}); err != nil {
   183  			return err
   184  		}
   185  
   186  		t := time.NewTimer(30 * time.Second)
   187  		select {
   188  		case <-t.C:
   189  			continue
   190  		case <-d.connReady:
   191  			t.Stop()
   192  			if err := d.sendMessage(controlMsg{
   193  				Command:  "conn-ready",
   194  				ConnPath: d.pickupPath,
   195  			}); err != nil {
   196  				return err
   197  			}
   198  		case <-d.donec:
   199  			t.Stop()
   200  			return errors.New("revdial.Dialer closed")
   201  		}
   202  	}
   203  }
   204  
   205  func (d *Dialer) sendMessage(m controlMsg) error {
   206  	j, _ := json.Marshal(m)
   207  	d.conn.SetWriteDeadline(time.Now().Add(10 * time.Second))
   208  	j = append(j, '\n')
   209  	_, err := d.conn.Write(j)
   210  	d.conn.SetWriteDeadline(time.Time{})
   211  	return err
   212  }
   213  
   214  // NewListener returns a new Listener, accepting connections which
   215  // arrive from the provided server connection, which should be after
   216  // any necessary authentication (usually after an HTTP exchange).
   217  //
   218  // The provided dialServer func is responsible for connecting back to
   219  // the server and doing TLS setup.
   220  func NewListener(serverConn net.Conn, dialServer func(context.Context) (net.Conn, error)) *Listener {
   221  	ln := &Listener{
   222  		sc:    serverConn,
   223  		dial:  dialServer,
   224  		connc: make(chan net.Conn, 8), // arbitrary
   225  		donec: make(chan struct{}),
   226  	}
   227  	go ln.run()
   228  	return ln
   229  }
   230  
   231  var _ net.Listener = (*Listener)(nil)
   232  
   233  // Listener is a net.Listener, returning new connections which arrive
   234  // from a corresponding Dialer.
   235  type Listener struct {
   236  	sc     net.Conn
   237  	connc  chan net.Conn
   238  	donec  chan struct{}
   239  	dial   func(context.Context) (net.Conn, error)
   240  	writec chan<- []byte
   241  
   242  	mu      sync.Mutex // guards below, closing connc, and writing to rw
   243  	readErr error
   244  	closed  bool
   245  }
   246  
   247  type controlMsg struct {
   248  	Command  string `json:"command,omitempty"`  // "keep-alive", "conn-ready", "pickup-failed"
   249  	ConnPath string `json:"connPath,omitempty"` // conn pick-up URL path for "conn-url", "pickup-failed"
   250  	Err      string `json:"err,omitempty"`
   251  }
   252  
   253  // run reads control messages from the public server forever until the connection dies, which
   254  // then closes the listener.
   255  func (ln *Listener) run() {
   256  	defer ln.Close()
   257  
   258  	// Write loop
   259  	writec := make(chan []byte, 8)
   260  	ln.writec = writec
   261  	go func() {
   262  		for {
   263  			select {
   264  			case <-ln.donec:
   265  				return
   266  			case msg := <-writec:
   267  				if _, err := ln.sc.Write(msg); err != nil {
   268  					log.Printf("revdial.Listener: error writing message to server: %v", err)
   269  					ln.Close()
   270  					return
   271  				}
   272  			}
   273  		}
   274  	}()
   275  
   276  	// Read loop
   277  	br := bufio.NewReader(ln.sc)
   278  	for {
   279  		line, err := br.ReadSlice('\n')
   280  		if err != nil {
   281  			return
   282  		}
   283  		var msg controlMsg
   284  		if err := json.Unmarshal(line, &msg); err != nil {
   285  			log.Printf("revdial.Listener read invalid JSON: %q: %v", line, err)
   286  			return
   287  		}
   288  		switch msg.Command {
   289  		case "keep-alive":
   290  			// Occasional no-op message from server to keep
   291  			// us alive through NAT timeouts.
   292  		case "conn-ready":
   293  			go ln.grabConn(msg.ConnPath)
   294  		default:
   295  			// Ignore unknown messages
   296  		}
   297  	}
   298  }
   299  
   300  func (ln *Listener) sendMessage(m controlMsg) {
   301  	j, _ := json.Marshal(m)
   302  	j = append(j, '\n')
   303  	ln.writec <- j
   304  }
   305  
   306  func (ln *Listener) grabConn(path string) {
   307  	ctx, cancel := context.WithTimeout(context.Background(), 20*time.Second)
   308  	defer cancel()
   309  	c, err := ln.dial(ctx)
   310  	if err != nil {
   311  		ln.sendMessage(controlMsg{Command: "pickup-failed", ConnPath: path, Err: err.Error()})
   312  		return
   313  	}
   314  	failPickup := func(err error) {
   315  		c.Close()
   316  		log.Printf("revdial.Listener: failed to pick up connection to %s: %v", path, err)
   317  		ln.sendMessage(controlMsg{Command: "pickup-failed", ConnPath: path, Err: err.Error()})
   318  	}
   319  	bufr := bufio.NewReader(c)
   320  
   321  	success := false
   322  	const maxRedirects = 2
   323  	for i := 0; i < maxRedirects; i++ {
   324  		req, _ := http.NewRequest("GET", path, nil)
   325  		if err := req.Write(c); err != nil {
   326  			failPickup(err)
   327  			return
   328  		}
   329  		path, err = ReadProtoSwitchOrRedirect(bufr, req)
   330  		if err != nil {
   331  			failPickup(fmt.Errorf("switch failed: %v", err))
   332  			return
   333  		}
   334  		if path == "" {
   335  			success = true
   336  			break
   337  		}
   338  	}
   339  	if !success {
   340  		failPickup(errors.New("too many redirects"))
   341  		return
   342  	}
   343  
   344  	select {
   345  	case ln.connc <- c:
   346  	case <-ln.donec:
   347  	}
   348  }
   349  
   350  // Closed reports whether the listener has been closed.
   351  func (ln *Listener) Closed() bool {
   352  	ln.mu.Lock()
   353  	defer ln.mu.Unlock()
   354  	return ln.closed
   355  }
   356  
   357  // Accept blocks and returns a new connection, or an error.
   358  func (ln *Listener) Accept() (net.Conn, error) {
   359  	c, ok := <-ln.connc
   360  	if !ok {
   361  		ln.mu.Lock()
   362  		err, closed := ln.readErr, ln.closed
   363  		ln.mu.Unlock()
   364  		if err != nil && !closed {
   365  			return nil, fmt.Errorf("revdial: Listener closed; %v", err)
   366  		}
   367  		return nil, ErrListenerClosed
   368  	}
   369  	return c, nil
   370  }
   371  
   372  // ErrListenerClosed is returned by Accept after Close has been called.
   373  var ErrListenerClosed = errors.New("revdial: Listener closed")
   374  
   375  // Close closes the Listener, making future Accept calls return an
   376  // error.
   377  func (ln *Listener) Close() error {
   378  	ln.mu.Lock()
   379  	defer ln.mu.Unlock()
   380  	if ln.closed {
   381  		return nil
   382  	}
   383  	go ln.sc.Close()
   384  	ln.closed = true
   385  	close(ln.connc)
   386  	close(ln.donec)
   387  	return nil
   388  }
   389  
   390  // Addr returns a dummy address. This exists only to conform to the
   391  // net.Listener interface.
   392  func (ln *Listener) Addr() net.Addr { return fakeAddr{} }
   393  
   394  type fakeAddr struct{}
   395  
   396  func (fakeAddr) Network() string { return "revdial" }
   397  func (fakeAddr) String() string  { return "revdialconn" }
   398  
   399  // ConnHandler returns the HTTP handler that needs to be mounted somewhere
   400  // that the Listeners can dial out and get to. A dialer to connect to it
   401  // is given to NewListener and the path to reach it is given to NewDialer
   402  // to use in messages to the listener.
   403  func ConnHandler() http.Handler {
   404  	return http.HandlerFunc(connHandler)
   405  }
   406  
   407  func connHandler(w http.ResponseWriter, r *http.Request) {
   408  	if r.TLS == nil {
   409  		http.Error(w, "handler requires TLS", http.StatusInternalServerError)
   410  		return
   411  	}
   412  	if r.Method != "GET" {
   413  		w.Header().Set("Allow", "GET")
   414  		http.Error(w, "expected GET request to revdial conn handler", http.StatusMethodNotAllowed)
   415  		return
   416  	}
   417  	dialerUniq := r.FormValue(dialerUniqParam)
   418  
   419  	dmapMu.Lock()
   420  	d, ok := dialers[dialerUniq]
   421  	dmapMu.Unlock()
   422  	if !ok {
   423  		http.Error(w, "unknown dialer", http.StatusBadRequest)
   424  		return
   425  	}
   426  
   427  	conn, _, err := w.(http.Hijacker).Hijack()
   428  	if err != nil {
   429  		http.Error(w, err.Error(), http.StatusInternalServerError)
   430  		return
   431  	}
   432  	(&http.Response{StatusCode: http.StatusSwitchingProtocols, Proto: "HTTP/1.1"}).Write(conn)
   433  	d.matchConn(conn)
   434  }
   435  
   436  // checkRelativeURL verifies that URL s does not change scheme or host.
   437  func checkRelativeURL(s string) error {
   438  	u, err := url.Parse(s)
   439  	if err != nil {
   440  		return err
   441  	}
   442  
   443  	// A relative URL should have no schema or host.
   444  	if u.Scheme != "" {
   445  		return fmt.Errorf("URL %q is not relative: contains scheme", s)
   446  	}
   447  	if u.Host != "" {
   448  		return fmt.Errorf("URL %q is not relative: contains host", s)
   449  	}
   450  	return nil
   451  }
   452  
   453  // ReadProtoSwitchOrRedirect is a helper for completing revdial protocol switch
   454  // requests. If the response indicates successful switch, nothing is returned.
   455  // If the response indicates a redirect, the new location is returned.
   456  func ReadProtoSwitchOrRedirect(r *bufio.Reader, req *http.Request) (location string, err error) {
   457  	resp, err := http.ReadResponse(r, req)
   458  	if err != nil {
   459  		return "", fmt.Errorf("error reading response: %v", err)
   460  	}
   461  	switch resp.StatusCode {
   462  	case http.StatusSwitchingProtocols:
   463  		// Success! Don't read body, as caller may want it.
   464  		return "", nil
   465  	case http.StatusTemporaryRedirect:
   466  		// Redirect. Discard body.
   467  		msg, _ := io.ReadAll(resp.Body)
   468  		location := resp.Header.Get("Location")
   469  		if location == "" {
   470  			return "", fmt.Errorf("redirect missing Location header; got %+v:\n\t%s", resp, msg)
   471  		}
   472  		if err := checkRelativeURL(location); err != nil {
   473  			return "", fmt.Errorf("redirect Location must be relative: %w", err)
   474  		}
   475  		// Retry at new location.
   476  		return location, nil
   477  	default:
   478  		msg, _ := io.ReadAll(resp.Body)
   479  		return "", fmt.Errorf("want HTTP status 101 or 307; got %v:\n\t%s", resp.Status, msg)
   480  	}
   481  }