github.com/xraypb/xray-core@v1.6.6/transport/internet/headers/http/http.go (about)

     1  package http
     2  
     3  //go:generate go run github.com/xraypb/xray-core/common/errors/errorgen
     4  
     5  import (
     6  	"bufio"
     7  	"bytes"
     8  	"context"
     9  	"io"
    10  	"net"
    11  	"net/http"
    12  	"strings"
    13  	"time"
    14  
    15  	"github.com/xraypb/xray-core/common"
    16  	"github.com/xraypb/xray-core/common/buf"
    17  )
    18  
    19  const (
    20  	// CRLF is the line ending in HTTP header
    21  	CRLF = "\r\n"
    22  
    23  	// ENDING is the double line ending between HTTP header and body.
    24  	ENDING = CRLF + CRLF
    25  
    26  	// max length of HTTP header. Safety precaution for DDoS attack.
    27  	maxHeaderLength = 8192
    28  )
    29  
    30  var (
    31  	ErrHeaderToLong = newError("Header too long.")
    32  
    33  	ErrHeaderMisMatch = newError("Header Mismatch.")
    34  )
    35  
    36  type Reader interface {
    37  	Read(io.Reader) (*buf.Buffer, error)
    38  }
    39  
    40  type Writer interface {
    41  	Write(io.Writer) error
    42  }
    43  
    44  type NoOpReader struct{}
    45  
    46  func (NoOpReader) Read(io.Reader) (*buf.Buffer, error) {
    47  	return nil, nil
    48  }
    49  
    50  type NoOpWriter struct{}
    51  
    52  func (NoOpWriter) Write(io.Writer) error {
    53  	return nil
    54  }
    55  
    56  type HeaderReader struct {
    57  	req            *http.Request
    58  	expectedHeader *RequestConfig
    59  }
    60  
    61  func (h *HeaderReader) ExpectThisRequest(expectedHeader *RequestConfig) *HeaderReader {
    62  	h.expectedHeader = expectedHeader
    63  	return h
    64  }
    65  
    66  func (h *HeaderReader) Read(reader io.Reader) (*buf.Buffer, error) {
    67  	buffer := buf.New()
    68  	totalBytes := int32(0)
    69  	endingDetected := false
    70  
    71  	var headerBuf bytes.Buffer
    72  
    73  	for totalBytes < maxHeaderLength {
    74  		_, err := buffer.ReadFrom(reader)
    75  		if err != nil {
    76  			buffer.Release()
    77  			return nil, err
    78  		}
    79  		if n := bytes.Index(buffer.Bytes(), []byte(ENDING)); n != -1 {
    80  			headerBuf.Write(buffer.BytesRange(0, int32(n+len(ENDING))))
    81  			buffer.Advance(int32(n + len(ENDING)))
    82  			endingDetected = true
    83  			break
    84  		}
    85  		lenEnding := int32(len(ENDING))
    86  		if buffer.Len() >= lenEnding {
    87  			totalBytes += buffer.Len() - lenEnding
    88  			headerBuf.Write(buffer.BytesRange(0, buffer.Len()-lenEnding))
    89  			leftover := buffer.BytesFrom(-lenEnding)
    90  			buffer.Clear()
    91  			copy(buffer.Extend(lenEnding), leftover)
    92  
    93  			if _, err := readRequest(bufio.NewReader(bytes.NewReader(headerBuf.Bytes()))); err != io.ErrUnexpectedEOF {
    94  				return nil, err
    95  			}
    96  		}
    97  	}
    98  
    99  	if !endingDetected {
   100  		buffer.Release()
   101  		return nil, ErrHeaderToLong
   102  	}
   103  
   104  	if h.expectedHeader == nil {
   105  		if buffer.IsEmpty() {
   106  			buffer.Release()
   107  			return nil, nil
   108  		}
   109  		return buffer, nil
   110  	}
   111  
   112  	// Parse the request
   113  	if req, err := readRequest(bufio.NewReader(bytes.NewReader(headerBuf.Bytes()))); err != nil {
   114  		return nil, err
   115  	} else {
   116  		h.req = req
   117  	}
   118  
   119  	// Check req
   120  	path := h.req.URL.Path
   121  	hasThisURI := false
   122  	for _, u := range h.expectedHeader.Uri {
   123  		if u == path {
   124  			hasThisURI = true
   125  		}
   126  	}
   127  
   128  	if !hasThisURI {
   129  		return nil, ErrHeaderMisMatch
   130  	}
   131  
   132  	if buffer.IsEmpty() {
   133  		buffer.Release()
   134  		return nil, nil
   135  	}
   136  
   137  	return buffer, nil
   138  }
   139  
   140  type HeaderWriter struct {
   141  	header *buf.Buffer
   142  }
   143  
   144  func NewHeaderWriter(header *buf.Buffer) *HeaderWriter {
   145  	return &HeaderWriter{
   146  		header: header,
   147  	}
   148  }
   149  
   150  func (w *HeaderWriter) Write(writer io.Writer) error {
   151  	if w.header == nil {
   152  		return nil
   153  	}
   154  	err := buf.WriteAllBytes(writer, w.header.Bytes(), nil)
   155  	w.header.Release()
   156  	w.header = nil
   157  	return err
   158  }
   159  
   160  type Conn struct {
   161  	net.Conn
   162  
   163  	readBuffer          *buf.Buffer
   164  	oneTimeReader       Reader
   165  	oneTimeWriter       Writer
   166  	errorWriter         Writer
   167  	errorMismatchWriter Writer
   168  	errorTooLongWriter  Writer
   169  	errReason           error
   170  }
   171  
   172  func NewConn(conn net.Conn, reader Reader, writer Writer, errorWriter Writer, errorMismatchWriter Writer, errorTooLongWriter Writer) *Conn {
   173  	return &Conn{
   174  		Conn:                conn,
   175  		oneTimeReader:       reader,
   176  		oneTimeWriter:       writer,
   177  		errorWriter:         errorWriter,
   178  		errorMismatchWriter: errorMismatchWriter,
   179  		errorTooLongWriter:  errorTooLongWriter,
   180  	}
   181  }
   182  
   183  func (c *Conn) Read(b []byte) (int, error) {
   184  	if c.oneTimeReader != nil {
   185  		buffer, err := c.oneTimeReader.Read(c.Conn)
   186  		if err != nil {
   187  			c.errReason = err
   188  			return 0, err
   189  		}
   190  		c.readBuffer = buffer
   191  		c.oneTimeReader = nil
   192  	}
   193  
   194  	if !c.readBuffer.IsEmpty() {
   195  		nBytes, _ := c.readBuffer.Read(b)
   196  		if c.readBuffer.IsEmpty() {
   197  			c.readBuffer.Release()
   198  			c.readBuffer = nil
   199  		}
   200  		return nBytes, nil
   201  	}
   202  
   203  	return c.Conn.Read(b)
   204  }
   205  
   206  // Write implements io.Writer.
   207  func (c *Conn) Write(b []byte) (int, error) {
   208  	if c.oneTimeWriter != nil {
   209  		err := c.oneTimeWriter.Write(c.Conn)
   210  		c.oneTimeWriter = nil
   211  		if err != nil {
   212  			return 0, err
   213  		}
   214  	}
   215  
   216  	return c.Conn.Write(b)
   217  }
   218  
   219  // Close implements net.Conn.Close().
   220  func (c *Conn) Close() error {
   221  	if c.oneTimeWriter != nil && c.errorWriter != nil {
   222  		// Connection is being closed but header wasn't sent. This means the client request
   223  		// is probably not valid. Sending back a server error header in this case.
   224  
   225  		// Write response based on error reason
   226  		switch c.errReason {
   227  		case ErrHeaderMisMatch:
   228  			c.errorMismatchWriter.Write(c.Conn)
   229  		case ErrHeaderToLong:
   230  			c.errorTooLongWriter.Write(c.Conn)
   231  		default:
   232  			c.errorWriter.Write(c.Conn)
   233  		}
   234  	}
   235  
   236  	return c.Conn.Close()
   237  }
   238  
   239  func formResponseHeader(config *ResponseConfig) *HeaderWriter {
   240  	header := buf.New()
   241  	common.Must2(header.WriteString(strings.Join([]string{config.GetFullVersion(), config.GetStatusValue().Code, config.GetStatusValue().Reason}, " ")))
   242  	common.Must2(header.WriteString(CRLF))
   243  
   244  	headers := config.PickHeaders()
   245  	for _, h := range headers {
   246  		common.Must2(header.WriteString(h))
   247  		common.Must2(header.WriteString(CRLF))
   248  	}
   249  	if !config.HasHeader("Date") {
   250  		common.Must2(header.WriteString("Date: "))
   251  		common.Must2(header.WriteString(time.Now().Format(http.TimeFormat)))
   252  		common.Must2(header.WriteString(CRLF))
   253  	}
   254  	common.Must2(header.WriteString(CRLF))
   255  	return &HeaderWriter{
   256  		header: header,
   257  	}
   258  }
   259  
   260  type Authenticator struct {
   261  	config *Config
   262  }
   263  
   264  func (a Authenticator) GetClientWriter() *HeaderWriter {
   265  	header := buf.New()
   266  	config := a.config.Request
   267  	common.Must2(header.WriteString(strings.Join([]string{config.GetMethodValue(), config.PickURI(), config.GetFullVersion()}, " ")))
   268  	common.Must2(header.WriteString(CRLF))
   269  
   270  	headers := config.PickHeaders()
   271  	for _, h := range headers {
   272  		common.Must2(header.WriteString(h))
   273  		common.Must2(header.WriteString(CRLF))
   274  	}
   275  	common.Must2(header.WriteString(CRLF))
   276  	return &HeaderWriter{
   277  		header: header,
   278  	}
   279  }
   280  
   281  func (a Authenticator) GetServerWriter() *HeaderWriter {
   282  	return formResponseHeader(a.config.Response)
   283  }
   284  
   285  func (a Authenticator) Client(conn net.Conn) net.Conn {
   286  	if a.config.Request == nil && a.config.Response == nil {
   287  		return conn
   288  	}
   289  	var reader Reader = NoOpReader{}
   290  	if a.config.Request != nil {
   291  		reader = new(HeaderReader)
   292  	}
   293  
   294  	var writer Writer = NoOpWriter{}
   295  	if a.config.Response != nil {
   296  		writer = a.GetClientWriter()
   297  	}
   298  	return NewConn(conn, reader, writer, NoOpWriter{}, NoOpWriter{}, NoOpWriter{})
   299  }
   300  
   301  func (a Authenticator) Server(conn net.Conn) net.Conn {
   302  	if a.config.Request == nil && a.config.Response == nil {
   303  		return conn
   304  	}
   305  	return NewConn(conn, new(HeaderReader).ExpectThisRequest(a.config.Request), a.GetServerWriter(),
   306  		formResponseHeader(resp400),
   307  		formResponseHeader(resp404),
   308  		formResponseHeader(resp400))
   309  }
   310  
   311  func NewAuthenticator(ctx context.Context, config *Config) (Authenticator, error) {
   312  	return Authenticator{
   313  		config: config,
   314  	}, nil
   315  }
   316  
   317  func init() {
   318  	common.Must(common.RegisterConfig((*Config)(nil), func(ctx context.Context, config interface{}) (interface{}, error) {
   319  		return NewAuthenticator(ctx, config.(*Config))
   320  	}))
   321  }