github.com/psiphon-labs/psiphon-tunnel-core@v2.0.28+incompatible/psiphon/TCPConn.go (about)

     1  /*
     2   * Copyright (c) 2015, Psiphon Inc.
     3   * All rights reserved.
     4   *
     5   * This program is free software: you can redistribute it and/or modify
     6   * it under the terms of the GNU General Public License as published by
     7   * the Free Software Foundation, either version 3 of the License, or
     8   * (at your option) any later version.
     9   *
    10   * This program is distributed in the hope that it will be useful,
    11   * but WITHOUT ANY WARRANTY; without even the implied warranty of
    12   * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
    13   * GNU General Public License for more details.
    14   *
    15   * You should have received a copy of the GNU General Public License
    16   * along with this program.  If not, see <http://www.gnu.org/licenses/>.
    17   *
    18   */
    19  
    20  package psiphon
    21  
    22  import (
    23  	"context"
    24  	std_errors "errors"
    25  	"net"
    26  	"sync/atomic"
    27  	"syscall"
    28  
    29  	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common"
    30  	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/errors"
    31  	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/fragmentor"
    32  	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/prng"
    33  	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/upstreamproxy"
    34  )
    35  
    36  // TCPConn is a customized TCP connection that supports the Closer interface
    37  // and which may be created using options in DialConfig, including
    38  // UpstreamProxyURL, DeviceBinder, IPv6Synthesizer, and ResolvedIPCallback.
    39  // DeviceBinder is implemented using SO_BINDTODEVICE/IP_BOUND_IF, which
    40  // requires syscall-level socket code.
    41  type TCPConn struct {
    42  	net.Conn
    43  	isClosed int32
    44  }
    45  
    46  // NewTCPDialer creates a TCP Dialer.
    47  //
    48  // Note: do not set an UpstreamProxyURL in the config when using NewTCPDialer
    49  // as a custom dialer for NewProxyAuthTransport (or http.Transport with a
    50  // ProxyUrl), as that would result in double proxy chaining.
    51  func NewTCPDialer(config *DialConfig) common.Dialer {
    52  
    53  	// Use config.CustomDialer when set. This ignores all other parameters in
    54  	// DialConfig.
    55  	if config.CustomDialer != nil {
    56  		return config.CustomDialer
    57  	}
    58  
    59  	return func(ctx context.Context, network, addr string) (net.Conn, error) {
    60  		if network != "tcp" {
    61  			return nil, errors.Tracef("%s unsupported", network)
    62  		}
    63  		return DialTCP(ctx, addr, config)
    64  	}
    65  }
    66  
    67  // DialTCP creates a new, connected TCPConn.
    68  func DialTCP(
    69  	ctx context.Context, addr string, config *DialConfig) (net.Conn, error) {
    70  
    71  	var conn net.Conn
    72  	var err error
    73  
    74  	if config.UpstreamProxyURL != "" {
    75  		conn, err = proxiedTcpDial(ctx, addr, config)
    76  	} else {
    77  		conn, err = tcpDial(ctx, addr, config)
    78  	}
    79  
    80  	if err != nil {
    81  		return nil, errors.Trace(err)
    82  	}
    83  
    84  	// Note: when an upstream proxy is used, we don't know what IP address
    85  	// was resolved, by the proxy, for that destination.
    86  	if config.ResolvedIPCallback != nil && config.UpstreamProxyURL == "" {
    87  		ipAddress := common.IPAddressFromAddr(conn.RemoteAddr())
    88  		if ipAddress != "" {
    89  			config.ResolvedIPCallback(ipAddress)
    90  		}
    91  	}
    92  
    93  	if config.FragmentorConfig.MayFragment() {
    94  		conn = fragmentor.NewConn(
    95  			config.FragmentorConfig,
    96  			func(message string) {
    97  				NoticeFragmentor(config.DiagnosticID, message)
    98  			},
    99  			conn)
   100  	}
   101  
   102  	return conn, nil
   103  }
   104  
   105  // proxiedTcpDial wraps a tcpDial call in an upstreamproxy dial.
   106  func proxiedTcpDial(
   107  	ctx context.Context, addr string, config *DialConfig) (net.Conn, error) {
   108  
   109  	interruptConns := common.NewConns()
   110  
   111  	// Note: using interruptConns to interrupt a proxy dial assumes
   112  	// that the underlying proxy code will immediately exit with an
   113  	// error when all underlying conns unexpectedly close; e.g.,
   114  	// the proxy handshake won't keep retrying to dial new conns.
   115  
   116  	dialer := func(network, addr string) (net.Conn, error) {
   117  		conn, err := tcpDial(ctx, addr, config)
   118  		if conn != nil {
   119  			if !interruptConns.Add(conn) {
   120  				err = std_errors.New("already interrupted")
   121  				conn.Close()
   122  				conn = nil
   123  			}
   124  		}
   125  		if err != nil {
   126  			return nil, errors.Trace(err)
   127  		}
   128  		return conn, nil
   129  	}
   130  
   131  	upstreamDialer := upstreamproxy.NewProxyDialFunc(
   132  		&upstreamproxy.UpstreamProxyConfig{
   133  			ForwardDialFunc: dialer,
   134  			ProxyURIString:  config.UpstreamProxyURL,
   135  			CustomHeaders:   config.CustomHeaders,
   136  		})
   137  
   138  	type upstreamDialResult struct {
   139  		conn net.Conn
   140  		err  error
   141  	}
   142  
   143  	resultChannel := make(chan upstreamDialResult)
   144  
   145  	go func() {
   146  		conn, err := upstreamDialer("tcp", addr)
   147  		if _, ok := err.(*upstreamproxy.Error); ok {
   148  			if config.UpstreamProxyErrorCallback != nil {
   149  				config.UpstreamProxyErrorCallback(err)
   150  			}
   151  		}
   152  		resultChannel <- upstreamDialResult{
   153  			conn: conn,
   154  			err:  err,
   155  		}
   156  	}()
   157  
   158  	var result upstreamDialResult
   159  
   160  	select {
   161  	case result = <-resultChannel:
   162  	case <-ctx.Done():
   163  		result.err = ctx.Err()
   164  		// Interrupt the goroutine
   165  		interruptConns.CloseAll()
   166  		<-resultChannel
   167  	}
   168  
   169  	if result.err != nil {
   170  		return nil, errors.Trace(result.err)
   171  	}
   172  
   173  	return result.conn, nil
   174  }
   175  
   176  // Close terminates a connected TCPConn or interrupts a dialing TCPConn.
   177  func (conn *TCPConn) Close() (err error) {
   178  
   179  	if !atomic.CompareAndSwapInt32(&conn.isClosed, 0, 1) {
   180  		return nil
   181  	}
   182  
   183  	return conn.Conn.Close()
   184  }
   185  
   186  // IsClosed implements the Closer iterface. The return value
   187  // indicates whether the TCPConn has been closed.
   188  func (conn *TCPConn) IsClosed() bool {
   189  	return atomic.LoadInt32(&conn.isClosed) == 1
   190  }
   191  
   192  // CloseWrite calls net.TCPConn.CloseWrite when the underlying
   193  // conn is a *net.TCPConn.
   194  func (conn *TCPConn) CloseWrite() (err error) {
   195  
   196  	if conn.IsClosed() {
   197  		return errors.TraceNew("already closed")
   198  	}
   199  
   200  	tcpConn, ok := conn.Conn.(*net.TCPConn)
   201  	if !ok {
   202  		return errors.TraceNew("conn is not a *net.TCPConn")
   203  	}
   204  
   205  	return tcpConn.CloseWrite()
   206  }
   207  
   208  func tcpDial(ctx context.Context, addr string, config *DialConfig) (net.Conn, error) {
   209  
   210  	// Get the remote IP and port, resolving a domain name if necessary
   211  	host, port, err := net.SplitHostPort(addr)
   212  	if err != nil {
   213  		return nil, errors.Trace(err)
   214  	}
   215  	if config.ResolveIP == nil {
   216  		// Fail even if we don't need a resolver for this dial: this is a code
   217  		// misconfiguration.
   218  		return nil, errors.TraceNew("missing resolver")
   219  	}
   220  	ipAddrs, err := config.ResolveIP(ctx, host)
   221  	if err != nil {
   222  		return nil, errors.Trace(err)
   223  	}
   224  	if len(ipAddrs) < 1 {
   225  		return nil, errors.TraceNew("no IP address")
   226  	}
   227  
   228  	// When configured, attempt to synthesize IPv6 addresses from
   229  	// an IPv4 addresses for compatibility on DNS64/NAT64 networks.
   230  	// If synthesize fails, try the original addresses.
   231  	if config.IPv6Synthesizer != nil {
   232  		for i, ipAddr := range ipAddrs {
   233  			if ipAddr.To4() != nil {
   234  				synthesizedIPAddress := config.IPv6Synthesizer.IPv6Synthesize(ipAddr.String())
   235  				if synthesizedIPAddress != "" {
   236  					synthesizedAddr := net.ParseIP(synthesizedIPAddress)
   237  					if synthesizedAddr != nil {
   238  						ipAddrs[i] = synthesizedAddr
   239  					}
   240  				}
   241  			}
   242  		}
   243  	}
   244  
   245  	// Iterate over a pseudorandom permutation of the destination
   246  	// IPs and attempt connections.
   247  	//
   248  	// Only continue retrying as long as the dial context is not
   249  	// done. Unlike net.Dial, we do not fractionalize the context
   250  	// deadline, as the dial is generally intended to apply to a
   251  	// single attempt. So these serial retries are most useful in
   252  	// cases of immediate failure, such as "no route to host"
   253  	// errors when a host resolves to both IPv4 and IPv6 but IPv6
   254  	// addresses are unreachable.
   255  	//
   256  	// Retries at higher levels cover other cases: e.g.,
   257  	// Controller.remoteServerListFetcher will retry its entire
   258  	// operation and tcpDial will try a new permutation; or similarly,
   259  	// Controller.establishCandidateGenerator will retry a candidate
   260  	// tunnel server dials.
   261  
   262  	permutedIndexes := prng.Perm(len(ipAddrs))
   263  
   264  	lastErr := errors.TraceNew("unknown error")
   265  
   266  	for _, index := range permutedIndexes {
   267  
   268  		dialer := &net.Dialer{
   269  			Control: func(_, _ string, c syscall.RawConn) error {
   270  				var controlErr error
   271  				err := c.Control(func(fd uintptr) {
   272  
   273  					socketFD := int(fd)
   274  
   275  					setAdditionalSocketOptions(socketFD)
   276  
   277  					if config.BPFProgramInstructions != nil {
   278  						err := setSocketBPF(config.BPFProgramInstructions, socketFD)
   279  						if err != nil {
   280  							controlErr = errors.Tracef("setSocketBPF failed: %s", err)
   281  							return
   282  						}
   283  					}
   284  
   285  					if config.DeviceBinder != nil {
   286  						_, err := config.DeviceBinder.BindToDevice(socketFD)
   287  						if err != nil {
   288  							controlErr = errors.Tracef("BindToDevice failed: %s", err)
   289  							return
   290  						}
   291  					}
   292  				})
   293  				if controlErr != nil {
   294  					return errors.Trace(controlErr)
   295  				}
   296  				return errors.Trace(err)
   297  			},
   298  		}
   299  
   300  		conn, err := dialer.DialContext(
   301  			ctx, "tcp", net.JoinHostPort(ipAddrs[index].String(), port))
   302  		if err != nil {
   303  			lastErr = errors.Trace(err)
   304  			continue
   305  		}
   306  
   307  		return &TCPConn{Conn: conn}, nil
   308  	}
   309  
   310  	return nil, lastErr
   311  }