github.com/mysteriumnetwork/node@v0.0.0-20240516044423-365054f76801/services/wireguard/endpoint/proxyclient/utils.go (about)

     1  /*
     2   * Copyright (C) 2022 The "MysteriumNetwork/node" Authors.
     3   *
     4   * This program is free software: you can redistribute it and/or modify
     5   * it under the terms of the GNU General Public License as published by
     6   * the Free Software Foundation, either version 3 of the License, or
     7   * (at your option) any later version.
     8   *
     9   * This program is distributed in the hope that it will be useful,
    10   * but WITHOUT ANY WARRANTY; without even the implied warranty of
    11   * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
    12   * GNU General Public License for more details.
    13   *
    14   * You should have received a copy of the GNU General Public License
    15   * along with this program.  If not, see <http://www.gnu.org/licenses/>.
    16   */
    17  
    18  package proxyclient
    19  
    20  import (
    21  	"bufio"
    22  	"context"
    23  	"errors"
    24  	"io"
    25  	"net"
    26  	"net/http"
    27  	"sync"
    28  	"time"
    29  )
    30  
    31  const copyBufferSize = 128 * 1024
    32  
    33  var bufferPool = NewBufferPool(copyBufferSize)
    34  
    35  func proxyHTTP1(ctx context.Context, left, right net.Conn) {
    36  	wg := sync.WaitGroup{}
    37  
    38  	idleTimeout := 5 * time.Minute
    39  	timeout := time.AfterFunc(idleTimeout, func() {
    40  		left.Close()
    41  		right.Close()
    42  	})
    43  	extend := func() {
    44  		timeout.Reset(idleTimeout)
    45  	}
    46  
    47  	cpy := func(dst, src net.Conn) {
    48  		defer wg.Done()
    49  
    50  		copyBuffer(dst, src, extend)
    51  		dst.Close()
    52  	}
    53  	wg.Add(2)
    54  	go cpy(left, right)
    55  	go cpy(right, left)
    56  	groupDone := make(chan struct{}, 1)
    57  	go func() {
    58  		wg.Wait()
    59  		groupDone <- struct{}{}
    60  	}()
    61  	select {
    62  	case <-ctx.Done():
    63  		left.Close()
    64  		right.Close()
    65  	case <-groupDone:
    66  		return
    67  	}
    68  	<-groupDone
    69  	return
    70  }
    71  
    72  func proxyHTTP2(ctx context.Context, leftreader io.ReadCloser, leftwriter io.Writer, right net.Conn) {
    73  	wg := sync.WaitGroup{}
    74  
    75  	idleTimeout := 5 * time.Minute
    76  	timeout := time.AfterFunc(idleTimeout, func() {
    77  		leftreader.Close()
    78  		right.Close()
    79  	})
    80  	extend := func() {
    81  		timeout.Reset(idleTimeout)
    82  	}
    83  
    84  	ltr := func(dst net.Conn, src io.Reader) {
    85  		defer wg.Done()
    86  		copyBuffer(dst, src, extend)
    87  		dst.Close()
    88  	}
    89  	rtl := func(dst io.Writer, src io.Reader) {
    90  		defer wg.Done()
    91  		copyBody(dst, src)
    92  	}
    93  	wg.Add(2)
    94  	go ltr(right, leftreader)
    95  	go rtl(leftwriter, right)
    96  	groupDone := make(chan struct{}, 1)
    97  	go func() {
    98  		wg.Wait()
    99  		groupDone <- struct{}{}
   100  	}()
   101  	select {
   102  	case <-ctx.Done():
   103  		leftreader.Close()
   104  		right.Close()
   105  	case <-groupDone:
   106  		return
   107  	}
   108  	<-groupDone
   109  	return
   110  }
   111  
   112  // Hop-by-hop headers. These are removed when sent to the backend.
   113  // http://www.w3.org/Protocols/rfc2616/rfc2616-sec13.html
   114  var hopHeaders = []string{
   115  	"Connection",
   116  	"Keep-Alive",
   117  	"Proxy-Authenticate",
   118  	"Proxy-Connection",
   119  	"Proxy-Authorization",
   120  	"Te", // canonicalized version of "TE"
   121  	"Trailers",
   122  	"Transfer-Encoding",
   123  	"Upgrade",
   124  }
   125  
   126  func copyHeader(dst, src http.Header) {
   127  	for k, vv := range src {
   128  		for _, v := range vv {
   129  			dst.Add(k, v)
   130  		}
   131  	}
   132  }
   133  
   134  func delHopHeaders(header http.Header) {
   135  	for _, h := range hopHeaders {
   136  		header.Del(h)
   137  	}
   138  }
   139  
   140  func hijack(hijackable interface{}) (net.Conn, *bufio.ReadWriter, error) {
   141  	hj, ok := hijackable.(http.Hijacker)
   142  	if !ok {
   143  		return nil, nil, errors.New("connection does not support hijacking")
   144  	}
   145  	conn, rw, err := hj.Hijack()
   146  	if err != nil {
   147  		return nil, nil, err
   148  	}
   149  	var emptyTime time.Time
   150  	err = conn.SetDeadline(emptyTime)
   151  	if err != nil {
   152  		conn.Close()
   153  		return nil, nil, err
   154  	}
   155  	return conn, rw, nil
   156  }
   157  
   158  func flush(flusher interface{}) bool {
   159  	f, ok := flusher.(http.Flusher)
   160  	if !ok {
   161  		return false
   162  	}
   163  	f.Flush()
   164  	return true
   165  }
   166  
   167  func copyBody(wr io.Writer, body io.Reader) {
   168  	buf := bufferPool.Get()
   169  	defer bufferPool.Put(buf)
   170  
   171  	for {
   172  		bread, readErr := body.Read(buf)
   173  		var writeErr error
   174  		if bread > 0 {
   175  			_, writeErr = wr.Write(buf[:bread])
   176  			flush(wr)
   177  		}
   178  		if readErr != nil || writeErr != nil {
   179  			break
   180  		}
   181  	}
   182  }
   183  
   184  func copyBuffer(dst io.Writer, src io.Reader, extend func()) (written int64, err error) {
   185  	buf := bufferPool.Get()
   186  	defer bufferPool.Put(buf)
   187  
   188  	for {
   189  		extend()
   190  		nr, er := src.Read(buf)
   191  		if nr > 0 {
   192  			nw, ew := dst.Write(buf[0:nr])
   193  			if nw < 0 || nr < nw {
   194  				nw = 0
   195  				if ew == nil {
   196  					ew = errors.New("invalid write result")
   197  				}
   198  			}
   199  			written += int64(nw)
   200  			if ew != nil {
   201  				err = ew
   202  				break
   203  			}
   204  			if nr != nw {
   205  				err = io.ErrShortWrite
   206  				break
   207  			}
   208  		}
   209  		if er != nil {
   210  			if er != io.EOF {
   211  				err = er
   212  			}
   213  			break
   214  		}
   215  	}
   216  	return written, err
   217  }