github.com/mysteriumnetwork/node@v0.0.0-20240516044423-365054f76801/services/wireguard/endpoint/proxyclient/handler.go (about)

     1  /*
     2   * Copyright (C) 2022 The "MysteriumNetwork/node" Authors.
     3   *
     4   * This program is free software: you can redistribute it and/or modify
     5   * it under the terms of the GNU General Public License as published by
     6   * the Free Software Foundation, either version 3 of the License, or
     7   * (at your option) any later version.
     8   *
     9   * This program is distributed in the hope that it will be useful,
    10   * but WITHOUT ANY WARRANTY; without even the implied warranty of
    11   * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
    12   * GNU General Public License for more details.
    13   *
    14   * You should have received a copy of the GNU General Public License
    15   * along with this program.  If not, see <http://www.gnu.org/licenses/>.
    16   */
    17  
    18  package proxyclient
    19  
    20  import (
    21  	"context"
    22  	"fmt"
    23  	"net/http"
    24  	"strings"
    25  	"sync"
    26  	"time"
    27  
    28  	"github.com/rs/zerolog/log"
    29  	"golang.org/x/net/proxy"
    30  )
    31  
    32  type proxyHandler struct {
    33  	timeout       time.Duration
    34  	httptransport http.RoundTripper
    35  	outbound      map[string]string
    36  	outboundMux   sync.RWMutex
    37  	dialer        proxy.ContextDialer
    38  }
    39  
    40  func newProxyHandler(timeout time.Duration, dialer proxy.ContextDialer) *proxyHandler {
    41  	httptransport := &http.Transport{
    42  		DialContext: dialer.DialContext,
    43  	}
    44  	return &proxyHandler{
    45  		timeout:       timeout,
    46  		httptransport: httptransport,
    47  		outbound:      make(map[string]string),
    48  		dialer:        dialer,
    49  	}
    50  }
    51  
    52  func (s *proxyHandler) handleTunnel(wr http.ResponseWriter, req *http.Request) {
    53  	ctx, cancel := context.WithTimeout(req.Context(), s.timeout)
    54  	defer cancel()
    55  
    56  	conn, err := s.dialer.DialContext(ctx, "tcp", req.RequestURI)
    57  	if err != nil {
    58  		log.Error().Err(err).Msg("Can't satisfy CONNECT request")
    59  		http.Error(wr, "Can't satisfy CONNECT request", http.StatusBadGateway)
    60  		return
    61  	}
    62  
    63  	localAddr := conn.LocalAddr().String()
    64  	s.outboundMux.Lock()
    65  	s.outbound[localAddr] = req.RemoteAddr
    66  	s.outboundMux.Unlock()
    67  	defer func() {
    68  		conn.Close()
    69  		s.outboundMux.Lock()
    70  		delete(s.outbound, localAddr)
    71  		s.outboundMux.Unlock()
    72  	}()
    73  
    74  	if req.ProtoMajor == 0 || req.ProtoMajor == 1 {
    75  		// Upgrade client connection
    76  		localconn, _, err := hijack(wr)
    77  		if err != nil {
    78  			log.Error().Err(err).Msg("Can't hijack client connection")
    79  			http.Error(wr, "Can't hijack client connection", http.StatusInternalServerError)
    80  			return
    81  		}
    82  		defer localconn.Close()
    83  
    84  		// Inform client connection is built
    85  		fmt.Fprintf(localconn, "HTTP/%d.%d 200 OK\r\n\r\n", req.ProtoMajor, req.ProtoMinor)
    86  
    87  		proxyHTTP1(req.Context(), localconn, conn)
    88  	} else if req.ProtoMajor == 2 {
    89  		wr.Header()["Date"] = nil
    90  		wr.WriteHeader(http.StatusOK)
    91  		flush(wr)
    92  		proxyHTTP2(req.Context(), req.Body, wr, conn)
    93  	} else {
    94  		log.Error().Msgf("Unsupported protocol version: %s", req.Proto)
    95  		http.Error(wr, "Unsupported protocol version.", http.StatusBadRequest)
    96  		return
    97  	}
    98  }
    99  
   100  func (s *proxyHandler) handleRequest(wr http.ResponseWriter, req *http.Request) {
   101  	req.RequestURI = ""
   102  	if req.ProtoMajor == 2 {
   103  		req.URL.Scheme = "http" // We can't access :scheme pseudo-header, so assume http
   104  		req.URL.Host = req.Host
   105  	}
   106  	resp, err := s.httptransport.RoundTrip(req)
   107  	if err != nil {
   108  		log.Error().Err(err).Msg("HTTP fetch error")
   109  		http.Error(wr, "Server Error", http.StatusInternalServerError)
   110  		return
   111  	}
   112  	defer resp.Body.Close()
   113  
   114  	delHopHeaders(resp.Header)
   115  	copyHeader(wr.Header(), resp.Header)
   116  	wr.WriteHeader(resp.StatusCode)
   117  	flush(wr)
   118  	copyBody(wr, resp.Body)
   119  }
   120  
   121  func (s *proxyHandler) isLoopback(req *http.Request) (string, bool) {
   122  	s.outboundMux.RLock()
   123  	originator, found := s.outbound[req.RemoteAddr]
   124  	s.outboundMux.RUnlock()
   125  	return originator, found
   126  }
   127  
   128  func (s *proxyHandler) ServeHTTP(wr http.ResponseWriter, req *http.Request) {
   129  	if originator, isLoopback := s.isLoopback(req); isLoopback {
   130  		log.Error().Msgf("Loopback tunnel detected: %s is an outbound "+
   131  			"address for another request from %s", req.RemoteAddr, originator)
   132  		http.Error(wr, http.StatusText(http.StatusBadRequest), http.StatusBadRequest)
   133  		return
   134  	}
   135  
   136  	isConnect := strings.ToUpper(req.Method) == "CONNECT"
   137  	if (req.URL.Host == "" || req.URL.Scheme == "" && !isConnect) && req.ProtoMajor < 2 ||
   138  		req.Host == "" && req.ProtoMajor == 2 {
   139  		http.Error(wr, http.StatusText(http.StatusBadRequest), http.StatusBadRequest)
   140  		return
   141  	}
   142  
   143  	delHopHeaders(req.Header)
   144  	if isConnect {
   145  		s.handleTunnel(wr, req)
   146  	} else {
   147  		s.handleRequest(wr, req)
   148  	}
   149  }