github.com/moqsien/xraycore@v1.8.5/proxy/http/server.go (about)

     1  package http
     2  
     3  import (
     4  	"bufio"
     5  	"context"
     6  	"encoding/base64"
     7  	"io"
     8  	"net/http"
     9  	"strings"
    10  	"time"
    11  
    12  	"github.com/moqsien/xraycore/common"
    13  	"github.com/moqsien/xraycore/common/buf"
    14  	"github.com/moqsien/xraycore/common/errors"
    15  	"github.com/moqsien/xraycore/common/log"
    16  	"github.com/moqsien/xraycore/common/net"
    17  	"github.com/moqsien/xraycore/common/protocol"
    18  	http_proto "github.com/moqsien/xraycore/common/protocol/http"
    19  	"github.com/moqsien/xraycore/common/session"
    20  	"github.com/moqsien/xraycore/common/signal"
    21  	"github.com/moqsien/xraycore/common/task"
    22  	"github.com/moqsien/xraycore/core"
    23  	"github.com/moqsien/xraycore/features/policy"
    24  	"github.com/moqsien/xraycore/features/routing"
    25  	"github.com/moqsien/xraycore/transport/internet/stat"
    26  )
    27  
    28  // Server is an HTTP proxy server.
    29  type Server struct {
    30  	config        *ServerConfig
    31  	policyManager policy.Manager
    32  }
    33  
    34  // NewServer creates a new HTTP inbound handler.
    35  func NewServer(ctx context.Context, config *ServerConfig) (*Server, error) {
    36  	v := core.MustFromContext(ctx)
    37  	s := &Server{
    38  		config:        config,
    39  		policyManager: v.GetFeature(policy.ManagerType()).(policy.Manager),
    40  	}
    41  
    42  	return s, nil
    43  }
    44  
    45  func (s *Server) policy() policy.Session {
    46  	config := s.config
    47  	p := s.policyManager.ForLevel(config.UserLevel)
    48  	if config.Timeout > 0 && config.UserLevel == 0 {
    49  		p.Timeouts.ConnectionIdle = time.Duration(config.Timeout) * time.Second
    50  	}
    51  	return p
    52  }
    53  
    54  // Network implements proxy.Inbound.
    55  func (*Server) Network() []net.Network {
    56  	return []net.Network{net.Network_TCP, net.Network_UNIX}
    57  }
    58  
    59  func isTimeout(err error) bool {
    60  	nerr, ok := errors.Cause(err).(net.Error)
    61  	return ok && nerr.Timeout()
    62  }
    63  
    64  func parseBasicAuth(auth string) (username, password string, ok bool) {
    65  	const prefix = "Basic "
    66  	if !strings.HasPrefix(auth, prefix) {
    67  		return
    68  	}
    69  	c, err := base64.StdEncoding.DecodeString(auth[len(prefix):])
    70  	if err != nil {
    71  		return
    72  	}
    73  	cs := string(c)
    74  	s := strings.IndexByte(cs, ':')
    75  	if s < 0 {
    76  		return
    77  	}
    78  	return cs[:s], cs[s+1:], true
    79  }
    80  
    81  type readerOnly struct {
    82  	io.Reader
    83  }
    84  
    85  func (s *Server) Process(ctx context.Context, network net.Network, conn stat.Connection, dispatcher routing.Dispatcher) error {
    86  	inbound := session.InboundFromContext(ctx)
    87  	if inbound != nil {
    88  		inbound.Name = "http"
    89  		inbound.User = &protocol.MemoryUser{
    90  			Level: s.config.UserLevel,
    91  		}
    92  	}
    93  
    94  	reader := bufio.NewReaderSize(readerOnly{conn}, buf.Size)
    95  
    96  Start:
    97  	if err := conn.SetReadDeadline(time.Now().Add(s.policy().Timeouts.Handshake)); err != nil {
    98  		newError("failed to set read deadline").Base(err).WriteToLog(session.ExportIDToError(ctx))
    99  	}
   100  
   101  	request, err := http.ReadRequest(reader)
   102  	if err != nil {
   103  		trace := newError("failed to read http request").Base(err)
   104  		if errors.Cause(err) != io.EOF && !isTimeout(errors.Cause(err)) {
   105  			trace.AtWarning()
   106  		}
   107  		return trace
   108  	}
   109  
   110  	if len(s.config.Accounts) > 0 {
   111  		user, pass, ok := parseBasicAuth(request.Header.Get("Proxy-Authorization"))
   112  		if !ok || !s.config.HasAccount(user, pass) {
   113  			return common.Error2(conn.Write([]byte("HTTP/1.1 407 Proxy Authentication Required\r\nProxy-Authenticate: Basic realm=\"proxy\"\r\n\r\n")))
   114  		}
   115  		if inbound != nil {
   116  			inbound.User.Email = user
   117  		}
   118  	}
   119  
   120  	newError("request to Method [", request.Method, "] Host [", request.Host, "] with URL [", request.URL, "]").WriteToLog(session.ExportIDToError(ctx))
   121  	if err := conn.SetReadDeadline(time.Time{}); err != nil {
   122  		newError("failed to clear read deadline").Base(err).WriteToLog(session.ExportIDToError(ctx))
   123  	}
   124  
   125  	defaultPort := net.Port(80)
   126  	if strings.EqualFold(request.URL.Scheme, "https") {
   127  		defaultPort = net.Port(443)
   128  	}
   129  	host := request.Host
   130  	if host == "" {
   131  		host = request.URL.Host
   132  	}
   133  	dest, err := http_proto.ParseHost(host, defaultPort)
   134  	if err != nil {
   135  		return newError("malformed proxy host: ", host).AtWarning().Base(err)
   136  	}
   137  	ctx = log.ContextWithAccessMessage(ctx, &log.AccessMessage{
   138  		From:   conn.RemoteAddr(),
   139  		To:     request.URL,
   140  		Status: log.AccessAccepted,
   141  		Reason: "",
   142  	})
   143  
   144  	if strings.EqualFold(request.Method, "CONNECT") {
   145  		return s.handleConnect(ctx, request, reader, conn, dest, dispatcher, inbound)
   146  	}
   147  
   148  	keepAlive := (strings.TrimSpace(strings.ToLower(request.Header.Get("Proxy-Connection"))) == "keep-alive")
   149  
   150  	err = s.handlePlainHTTP(ctx, request, conn, dest, dispatcher)
   151  	if err == errWaitAnother {
   152  		if keepAlive {
   153  			goto Start
   154  		}
   155  		err = nil
   156  	}
   157  
   158  	return err
   159  }
   160  
   161  func (s *Server) handleConnect(ctx context.Context, _ *http.Request, reader *bufio.Reader, conn stat.Connection, dest net.Destination, dispatcher routing.Dispatcher, inbound *session.Inbound) error {
   162  	_, err := conn.Write([]byte("HTTP/1.1 200 Connection established\r\n\r\n"))
   163  	if err != nil {
   164  		return newError("failed to write back OK response").Base(err)
   165  	}
   166  
   167  	plcy := s.policy()
   168  	ctx, cancel := context.WithCancel(ctx)
   169  	timer := signal.CancelAfterInactivity(ctx, cancel, plcy.Timeouts.ConnectionIdle)
   170  
   171  	if inbound != nil {
   172  		inbound.Timer = timer
   173  	}
   174  
   175  	ctx = policy.ContextWithBufferPolicy(ctx, plcy.Buffer)
   176  	link, err := dispatcher.Dispatch(ctx, dest)
   177  	if err != nil {
   178  		return err
   179  	}
   180  
   181  	if reader.Buffered() > 0 {
   182  		payload, err := buf.ReadFrom(io.LimitReader(reader, int64(reader.Buffered())))
   183  		if err != nil {
   184  			return err
   185  		}
   186  		if err := link.Writer.WriteMultiBuffer(payload); err != nil {
   187  			return err
   188  		}
   189  		reader = nil
   190  	}
   191  
   192  	requestDone := func() error {
   193  		defer timer.SetTimeout(plcy.Timeouts.DownlinkOnly)
   194  
   195  		return buf.Copy(buf.NewReader(conn), link.Writer, buf.UpdateActivity(timer))
   196  	}
   197  
   198  	responseDone := func() error {
   199  		defer timer.SetTimeout(plcy.Timeouts.UplinkOnly)
   200  
   201  		v2writer := buf.NewWriter(conn)
   202  		if err := buf.Copy(link.Reader, v2writer, buf.UpdateActivity(timer)); err != nil {
   203  			return err
   204  		}
   205  
   206  		return nil
   207  	}
   208  
   209  	closeWriter := task.OnSuccess(requestDone, task.Close(link.Writer))
   210  	if err := task.Run(ctx, closeWriter, responseDone); err != nil {
   211  		common.Interrupt(link.Reader)
   212  		common.Interrupt(link.Writer)
   213  		return newError("connection ends").Base(err)
   214  	}
   215  
   216  	return nil
   217  }
   218  
   219  var errWaitAnother = newError("keep alive")
   220  
   221  func (s *Server) handlePlainHTTP(ctx context.Context, request *http.Request, writer io.Writer, dest net.Destination, dispatcher routing.Dispatcher) error {
   222  	if !s.config.AllowTransparent && request.URL.Host == "" {
   223  		// RFC 2068 (HTTP/1.1) requires URL to be absolute URL in HTTP proxy.
   224  		response := &http.Response{
   225  			Status:        "Bad Request",
   226  			StatusCode:    400,
   227  			Proto:         "HTTP/1.1",
   228  			ProtoMajor:    1,
   229  			ProtoMinor:    1,
   230  			Header:        http.Header(make(map[string][]string)),
   231  			Body:          nil,
   232  			ContentLength: 0,
   233  			Close:         true,
   234  		}
   235  		response.Header.Set("Proxy-Connection", "close")
   236  		response.Header.Set("Connection", "close")
   237  		return response.Write(writer)
   238  	}
   239  
   240  	if len(request.URL.Host) > 0 {
   241  		request.Host = request.URL.Host
   242  	}
   243  	http_proto.RemoveHopByHopHeaders(request.Header)
   244  
   245  	// Prevent UA from being set to golang's default ones
   246  	if request.Header.Get("User-Agent") == "" {
   247  		request.Header.Set("User-Agent", "")
   248  	}
   249  
   250  	content := &session.Content{
   251  		Protocol: "http/1.1",
   252  	}
   253  
   254  	content.SetAttribute(":method", strings.ToUpper(request.Method))
   255  	content.SetAttribute(":path", request.URL.Path)
   256  	for key := range request.Header {
   257  		value := request.Header.Get(key)
   258  		content.SetAttribute(strings.ToLower(key), value)
   259  	}
   260  
   261  	ctx = session.ContextWithContent(ctx, content)
   262  
   263  	link, err := dispatcher.Dispatch(ctx, dest)
   264  	if err != nil {
   265  		return err
   266  	}
   267  
   268  	// Plain HTTP request is not a stream. The request always finishes before response. Hense request has to be closed later.
   269  	defer common.Close(link.Writer)
   270  	var result error = errWaitAnother
   271  
   272  	requestDone := func() error {
   273  		request.Header.Set("Connection", "close")
   274  
   275  		requestWriter := buf.NewBufferedWriter(link.Writer)
   276  		common.Must(requestWriter.SetBuffered(false))
   277  		if err := request.Write(requestWriter); err != nil {
   278  			return newError("failed to write whole request").Base(err).AtWarning()
   279  		}
   280  		return nil
   281  	}
   282  
   283  	responseDone := func() error {
   284  		responseReader := bufio.NewReaderSize(&buf.BufferedReader{Reader: link.Reader}, buf.Size)
   285  		response, err := http.ReadResponse(responseReader, request)
   286  		if err == nil {
   287  			http_proto.RemoveHopByHopHeaders(response.Header)
   288  			if response.ContentLength >= 0 {
   289  				response.Header.Set("Proxy-Connection", "keep-alive")
   290  				response.Header.Set("Connection", "keep-alive")
   291  				response.Header.Set("Keep-Alive", "timeout=4")
   292  				response.Close = false
   293  			} else {
   294  				response.Close = true
   295  				result = nil
   296  			}
   297  			defer response.Body.Close()
   298  		} else {
   299  			newError("failed to read response from ", request.Host).Base(err).AtWarning().WriteToLog(session.ExportIDToError(ctx))
   300  			response = &http.Response{
   301  				Status:        "Service Unavailable",
   302  				StatusCode:    503,
   303  				Proto:         "HTTP/1.1",
   304  				ProtoMajor:    1,
   305  				ProtoMinor:    1,
   306  				Header:        http.Header(make(map[string][]string)),
   307  				Body:          nil,
   308  				ContentLength: 0,
   309  				Close:         true,
   310  			}
   311  			response.Header.Set("Connection", "close")
   312  			response.Header.Set("Proxy-Connection", "close")
   313  		}
   314  		if err := response.Write(writer); err != nil {
   315  			return newError("failed to write response").Base(err).AtWarning()
   316  		}
   317  		return nil
   318  	}
   319  
   320  	if err := task.Run(ctx, requestDone, responseDone); err != nil {
   321  		common.Interrupt(link.Reader)
   322  		common.Interrupt(link.Writer)
   323  		return newError("connection ends").Base(err)
   324  	}
   325  
   326  	return result
   327  }
   328  
   329  func init() {
   330  	common.Must(common.RegisterConfig((*ServerConfig)(nil), func(ctx context.Context, config interface{}) (interface{}, error) {
   331  		return NewServer(ctx, config.(*ServerConfig))
   332  	}))
   333  }