github.com/Andyfoo/golang/x/net@v0.0.0-20190901054642-57c1bf301704/internal/sockstest/server.go (about)

     1  // Copyright 2018 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  // Package sockstest provides utilities for SOCKS testing.
     6  package sockstest
     7  
     8  import (
     9  	"errors"
    10  	"io"
    11  	"net"
    12  
    13  	"github.com/Andyfoo/golang/x/net/internal/socks"
    14  	"github.com/Andyfoo/golang/x/net/nettest"
    15  )
    16  
    17  // An AuthRequest represents an authentication request.
    18  type AuthRequest struct {
    19  	Version int
    20  	Methods []socks.AuthMethod
    21  }
    22  
    23  // ParseAuthRequest parses an authentication request.
    24  func ParseAuthRequest(b []byte) (*AuthRequest, error) {
    25  	if len(b) < 2 {
    26  		return nil, errors.New("short auth request")
    27  	}
    28  	if b[0] != socks.Version5 {
    29  		return nil, errors.New("unexpected protocol version")
    30  	}
    31  	if len(b)-2 < int(b[1]) {
    32  		return nil, errors.New("short auth request")
    33  	}
    34  	req := &AuthRequest{Version: int(b[0])}
    35  	if b[1] > 0 {
    36  		req.Methods = make([]socks.AuthMethod, b[1])
    37  		for i, m := range b[2 : 2+b[1]] {
    38  			req.Methods[i] = socks.AuthMethod(m)
    39  		}
    40  	}
    41  	return req, nil
    42  }
    43  
    44  // MarshalAuthReply returns an authentication reply in wire format.
    45  func MarshalAuthReply(ver int, m socks.AuthMethod) ([]byte, error) {
    46  	return []byte{byte(ver), byte(m)}, nil
    47  }
    48  
    49  // A CmdRequest repesents a command request.
    50  type CmdRequest struct {
    51  	Version int
    52  	Cmd     socks.Command
    53  	Addr    socks.Addr
    54  }
    55  
    56  // ParseCmdRequest parses a command request.
    57  func ParseCmdRequest(b []byte) (*CmdRequest, error) {
    58  	if len(b) < 7 {
    59  		return nil, errors.New("short cmd request")
    60  	}
    61  	if b[0] != socks.Version5 {
    62  		return nil, errors.New("unexpected protocol version")
    63  	}
    64  	if socks.Command(b[1]) != socks.CmdConnect {
    65  		return nil, errors.New("unexpected command")
    66  	}
    67  	if b[2] != 0 {
    68  		return nil, errors.New("non-zero reserved field")
    69  	}
    70  	req := &CmdRequest{Version: int(b[0]), Cmd: socks.Command(b[1])}
    71  	l := 2
    72  	off := 4
    73  	switch b[3] {
    74  	case socks.AddrTypeIPv4:
    75  		l += net.IPv4len
    76  		req.Addr.IP = make(net.IP, net.IPv4len)
    77  	case socks.AddrTypeIPv6:
    78  		l += net.IPv6len
    79  		req.Addr.IP = make(net.IP, net.IPv6len)
    80  	case socks.AddrTypeFQDN:
    81  		l += int(b[4])
    82  		off = 5
    83  	default:
    84  		return nil, errors.New("unknown address type")
    85  	}
    86  	if len(b[off:]) < l {
    87  		return nil, errors.New("short cmd request")
    88  	}
    89  	if req.Addr.IP != nil {
    90  		copy(req.Addr.IP, b[off:])
    91  	} else {
    92  		req.Addr.Name = string(b[off : off+l-2])
    93  	}
    94  	req.Addr.Port = int(b[off+l-2])<<8 | int(b[off+l-1])
    95  	return req, nil
    96  }
    97  
    98  // MarshalCmdReply returns a command reply in wire format.
    99  func MarshalCmdReply(ver int, reply socks.Reply, a *socks.Addr) ([]byte, error) {
   100  	b := make([]byte, 4)
   101  	b[0] = byte(ver)
   102  	b[1] = byte(reply)
   103  	if a.Name != "" {
   104  		if len(a.Name) > 255 {
   105  			return nil, errors.New("fqdn too long")
   106  		}
   107  		b[3] = socks.AddrTypeFQDN
   108  		b = append(b, byte(len(a.Name)))
   109  		b = append(b, a.Name...)
   110  	} else if ip4 := a.IP.To4(); ip4 != nil {
   111  		b[3] = socks.AddrTypeIPv4
   112  		b = append(b, ip4...)
   113  	} else if ip6 := a.IP.To16(); ip6 != nil {
   114  		b[3] = socks.AddrTypeIPv6
   115  		b = append(b, ip6...)
   116  	} else {
   117  		return nil, errors.New("unknown address type")
   118  	}
   119  	b = append(b, byte(a.Port>>8), byte(a.Port))
   120  	return b, nil
   121  }
   122  
   123  // A Server repesents a server for handshake testing.
   124  type Server struct {
   125  	ln net.Listener
   126  }
   127  
   128  // Addr rerurns a server address.
   129  func (s *Server) Addr() net.Addr {
   130  	return s.ln.Addr()
   131  }
   132  
   133  // TargetAddr returns a fake final destination address.
   134  //
   135  // The returned address is only valid for testing with Server.
   136  func (s *Server) TargetAddr() net.Addr {
   137  	a := s.ln.Addr()
   138  	switch a := a.(type) {
   139  	case *net.TCPAddr:
   140  		if a.IP.To4() != nil {
   141  			return &net.TCPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 5963}
   142  		}
   143  		if a.IP.To16() != nil && a.IP.To4() == nil {
   144  			return &net.TCPAddr{IP: net.IPv6loopback, Port: 5963}
   145  		}
   146  	}
   147  	return nil
   148  }
   149  
   150  // Close closes the server.
   151  func (s *Server) Close() error {
   152  	return s.ln.Close()
   153  }
   154  
   155  func (s *Server) serve(authFunc, cmdFunc func(io.ReadWriter, []byte) error) {
   156  	c, err := s.ln.Accept()
   157  	if err != nil {
   158  		return
   159  	}
   160  	defer c.Close()
   161  	go s.serve(authFunc, cmdFunc)
   162  	b := make([]byte, 512)
   163  	n, err := c.Read(b)
   164  	if err != nil {
   165  		return
   166  	}
   167  	if err := authFunc(c, b[:n]); err != nil {
   168  		return
   169  	}
   170  	n, err = c.Read(b)
   171  	if err != nil {
   172  		return
   173  	}
   174  	if err := cmdFunc(c, b[:n]); err != nil {
   175  		return
   176  	}
   177  }
   178  
   179  // NewServer returns a new server.
   180  //
   181  // The provided authFunc and cmdFunc must parse requests and return
   182  // appropriate replies to clients.
   183  func NewServer(authFunc, cmdFunc func(io.ReadWriter, []byte) error) (*Server, error) {
   184  	var err error
   185  	s := new(Server)
   186  	s.ln, err = nettest.NewLocalListener("tcp")
   187  	if err != nil {
   188  		return nil, err
   189  	}
   190  	go s.serve(authFunc, cmdFunc)
   191  	return s, nil
   192  }
   193  
   194  // NoAuthRequired handles a no-authentication-required signaling.
   195  func NoAuthRequired(rw io.ReadWriter, b []byte) error {
   196  	req, err := ParseAuthRequest(b)
   197  	if err != nil {
   198  		return err
   199  	}
   200  	b, err = MarshalAuthReply(req.Version, socks.AuthMethodNotRequired)
   201  	if err != nil {
   202  		return err
   203  	}
   204  	n, err := rw.Write(b)
   205  	if err != nil {
   206  		return err
   207  	}
   208  	if n != len(b) {
   209  		return errors.New("short write")
   210  	}
   211  	return nil
   212  }
   213  
   214  // NoProxyRequired handles a command signaling without constructing a
   215  // proxy connection to the final destination.
   216  func NoProxyRequired(rw io.ReadWriter, b []byte) error {
   217  	req, err := ParseCmdRequest(b)
   218  	if err != nil {
   219  		return err
   220  	}
   221  	req.Addr.Port += 1
   222  	if req.Addr.Name != "" {
   223  		req.Addr.Name = "boundaddr.doesnotexist"
   224  	} else if req.Addr.IP.To4() != nil {
   225  		req.Addr.IP = net.IPv4(127, 0, 0, 1)
   226  	} else {
   227  		req.Addr.IP = net.IPv6loopback
   228  	}
   229  	b, err = MarshalCmdReply(socks.Version5, socks.StatusSucceeded, &req.Addr)
   230  	if err != nil {
   231  		return err
   232  	}
   233  	n, err := rw.Write(b)
   234  	if err != nil {
   235  		return err
   236  	}
   237  	if n != len(b) {
   238  		return errors.New("short write")
   239  	}
   240  	return nil
   241  }