golang.org/x/net@v0.25.1-0.20240516223405-c87a5b62e243/internal/socks/client.go (about) 1 // Copyright 2018 The Go Authors. All rights reserved. 2 // Use of this source code is governed by a BSD-style 3 // license that can be found in the LICENSE file. 4 5 package socks 6 7 import ( 8 "context" 9 "errors" 10 "io" 11 "net" 12 "strconv" 13 "time" 14 ) 15 16 var ( 17 noDeadline = time.Time{} 18 aLongTimeAgo = time.Unix(1, 0) 19 ) 20 21 func (d *Dialer) connect(ctx context.Context, c net.Conn, address string) (_ net.Addr, ctxErr error) { 22 host, port, err := splitHostPort(address) 23 if err != nil { 24 return nil, err 25 } 26 if deadline, ok := ctx.Deadline(); ok && !deadline.IsZero() { 27 c.SetDeadline(deadline) 28 defer c.SetDeadline(noDeadline) 29 } 30 if ctx != context.Background() { 31 errCh := make(chan error, 1) 32 done := make(chan struct{}) 33 defer func() { 34 close(done) 35 if ctxErr == nil { 36 ctxErr = <-errCh 37 } 38 }() 39 go func() { 40 select { 41 case <-ctx.Done(): 42 c.SetDeadline(aLongTimeAgo) 43 errCh <- ctx.Err() 44 case <-done: 45 errCh <- nil 46 } 47 }() 48 } 49 50 b := make([]byte, 0, 6+len(host)) // the size here is just an estimate 51 b = append(b, Version5) 52 if len(d.AuthMethods) == 0 || d.Authenticate == nil { 53 b = append(b, 1, byte(AuthMethodNotRequired)) 54 } else { 55 ams := d.AuthMethods 56 if len(ams) > 255 { 57 return nil, errors.New("too many authentication methods") 58 } 59 b = append(b, byte(len(ams))) 60 for _, am := range ams { 61 b = append(b, byte(am)) 62 } 63 } 64 if _, ctxErr = c.Write(b); ctxErr != nil { 65 return 66 } 67 68 if _, ctxErr = io.ReadFull(c, b[:2]); ctxErr != nil { 69 return 70 } 71 if b[0] != Version5 { 72 return nil, errors.New("unexpected protocol version " + strconv.Itoa(int(b[0]))) 73 } 74 am := AuthMethod(b[1]) 75 if am == AuthMethodNoAcceptableMethods { 76 return nil, errors.New("no acceptable authentication methods") 77 } 78 if d.Authenticate != nil { 79 if ctxErr = d.Authenticate(ctx, c, am); ctxErr != nil { 80 return 81 } 82 } 83 84 b = b[:0] 85 b = append(b, Version5, byte(d.cmd), 0) 86 if ip := net.ParseIP(host); ip != nil { 87 if ip4 := ip.To4(); ip4 != nil { 88 b = append(b, AddrTypeIPv4) 89 b = append(b, ip4...) 90 } else if ip6 := ip.To16(); ip6 != nil { 91 b = append(b, AddrTypeIPv6) 92 b = append(b, ip6...) 93 } else { 94 return nil, errors.New("unknown address type") 95 } 96 } else { 97 if len(host) > 255 { 98 return nil, errors.New("FQDN too long") 99 } 100 b = append(b, AddrTypeFQDN) 101 b = append(b, byte(len(host))) 102 b = append(b, host...) 103 } 104 b = append(b, byte(port>>8), byte(port)) 105 if _, ctxErr = c.Write(b); ctxErr != nil { 106 return 107 } 108 109 if _, ctxErr = io.ReadFull(c, b[:4]); ctxErr != nil { 110 return 111 } 112 if b[0] != Version5 { 113 return nil, errors.New("unexpected protocol version " + strconv.Itoa(int(b[0]))) 114 } 115 if cmdErr := Reply(b[1]); cmdErr != StatusSucceeded { 116 return nil, errors.New("unknown error " + cmdErr.String()) 117 } 118 if b[2] != 0 { 119 return nil, errors.New("non-zero reserved field") 120 } 121 l := 2 122 var a Addr 123 switch b[3] { 124 case AddrTypeIPv4: 125 l += net.IPv4len 126 a.IP = make(net.IP, net.IPv4len) 127 case AddrTypeIPv6: 128 l += net.IPv6len 129 a.IP = make(net.IP, net.IPv6len) 130 case AddrTypeFQDN: 131 if _, err := io.ReadFull(c, b[:1]); err != nil { 132 return nil, err 133 } 134 l += int(b[0]) 135 default: 136 return nil, errors.New("unknown address type " + strconv.Itoa(int(b[3]))) 137 } 138 if cap(b) < l { 139 b = make([]byte, l) 140 } else { 141 b = b[:l] 142 } 143 if _, ctxErr = io.ReadFull(c, b); ctxErr != nil { 144 return 145 } 146 if a.IP != nil { 147 copy(a.IP, b) 148 } else { 149 a.Name = string(b[:len(b)-2]) 150 } 151 a.Port = int(b[len(b)-2])<<8 | int(b[len(b)-1]) 152 return &a, nil 153 } 154 155 func splitHostPort(address string) (string, int, error) { 156 host, port, err := net.SplitHostPort(address) 157 if err != nil { 158 return "", 0, err 159 } 160 portnum, err := strconv.Atoi(port) 161 if err != nil { 162 return "", 0, err 163 } 164 if 1 > portnum || portnum > 0xffff { 165 return "", 0, errors.New("port number out of range " + port) 166 } 167 return host, portnum, nil 168 }