golang.org/x/net@v0.25.1-0.20240516223405-c87a5b62e243/internal/socks/client.go (about)

     1  // Copyright 2018 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package socks
     6  
     7  import (
     8  	"context"
     9  	"errors"
    10  	"io"
    11  	"net"
    12  	"strconv"
    13  	"time"
    14  )
    15  
    16  var (
    17  	noDeadline   = time.Time{}
    18  	aLongTimeAgo = time.Unix(1, 0)
    19  )
    20  
    21  func (d *Dialer) connect(ctx context.Context, c net.Conn, address string) (_ net.Addr, ctxErr error) {
    22  	host, port, err := splitHostPort(address)
    23  	if err != nil {
    24  		return nil, err
    25  	}
    26  	if deadline, ok := ctx.Deadline(); ok && !deadline.IsZero() {
    27  		c.SetDeadline(deadline)
    28  		defer c.SetDeadline(noDeadline)
    29  	}
    30  	if ctx != context.Background() {
    31  		errCh := make(chan error, 1)
    32  		done := make(chan struct{})
    33  		defer func() {
    34  			close(done)
    35  			if ctxErr == nil {
    36  				ctxErr = <-errCh
    37  			}
    38  		}()
    39  		go func() {
    40  			select {
    41  			case <-ctx.Done():
    42  				c.SetDeadline(aLongTimeAgo)
    43  				errCh <- ctx.Err()
    44  			case <-done:
    45  				errCh <- nil
    46  			}
    47  		}()
    48  	}
    49  
    50  	b := make([]byte, 0, 6+len(host)) // the size here is just an estimate
    51  	b = append(b, Version5)
    52  	if len(d.AuthMethods) == 0 || d.Authenticate == nil {
    53  		b = append(b, 1, byte(AuthMethodNotRequired))
    54  	} else {
    55  		ams := d.AuthMethods
    56  		if len(ams) > 255 {
    57  			return nil, errors.New("too many authentication methods")
    58  		}
    59  		b = append(b, byte(len(ams)))
    60  		for _, am := range ams {
    61  			b = append(b, byte(am))
    62  		}
    63  	}
    64  	if _, ctxErr = c.Write(b); ctxErr != nil {
    65  		return
    66  	}
    67  
    68  	if _, ctxErr = io.ReadFull(c, b[:2]); ctxErr != nil {
    69  		return
    70  	}
    71  	if b[0] != Version5 {
    72  		return nil, errors.New("unexpected protocol version " + strconv.Itoa(int(b[0])))
    73  	}
    74  	am := AuthMethod(b[1])
    75  	if am == AuthMethodNoAcceptableMethods {
    76  		return nil, errors.New("no acceptable authentication methods")
    77  	}
    78  	if d.Authenticate != nil {
    79  		if ctxErr = d.Authenticate(ctx, c, am); ctxErr != nil {
    80  			return
    81  		}
    82  	}
    83  
    84  	b = b[:0]
    85  	b = append(b, Version5, byte(d.cmd), 0)
    86  	if ip := net.ParseIP(host); ip != nil {
    87  		if ip4 := ip.To4(); ip4 != nil {
    88  			b = append(b, AddrTypeIPv4)
    89  			b = append(b, ip4...)
    90  		} else if ip6 := ip.To16(); ip6 != nil {
    91  			b = append(b, AddrTypeIPv6)
    92  			b = append(b, ip6...)
    93  		} else {
    94  			return nil, errors.New("unknown address type")
    95  		}
    96  	} else {
    97  		if len(host) > 255 {
    98  			return nil, errors.New("FQDN too long")
    99  		}
   100  		b = append(b, AddrTypeFQDN)
   101  		b = append(b, byte(len(host)))
   102  		b = append(b, host...)
   103  	}
   104  	b = append(b, byte(port>>8), byte(port))
   105  	if _, ctxErr = c.Write(b); ctxErr != nil {
   106  		return
   107  	}
   108  
   109  	if _, ctxErr = io.ReadFull(c, b[:4]); ctxErr != nil {
   110  		return
   111  	}
   112  	if b[0] != Version5 {
   113  		return nil, errors.New("unexpected protocol version " + strconv.Itoa(int(b[0])))
   114  	}
   115  	if cmdErr := Reply(b[1]); cmdErr != StatusSucceeded {
   116  		return nil, errors.New("unknown error " + cmdErr.String())
   117  	}
   118  	if b[2] != 0 {
   119  		return nil, errors.New("non-zero reserved field")
   120  	}
   121  	l := 2
   122  	var a Addr
   123  	switch b[3] {
   124  	case AddrTypeIPv4:
   125  		l += net.IPv4len
   126  		a.IP = make(net.IP, net.IPv4len)
   127  	case AddrTypeIPv6:
   128  		l += net.IPv6len
   129  		a.IP = make(net.IP, net.IPv6len)
   130  	case AddrTypeFQDN:
   131  		if _, err := io.ReadFull(c, b[:1]); err != nil {
   132  			return nil, err
   133  		}
   134  		l += int(b[0])
   135  	default:
   136  		return nil, errors.New("unknown address type " + strconv.Itoa(int(b[3])))
   137  	}
   138  	if cap(b) < l {
   139  		b = make([]byte, l)
   140  	} else {
   141  		b = b[:l]
   142  	}
   143  	if _, ctxErr = io.ReadFull(c, b); ctxErr != nil {
   144  		return
   145  	}
   146  	if a.IP != nil {
   147  		copy(a.IP, b)
   148  	} else {
   149  		a.Name = string(b[:len(b)-2])
   150  	}
   151  	a.Port = int(b[len(b)-2])<<8 | int(b[len(b)-1])
   152  	return &a, nil
   153  }
   154  
   155  func splitHostPort(address string) (string, int, error) {
   156  	host, port, err := net.SplitHostPort(address)
   157  	if err != nil {
   158  		return "", 0, err
   159  	}
   160  	portnum, err := strconv.Atoi(port)
   161  	if err != nil {
   162  		return "", 0, err
   163  	}
   164  	if 1 > portnum || portnum > 0xffff {
   165  		return "", 0, errors.New("port number out of range " + port)
   166  	}
   167  	return host, portnum, nil
   168  }