github.com/PurpleSec/switchproxy@v1.6.2/switch.go (about) 1 // Copyright 2021 - 2022 PurpleSec Team 2 // 3 // This program is free software: you can redistribute it and/or modify 4 // it under the terms of the GNU Affero General Public License as published 5 // by the Free Software Foundation, either version 3 of the License, or 6 // (at your option) any later version. 7 // 8 // This program is distributed in the hope that it will be useful, 9 // but WITHOUT ANY WARRANTY; without even the implied warranty of 10 // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 11 // GNU Affero General Public License for more details. 12 // 13 // You should have received a copy of the GNU Affero General Public License 14 // along with this program. If not, see <https://www.gnu.org/licenses/>. 15 // 16 17 package switchproxy 18 19 import ( 20 "context" 21 "errors" 22 "io" 23 "net" 24 "net/http" 25 "net/url" 26 "path" 27 "strings" 28 "time" 29 30 // Import unsafe to use "fastrand" function 31 _ "unsafe" 32 ) 33 34 const table = "0123456789ABCDEF" 35 36 // Result is a struct that contains the data of the resulting Switch 37 // operation to be passed to Handlers. 38 type Result struct { 39 Headers http.Header `json:"headers"` 40 IP string `json:"ip"` 41 UUID string `json:"uuid"` 42 Path string `json:"path"` 43 Method string `json:"method"` 44 URL string `json:"url"` 45 Content []byte `json:"content"` 46 Status uint16 `json:"status"` 47 } 48 49 // Switch is a struct that represents a connection between proxy services. 50 // This struct contains mapping and functions to capture input and output. 51 type Switch struct { 52 Pre Handler 53 Post Handler 54 client *http.Client 55 rewrite map[string]string 56 url.URL 57 timeout time.Duration 58 } 59 60 // Handler is a function alias that can be passed a Result for processing. 61 type Handler func(Result) 62 63 //go:linkname fastRand runtime.fastrand 64 func fastRand() uint32 65 func newUUID() string { 66 var b [64]byte 67 for i := 0; i < 64; i += 2 { 68 v := byte(fastRand() & 0xFF) 69 if v < 16 { 70 b[i], b[i+1] = '0', table[v&0x0F] 71 } 72 b[i], b[i+1] = table[v>>4], table[v&0x0F] 73 } 74 return string(b[:]) 75 } 76 77 // IsResponse is a function that returns true if the Result is for a response. 78 func (r Result) IsResponse() bool { 79 return len(r.Method) > 0 && r.Status > 0 80 } 81 82 // Rewrite adds a URL rewrite from the Switch. 83 // 84 // If a URL starts with the 'from' parameter, it will be replaced with the 'to' 85 // parameter, only if starting with on the URL path. 86 func (s *Switch) Rewrite(from, to string) { 87 s.rewrite[from] = to 88 } 89 90 // RemoveRewrite removes the URL rewrite from the Switch. 91 func (s *Switch) RemoveRewrite(from string) { 92 delete(s.rewrite, from) 93 } 94 95 // NewSwitch creates a switching context that allows the connection to be proxied 96 // to the specified server. 97 func NewSwitch(target string) (*Switch, error) { 98 return NewSwitchTimeout(target, DefaultTimeout) 99 } 100 101 // NewSwitchTimeout creates a switching context that allows the connection to be 102 // proxied to the specified server. 103 // 104 // This function will set the specified timeout. 105 func NewSwitchTimeout(target string, t time.Duration) (*Switch, error) { 106 u, err := url.Parse(target) 107 if err != nil { 108 return nil, errors.New("unable to resolve URL: " + err.Error()) 109 } 110 if !u.IsAbs() { 111 u.Scheme = "http" 112 } 113 s := &Switch{ 114 URL: *u, 115 client: &http.Client{ 116 Timeout: t, 117 Transport: &http.Transport{ 118 Proxy: http.ProxyFromEnvironment, 119 DialContext: (&net.Dialer{ 120 Timeout: t, 121 KeepAlive: t, 122 }).DialContext, 123 IdleConnTimeout: t, 124 TLSHandshakeTimeout: t, 125 ExpectContinueTimeout: t, 126 ResponseHeaderTimeout: t, 127 }, 128 }, 129 timeout: t, 130 rewrite: make(map[string]string), 131 } 132 return s, nil 133 } 134 func (s Switch) process(x context.Context, r *http.Request, t *transfer) (int, http.Header, error) { 135 s.Path = r.URL.Path 136 s.User = r.URL.User 137 s.Opaque = r.URL.Opaque 138 s.Fragment = r.URL.Fragment 139 s.RawQuery = r.URL.RawQuery 140 s.ForceQuery = r.URL.ForceQuery 141 for k, v := range s.rewrite { 142 if strings.HasPrefix(s.Path, k) { 143 s.Path = path.Join(v, s.Path[len(k):]) 144 } 145 } 146 f := func() {} 147 if s.timeout > 0 { 148 x, f = context.WithTimeout(x, s.timeout) 149 } 150 q, err := http.NewRequestWithContext(x, r.Method, s.String(), t.in) 151 if err != nil { 152 f() 153 return 0, nil, err 154 } 155 u := newUUID() 156 if s.Pre != nil { 157 s.Pre(Result{ 158 IP: r.RemoteAddr, 159 URL: s.String(), 160 UUID: u, 161 Path: s.Path, 162 Method: r.Method, 163 Content: t.data, 164 Headers: r.Header, 165 }) 166 } 167 q.Header, q.Trailer = r.Header, r.Trailer 168 q.TransferEncoding = r.TransferEncoding 169 o, err := s.client.Do(q) 170 if err != nil { 171 f() 172 return 0, nil, err 173 } 174 if _, err := io.Copy(t.out, o.Body); err != nil { 175 f() 176 o.Body.Close() 177 return 0, nil, err 178 } 179 if s.Post != nil { 180 s.Post(Result{ 181 IP: r.RemoteAddr, 182 URL: s.String(), 183 Path: s.Path, 184 UUID: u, 185 Status: uint16(o.StatusCode), 186 Method: r.Method, 187 Content: t.out.Bytes(), 188 Headers: o.Header, 189 }) 190 } 191 f() 192 o.Body.Close() 193 return o.StatusCode, o.Header, nil 194 }