github.com/PurpleSec/switchproxy@v1.6.2/switch.go (about)

     1  // Copyright 2021 - 2022 PurpleSec Team
     2  //
     3  // This program is free software: you can redistribute it and/or modify
     4  // it under the terms of the GNU Affero General Public License as published
     5  // by the Free Software Foundation, either version 3 of the License, or
     6  // (at your option) any later version.
     7  //
     8  // This program is distributed in the hope that it will be useful,
     9  // but WITHOUT ANY WARRANTY; without even the implied warranty of
    10  // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
    11  // GNU Affero General Public License for more details.
    12  //
    13  // You should have received a copy of the GNU Affero General Public License
    14  // along with this program.  If not, see <https://www.gnu.org/licenses/>.
    15  //
    16  
    17  package switchproxy
    18  
    19  import (
    20  	"context"
    21  	"errors"
    22  	"io"
    23  	"net"
    24  	"net/http"
    25  	"net/url"
    26  	"path"
    27  	"strings"
    28  	"time"
    29  
    30  	// Import unsafe to use "fastrand" function
    31  	_ "unsafe"
    32  )
    33  
    34  const table = "0123456789ABCDEF"
    35  
    36  // Result is a struct that contains the data of the resulting Switch
    37  // operation to be passed to Handlers.
    38  type Result struct {
    39  	Headers http.Header `json:"headers"`
    40  	IP      string      `json:"ip"`
    41  	UUID    string      `json:"uuid"`
    42  	Path    string      `json:"path"`
    43  	Method  string      `json:"method"`
    44  	URL     string      `json:"url"`
    45  	Content []byte      `json:"content"`
    46  	Status  uint16      `json:"status"`
    47  }
    48  
    49  // Switch is a struct that represents a connection between proxy services.
    50  // This struct contains mapping and functions to capture input and output.
    51  type Switch struct {
    52  	Pre     Handler
    53  	Post    Handler
    54  	client  *http.Client
    55  	rewrite map[string]string
    56  	url.URL
    57  	timeout time.Duration
    58  }
    59  
    60  // Handler is a function alias that can be passed a Result for processing.
    61  type Handler func(Result)
    62  
    63  //go:linkname fastRand runtime.fastrand
    64  func fastRand() uint32
    65  func newUUID() string {
    66  	var b [64]byte
    67  	for i := 0; i < 64; i += 2 {
    68  		v := byte(fastRand() & 0xFF)
    69  		if v < 16 {
    70  			b[i], b[i+1] = '0', table[v&0x0F]
    71  		}
    72  		b[i], b[i+1] = table[v>>4], table[v&0x0F]
    73  	}
    74  	return string(b[:])
    75  }
    76  
    77  // IsResponse is a function that returns true if the Result is for a response.
    78  func (r Result) IsResponse() bool {
    79  	return len(r.Method) > 0 && r.Status > 0
    80  }
    81  
    82  // Rewrite adds a URL rewrite from the Switch.
    83  //
    84  // If a URL starts with the 'from' parameter, it will be replaced with the 'to'
    85  // parameter, only if starting with on the URL path.
    86  func (s *Switch) Rewrite(from, to string) {
    87  	s.rewrite[from] = to
    88  }
    89  
    90  // RemoveRewrite removes the URL rewrite from the Switch.
    91  func (s *Switch) RemoveRewrite(from string) {
    92  	delete(s.rewrite, from)
    93  }
    94  
    95  // NewSwitch creates a switching context that allows the connection to be proxied
    96  // to the specified server.
    97  func NewSwitch(target string) (*Switch, error) {
    98  	return NewSwitchTimeout(target, DefaultTimeout)
    99  }
   100  
   101  // NewSwitchTimeout creates a switching context that allows the connection to be
   102  // proxied to the specified server.
   103  //
   104  // This function will set the specified timeout.
   105  func NewSwitchTimeout(target string, t time.Duration) (*Switch, error) {
   106  	u, err := url.Parse(target)
   107  	if err != nil {
   108  		return nil, errors.New("unable to resolve URL: " + err.Error())
   109  	}
   110  	if !u.IsAbs() {
   111  		u.Scheme = "http"
   112  	}
   113  	s := &Switch{
   114  		URL: *u,
   115  		client: &http.Client{
   116  			Timeout: t,
   117  			Transport: &http.Transport{
   118  				Proxy: http.ProxyFromEnvironment,
   119  				DialContext: (&net.Dialer{
   120  					Timeout:   t,
   121  					KeepAlive: t,
   122  				}).DialContext,
   123  				IdleConnTimeout:       t,
   124  				TLSHandshakeTimeout:   t,
   125  				ExpectContinueTimeout: t,
   126  				ResponseHeaderTimeout: t,
   127  			},
   128  		},
   129  		timeout: t,
   130  		rewrite: make(map[string]string),
   131  	}
   132  	return s, nil
   133  }
   134  func (s Switch) process(x context.Context, r *http.Request, t *transfer) (int, http.Header, error) {
   135  	s.Path = r.URL.Path
   136  	s.User = r.URL.User
   137  	s.Opaque = r.URL.Opaque
   138  	s.Fragment = r.URL.Fragment
   139  	s.RawQuery = r.URL.RawQuery
   140  	s.ForceQuery = r.URL.ForceQuery
   141  	for k, v := range s.rewrite {
   142  		if strings.HasPrefix(s.Path, k) {
   143  			s.Path = path.Join(v, s.Path[len(k):])
   144  		}
   145  	}
   146  	f := func() {}
   147  	if s.timeout > 0 {
   148  		x, f = context.WithTimeout(x, s.timeout)
   149  	}
   150  	q, err := http.NewRequestWithContext(x, r.Method, s.String(), t.in)
   151  	if err != nil {
   152  		f()
   153  		return 0, nil, err
   154  	}
   155  	u := newUUID()
   156  	if s.Pre != nil {
   157  		s.Pre(Result{
   158  			IP:      r.RemoteAddr,
   159  			URL:     s.String(),
   160  			UUID:    u,
   161  			Path:    s.Path,
   162  			Method:  r.Method,
   163  			Content: t.data,
   164  			Headers: r.Header,
   165  		})
   166  	}
   167  	q.Header, q.Trailer = r.Header, r.Trailer
   168  	q.TransferEncoding = r.TransferEncoding
   169  	o, err := s.client.Do(q)
   170  	if err != nil {
   171  		f()
   172  		return 0, nil, err
   173  	}
   174  	if _, err := io.Copy(t.out, o.Body); err != nil {
   175  		f()
   176  		o.Body.Close()
   177  		return 0, nil, err
   178  	}
   179  	if s.Post != nil {
   180  		s.Post(Result{
   181  			IP:      r.RemoteAddr,
   182  			URL:     s.String(),
   183  			Path:    s.Path,
   184  			UUID:    u,
   185  			Status:  uint16(o.StatusCode),
   186  			Method:  r.Method,
   187  			Content: t.out.Bytes(),
   188  			Headers: o.Header,
   189  		})
   190  	}
   191  	f()
   192  	o.Body.Close()
   193  	return o.StatusCode, o.Header, nil
   194  }