github.com/mysteriumnetwork/node@v0.0.0-20240516044423-365054f76801/services/wireguard/endpoint/proxyclient/utils.go (about) 1 /* 2 * Copyright (C) 2022 The "MysteriumNetwork/node" Authors. 3 * 4 * This program is free software: you can redistribute it and/or modify 5 * it under the terms of the GNU General Public License as published by 6 * the Free Software Foundation, either version 3 of the License, or 7 * (at your option) any later version. 8 * 9 * This program is distributed in the hope that it will be useful, 10 * but WITHOUT ANY WARRANTY; without even the implied warranty of 11 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 12 * GNU General Public License for more details. 13 * 14 * You should have received a copy of the GNU General Public License 15 * along with this program. If not, see <http://www.gnu.org/licenses/>. 16 */ 17 18 package proxyclient 19 20 import ( 21 "bufio" 22 "context" 23 "errors" 24 "io" 25 "net" 26 "net/http" 27 "sync" 28 "time" 29 ) 30 31 const copyBufferSize = 128 * 1024 32 33 var bufferPool = NewBufferPool(copyBufferSize) 34 35 func proxyHTTP1(ctx context.Context, left, right net.Conn) { 36 wg := sync.WaitGroup{} 37 38 idleTimeout := 5 * time.Minute 39 timeout := time.AfterFunc(idleTimeout, func() { 40 left.Close() 41 right.Close() 42 }) 43 extend := func() { 44 timeout.Reset(idleTimeout) 45 } 46 47 cpy := func(dst, src net.Conn) { 48 defer wg.Done() 49 50 copyBuffer(dst, src, extend) 51 dst.Close() 52 } 53 wg.Add(2) 54 go cpy(left, right) 55 go cpy(right, left) 56 groupDone := make(chan struct{}, 1) 57 go func() { 58 wg.Wait() 59 groupDone <- struct{}{} 60 }() 61 select { 62 case <-ctx.Done(): 63 left.Close() 64 right.Close() 65 case <-groupDone: 66 return 67 } 68 <-groupDone 69 return 70 } 71 72 func proxyHTTP2(ctx context.Context, leftreader io.ReadCloser, leftwriter io.Writer, right net.Conn) { 73 wg := sync.WaitGroup{} 74 75 idleTimeout := 5 * time.Minute 76 timeout := time.AfterFunc(idleTimeout, func() { 77 leftreader.Close() 78 right.Close() 79 }) 80 extend := func() { 81 timeout.Reset(idleTimeout) 82 } 83 84 ltr := func(dst net.Conn, src io.Reader) { 85 defer wg.Done() 86 copyBuffer(dst, src, extend) 87 dst.Close() 88 } 89 rtl := func(dst io.Writer, src io.Reader) { 90 defer wg.Done() 91 copyBody(dst, src) 92 } 93 wg.Add(2) 94 go ltr(right, leftreader) 95 go rtl(leftwriter, right) 96 groupDone := make(chan struct{}, 1) 97 go func() { 98 wg.Wait() 99 groupDone <- struct{}{} 100 }() 101 select { 102 case <-ctx.Done(): 103 leftreader.Close() 104 right.Close() 105 case <-groupDone: 106 return 107 } 108 <-groupDone 109 return 110 } 111 112 // Hop-by-hop headers. These are removed when sent to the backend. 113 // http://www.w3.org/Protocols/rfc2616/rfc2616-sec13.html 114 var hopHeaders = []string{ 115 "Connection", 116 "Keep-Alive", 117 "Proxy-Authenticate", 118 "Proxy-Connection", 119 "Proxy-Authorization", 120 "Te", // canonicalized version of "TE" 121 "Trailers", 122 "Transfer-Encoding", 123 "Upgrade", 124 } 125 126 func copyHeader(dst, src http.Header) { 127 for k, vv := range src { 128 for _, v := range vv { 129 dst.Add(k, v) 130 } 131 } 132 } 133 134 func delHopHeaders(header http.Header) { 135 for _, h := range hopHeaders { 136 header.Del(h) 137 } 138 } 139 140 func hijack(hijackable interface{}) (net.Conn, *bufio.ReadWriter, error) { 141 hj, ok := hijackable.(http.Hijacker) 142 if !ok { 143 return nil, nil, errors.New("connection does not support hijacking") 144 } 145 conn, rw, err := hj.Hijack() 146 if err != nil { 147 return nil, nil, err 148 } 149 var emptyTime time.Time 150 err = conn.SetDeadline(emptyTime) 151 if err != nil { 152 conn.Close() 153 return nil, nil, err 154 } 155 return conn, rw, nil 156 } 157 158 func flush(flusher interface{}) bool { 159 f, ok := flusher.(http.Flusher) 160 if !ok { 161 return false 162 } 163 f.Flush() 164 return true 165 } 166 167 func copyBody(wr io.Writer, body io.Reader) { 168 buf := bufferPool.Get() 169 defer bufferPool.Put(buf) 170 171 for { 172 bread, readErr := body.Read(buf) 173 var writeErr error 174 if bread > 0 { 175 _, writeErr = wr.Write(buf[:bread]) 176 flush(wr) 177 } 178 if readErr != nil || writeErr != nil { 179 break 180 } 181 } 182 } 183 184 func copyBuffer(dst io.Writer, src io.Reader, extend func()) (written int64, err error) { 185 buf := bufferPool.Get() 186 defer bufferPool.Put(buf) 187 188 for { 189 extend() 190 nr, er := src.Read(buf) 191 if nr > 0 { 192 nw, ew := dst.Write(buf[0:nr]) 193 if nw < 0 || nr < nw { 194 nw = 0 195 if ew == nil { 196 ew = errors.New("invalid write result") 197 } 198 } 199 written += int64(nw) 200 if ew != nil { 201 err = ew 202 break 203 } 204 if nr != nw { 205 err = io.ErrShortWrite 206 break 207 } 208 } 209 if er != nil { 210 if er != io.EOF { 211 err = er 212 } 213 break 214 } 215 } 216 return written, err 217 }