golang.org/x/net@v0.25.1-0.20240516223405-c87a5b62e243/internal/sockstest/server.go (about) 1 // Copyright 2018 The Go Authors. All rights reserved. 2 // Use of this source code is governed by a BSD-style 3 // license that can be found in the LICENSE file. 4 5 // Package sockstest provides utilities for SOCKS testing. 6 package sockstest 7 8 import ( 9 "errors" 10 "io" 11 "net" 12 13 "golang.org/x/net/internal/socks" 14 "golang.org/x/net/nettest" 15 ) 16 17 // An AuthRequest represents an authentication request. 18 type AuthRequest struct { 19 Version int 20 Methods []socks.AuthMethod 21 } 22 23 // ParseAuthRequest parses an authentication request. 24 func ParseAuthRequest(b []byte) (*AuthRequest, error) { 25 if len(b) < 2 { 26 return nil, errors.New("short auth request") 27 } 28 if b[0] != socks.Version5 { 29 return nil, errors.New("unexpected protocol version") 30 } 31 if len(b)-2 < int(b[1]) { 32 return nil, errors.New("short auth request") 33 } 34 req := &AuthRequest{Version: int(b[0])} 35 if b[1] > 0 { 36 req.Methods = make([]socks.AuthMethod, b[1]) 37 for i, m := range b[2 : 2+b[1]] { 38 req.Methods[i] = socks.AuthMethod(m) 39 } 40 } 41 return req, nil 42 } 43 44 // MarshalAuthReply returns an authentication reply in wire format. 45 func MarshalAuthReply(ver int, m socks.AuthMethod) ([]byte, error) { 46 return []byte{byte(ver), byte(m)}, nil 47 } 48 49 // A CmdRequest represents a command request. 50 type CmdRequest struct { 51 Version int 52 Cmd socks.Command 53 Addr socks.Addr 54 } 55 56 // ParseCmdRequest parses a command request. 57 func ParseCmdRequest(b []byte) (*CmdRequest, error) { 58 if len(b) < 7 { 59 return nil, errors.New("short cmd request") 60 } 61 if b[0] != socks.Version5 { 62 return nil, errors.New("unexpected protocol version") 63 } 64 if socks.Command(b[1]) != socks.CmdConnect { 65 return nil, errors.New("unexpected command") 66 } 67 if b[2] != 0 { 68 return nil, errors.New("non-zero reserved field") 69 } 70 req := &CmdRequest{Version: int(b[0]), Cmd: socks.Command(b[1])} 71 l := 2 72 off := 4 73 switch b[3] { 74 case socks.AddrTypeIPv4: 75 l += net.IPv4len 76 req.Addr.IP = make(net.IP, net.IPv4len) 77 case socks.AddrTypeIPv6: 78 l += net.IPv6len 79 req.Addr.IP = make(net.IP, net.IPv6len) 80 case socks.AddrTypeFQDN: 81 l += int(b[4]) 82 off = 5 83 default: 84 return nil, errors.New("unknown address type") 85 } 86 if len(b[off:]) < l { 87 return nil, errors.New("short cmd request") 88 } 89 if req.Addr.IP != nil { 90 copy(req.Addr.IP, b[off:]) 91 } else { 92 req.Addr.Name = string(b[off : off+l-2]) 93 } 94 req.Addr.Port = int(b[off+l-2])<<8 | int(b[off+l-1]) 95 return req, nil 96 } 97 98 // MarshalCmdReply returns a command reply in wire format. 99 func MarshalCmdReply(ver int, reply socks.Reply, a *socks.Addr) ([]byte, error) { 100 b := make([]byte, 4) 101 b[0] = byte(ver) 102 b[1] = byte(reply) 103 if a.Name != "" { 104 if len(a.Name) > 255 { 105 return nil, errors.New("fqdn too long") 106 } 107 b[3] = socks.AddrTypeFQDN 108 b = append(b, byte(len(a.Name))) 109 b = append(b, a.Name...) 110 } else if ip4 := a.IP.To4(); ip4 != nil { 111 b[3] = socks.AddrTypeIPv4 112 b = append(b, ip4...) 113 } else if ip6 := a.IP.To16(); ip6 != nil { 114 b[3] = socks.AddrTypeIPv6 115 b = append(b, ip6...) 116 } else { 117 return nil, errors.New("unknown address type") 118 } 119 b = append(b, byte(a.Port>>8), byte(a.Port)) 120 return b, nil 121 } 122 123 // A Server represents a server for handshake testing. 124 type Server struct { 125 ln net.Listener 126 } 127 128 // Addr returns a server address. 129 func (s *Server) Addr() net.Addr { 130 return s.ln.Addr() 131 } 132 133 // TargetAddr returns a fake final destination address. 134 // 135 // The returned address is only valid for testing with Server. 136 func (s *Server) TargetAddr() net.Addr { 137 a := s.ln.Addr() 138 switch a := a.(type) { 139 case *net.TCPAddr: 140 if a.IP.To4() != nil { 141 return &net.TCPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 5963} 142 } 143 if a.IP.To16() != nil && a.IP.To4() == nil { 144 return &net.TCPAddr{IP: net.IPv6loopback, Port: 5963} 145 } 146 } 147 return nil 148 } 149 150 // Close closes the server. 151 func (s *Server) Close() error { 152 return s.ln.Close() 153 } 154 155 func (s *Server) serve(authFunc, cmdFunc func(io.ReadWriter, []byte) error) { 156 c, err := s.ln.Accept() 157 if err != nil { 158 return 159 } 160 defer c.Close() 161 go s.serve(authFunc, cmdFunc) 162 b := make([]byte, 512) 163 n, err := c.Read(b) 164 if err != nil { 165 return 166 } 167 if err := authFunc(c, b[:n]); err != nil { 168 return 169 } 170 n, err = c.Read(b) 171 if err != nil { 172 return 173 } 174 if err := cmdFunc(c, b[:n]); err != nil { 175 return 176 } 177 } 178 179 // NewServer returns a new server. 180 // 181 // The provided authFunc and cmdFunc must parse requests and return 182 // appropriate replies to clients. 183 func NewServer(authFunc, cmdFunc func(io.ReadWriter, []byte) error) (*Server, error) { 184 var err error 185 s := new(Server) 186 s.ln, err = nettest.NewLocalListener("tcp") 187 if err != nil { 188 return nil, err 189 } 190 go s.serve(authFunc, cmdFunc) 191 return s, nil 192 } 193 194 // NoAuthRequired handles a no-authentication-required signaling. 195 func NoAuthRequired(rw io.ReadWriter, b []byte) error { 196 req, err := ParseAuthRequest(b) 197 if err != nil { 198 return err 199 } 200 b, err = MarshalAuthReply(req.Version, socks.AuthMethodNotRequired) 201 if err != nil { 202 return err 203 } 204 n, err := rw.Write(b) 205 if err != nil { 206 return err 207 } 208 if n != len(b) { 209 return errors.New("short write") 210 } 211 return nil 212 } 213 214 // NoProxyRequired handles a command signaling without constructing a 215 // proxy connection to the final destination. 216 func NoProxyRequired(rw io.ReadWriter, b []byte) error { 217 req, err := ParseCmdRequest(b) 218 if err != nil { 219 return err 220 } 221 req.Addr.Port += 1 222 if req.Addr.Name != "" { 223 req.Addr.Name = "boundaddr.doesnotexist" 224 } else if req.Addr.IP.To4() != nil { 225 req.Addr.IP = net.IPv4(127, 0, 0, 1) 226 } else { 227 req.Addr.IP = net.IPv6loopback 228 } 229 b, err = MarshalCmdReply(socks.Version5, socks.StatusSucceeded, &req.Addr) 230 if err != nil { 231 return err 232 } 233 n, err := rw.Write(b) 234 if err != nil { 235 return err 236 } 237 if n != len(b) { 238 return errors.New("short write") 239 } 240 return nil 241 }