github.com/sagernet/sing@v0.4.0-beta.19.0.20240518125136-f67a0988a636/protocol/socks/socks4/protocol.go (about)

     1  package socks4
     2  
     3  import (
     4  	"bytes"
     5  	"encoding/binary"
     6  	"io"
     7  	"net/netip"
     8  
     9  	"github.com/sagernet/sing/common"
    10  	"github.com/sagernet/sing/common/buf"
    11  	E "github.com/sagernet/sing/common/exceptions"
    12  	M "github.com/sagernet/sing/common/metadata"
    13  	"github.com/sagernet/sing/common/rw"
    14  )
    15  
    16  const (
    17  	Version byte = 4
    18  
    19  	CommandConnect byte = 1
    20  	CommandBind    byte = 2
    21  
    22  	ReplyCodeGranted                     byte = 90
    23  	ReplyCodeRejectedOrFailed            byte = 91
    24  	ReplyCodeCannotConnectToIdentd       byte = 92
    25  	ReplyCodeIdentdReportDifferentUserID byte = 93
    26  )
    27  
    28  type Request struct {
    29  	Command     byte
    30  	Destination M.Socksaddr
    31  	Username    string
    32  }
    33  
    34  func ReadRequest(reader io.Reader) (request Request, err error) {
    35  	version, err := rw.ReadByte(reader)
    36  	if err != nil {
    37  		return
    38  	}
    39  	if version != 4 {
    40  		err = E.New("excepted socks version 4, got ", version)
    41  		return
    42  	}
    43  	return ReadRequest0(reader)
    44  }
    45  
    46  func ReadRequest0(reader io.Reader) (request Request, err error) {
    47  	request.Command, err = rw.ReadByte(reader)
    48  	if err != nil {
    49  		return
    50  	}
    51  	err = binary.Read(reader, binary.BigEndian, &request.Destination.Port)
    52  	if err != nil {
    53  		return
    54  	}
    55  	var dstIP [4]byte
    56  	_, err = io.ReadFull(reader, dstIP[:])
    57  	if err != nil {
    58  		return
    59  	}
    60  	var readHostName bool
    61  	if dstIP[0] == 0 && dstIP[1] == 0 && dstIP[2] == 0 && dstIP[3] != 0 {
    62  		readHostName = true
    63  	} else {
    64  		request.Destination.Addr = netip.AddrFrom4(dstIP)
    65  	}
    66  	request.Username, err = readString(reader)
    67  	if readHostName {
    68  		request.Destination.Fqdn, err = readString(reader)
    69  		request.Destination = M.ParseSocksaddrHostPort(request.Destination.Fqdn, request.Destination.Port)
    70  	}
    71  	return
    72  }
    73  
    74  func WriteRequest(writer io.Writer, request Request) error {
    75  	var requestLen int
    76  	requestLen += 1 // version
    77  	requestLen += 1 // command
    78  	requestLen += 2 // port
    79  	requestLen += 4 // ip
    80  	requestLen += 1 // NUL
    81  	if !request.Destination.IsIPv4() {
    82  		requestLen += len(request.Destination.AddrString()) + 1
    83  	}
    84  	if request.Username != "" {
    85  		requestLen += len(request.Username)
    86  	}
    87  
    88  	buffer := buf.NewSize(requestLen)
    89  	defer buffer.Release()
    90  
    91  	common.Must(
    92  		buffer.WriteByte(Version),
    93  		buffer.WriteByte(request.Command),
    94  		binary.Write(buffer, binary.BigEndian, request.Destination.Port),
    95  	)
    96  	if request.Destination.IsIPv4() {
    97  		common.Must1(buffer.Write(request.Destination.Addr.AsSlice()))
    98  	} else {
    99  		// 0.0.0.X
   100  		common.Must(buffer.WriteZeroN(3))
   101  		common.Must(buffer.WriteByte(1))
   102  	}
   103  	if request.Username != "" {
   104  		common.Must1(buffer.WriteString(request.Username))
   105  	}
   106  	common.Must(buffer.WriteZero())
   107  	if !request.Destination.IsIPv4() {
   108  		common.Must1(buffer.WriteString(request.Destination.AddrString()))
   109  		common.Must(buffer.WriteZero())
   110  	}
   111  	return rw.WriteBytes(writer, buffer.Bytes())
   112  }
   113  
   114  type Response struct {
   115  	ReplyCode   byte
   116  	Destination M.Socksaddr
   117  }
   118  
   119  func ReadResponse(reader io.Reader) (response Response, err error) {
   120  	version, err := rw.ReadByte(reader)
   121  	if err != nil {
   122  		return
   123  	}
   124  	if version != 0 {
   125  		err = E.New("excepted socks4 response version 0, got ", version)
   126  		return
   127  	}
   128  	response.ReplyCode, err = rw.ReadByte(reader)
   129  	if err != nil {
   130  		return
   131  	}
   132  	err = binary.Read(reader, binary.BigEndian, &response.Destination.Port)
   133  	if err != nil {
   134  		return
   135  	}
   136  	var dstIP [4]byte
   137  	_, err = io.ReadFull(reader, dstIP[:])
   138  	if err != nil {
   139  		return
   140  	}
   141  	response.Destination.Addr = netip.AddrFrom4(dstIP)
   142  	return
   143  }
   144  
   145  func WriteResponse(writer io.Writer, response Response) error {
   146  	buffer := buf.NewSize(8)
   147  	defer buffer.Release()
   148  	common.Must(
   149  		buffer.WriteByte(0),
   150  		buffer.WriteByte(response.ReplyCode),
   151  		binary.Write(buffer, binary.BigEndian, response.Destination.Port),
   152  		common.Error(buffer.Write(response.Destination.Addr.AsSlice())),
   153  	)
   154  	return rw.WriteBytes(writer, buffer.Bytes())
   155  }
   156  
   157  func readString(reader io.Reader) (string, error) {
   158  	buffer := bytes.Buffer{}
   159  	for {
   160  		b, err := rw.ReadByte(reader)
   161  		if err != nil {
   162  			return "", err
   163  		}
   164  		if b == 0 {
   165  			break
   166  		}
   167  		buffer.WriteByte(b)
   168  	}
   169  	return buffer.String(), nil
   170  }