github.com/craicoverflow/tyk@v2.9.6-rc3+incompatible/tcp/tcp.go (about) 1 package tcp 2 3 import ( 4 "context" 5 "crypto/tls" 6 "errors" 7 "io" 8 "net" 9 "net/url" 10 "strings" 11 "sync" 12 "sync/atomic" 13 "time" 14 15 logger "github.com/TykTechnologies/tyk/log" 16 ) 17 18 var log = logger.Get().WithField("prefix", "tcp-proxy") 19 20 type ConnState uint 21 22 const ( 23 Active ConnState = iota 24 Open 25 Closed 26 ) 27 28 // Modifier define rules for tranforming incoming and outcoming TCP messages 29 // To filter response set data to empty 30 // To close connection, return error 31 type Modifier struct { 32 ModifyRequest func(src, dst net.Conn, data []byte) ([]byte, error) 33 ModifyResponse func(src, dst net.Conn, data []byte) ([]byte, error) 34 } 35 36 type targetConfig struct { 37 modifier *Modifier 38 target string 39 } 40 41 // Stat defines basic statistics about a tcp connection 42 type Stat struct { 43 State ConnState 44 BytesIn int64 45 BytesOut int64 46 } 47 48 func (s *Stat) Flush() Stat { 49 v := Stat{ 50 BytesIn: atomic.LoadInt64(&s.BytesIn), 51 BytesOut: atomic.LoadInt64(&s.BytesOut), 52 } 53 atomic.StoreInt64(&s.BytesIn, 0) 54 atomic.StoreInt64(&s.BytesOut, 0) 55 return v 56 } 57 58 type Proxy struct { 59 sync.RWMutex 60 61 DialTLS func(network, addr string) (net.Conn, error) 62 Dial func(network, addr string) (net.Conn, error) 63 TLSConfigTarget *tls.Config 64 65 ReadTimeout time.Duration 66 WriteTimeout time.Duration 67 68 // Domain to config mapping 69 muxer map[string]*targetConfig 70 SyncStats func(Stat) 71 // Duration in which connection stats will be flushed. Defaults to one second. 72 StatsSyncInterval time.Duration 73 } 74 75 func (p *Proxy) AddDomainHandler(domain, target string, modifier *Modifier) { 76 p.Lock() 77 defer p.Unlock() 78 79 if p.muxer == nil { 80 p.muxer = make(map[string]*targetConfig) 81 } 82 83 if modifier == nil { 84 modifier = &Modifier{} 85 } 86 87 p.muxer[domain] = &targetConfig{ 88 modifier: modifier, 89 target: target, 90 } 91 } 92 93 func (p *Proxy) Swap(new *Proxy) { 94 p.Lock() 95 defer p.Unlock() 96 97 p.muxer = new.muxer 98 } 99 100 func (p *Proxy) RemoveDomainHandler(domain string) { 101 p.Lock() 102 defer p.Unlock() 103 104 delete(p.muxer, domain) 105 } 106 107 func (p *Proxy) Serve(l net.Listener) error { 108 for { 109 conn, err := l.Accept() 110 if err != nil { 111 log.WithError(err).Warning("Can't accept connection") 112 return err 113 } 114 go func() { 115 if err := p.handleConn(conn); err != nil { 116 log.WithError(err).Warning("Can't handle connection") 117 } 118 }() 119 } 120 } 121 122 func (p *Proxy) getTargetConfig(conn net.Conn) (*targetConfig, error) { 123 p.RLock() 124 defer p.RUnlock() 125 126 if len(p.muxer) == 0 { 127 return nil, errors.New("No services defined") 128 } 129 130 switch v := conn.(type) { 131 case *tls.Conn: 132 if err := v.Handshake(); err != nil { 133 return nil, err 134 } 135 136 state := v.ConnectionState() 137 138 if state.ServerName == "" { 139 // If SNI disabled, and only 1 record defined return it 140 if len(p.muxer) == 1 { 141 for _, config := range p.muxer { 142 return config, nil 143 } 144 } 145 146 return nil, errors.New("Multiple services on different domains running on the same port, but no SNI (domain) information from client") 147 } 148 149 // If SNI supported try to match domain 150 if config, ok := p.muxer[state.ServerName]; ok { 151 return config, nil 152 } 153 154 // If no custom domains are used 155 if config, ok := p.muxer[""]; ok { 156 return config, nil 157 } 158 159 return nil, errors.New("Can't detect service based on provided SNI information: " + state.ServerName) 160 default: 161 if len(p.muxer) > 1 { 162 return nil, errors.New("Running multiple services without TLS and SNI not supported") 163 } 164 165 for _, config := range p.muxer { 166 return config, nil 167 } 168 } 169 170 return nil, errors.New("Can't detect service configuration") 171 } 172 173 func (p *Proxy) handleConn(conn net.Conn) error { 174 var connectionClosed atomic.Value 175 connectionClosed.Store(false) 176 177 stat := Stat{} 178 179 ctx, cancel := context.WithCancel(context.Background()) 180 defer cancel() 181 if p.SyncStats != nil { 182 go func() { 183 duration := p.StatsSyncInterval 184 if duration == 0 { 185 duration = time.Second 186 } 187 tick := time.NewTicker(duration) 188 defer tick.Stop() 189 p.SyncStats(Stat{State: Open}) 190 for { 191 select { 192 case <-ctx.Done(): 193 s := stat.Flush() 194 s.State = Closed 195 p.SyncStats(s) 196 return 197 case <-tick.C: 198 p.SyncStats(stat.Flush()) 199 } 200 } 201 }() 202 } 203 config, err := p.getTargetConfig(conn) 204 if err != nil { 205 conn.Close() 206 return err 207 } 208 u, uErr := url.Parse(config.target) 209 if uErr != nil { 210 u, uErr = url.Parse("tcp://" + config.target) 211 212 if uErr != nil { 213 conn.Close() 214 return uErr 215 } 216 } 217 218 // connects to target server 219 var rconn net.Conn 220 switch u.Scheme { 221 case "tcp": 222 if p.Dial != nil { 223 rconn, err = p.Dial("tcp", u.Host) 224 } else { 225 rconn, err = net.Dial("tcp", u.Host) 226 } 227 case "tls": 228 if p.DialTLS != nil { 229 rconn, err = p.DialTLS("tcp", u.Host) 230 } else { 231 rconn, err = tls.Dial("tcp", u.Host, p.TLSConfigTarget) 232 } 233 default: 234 err = errors.New("Unsupported protocol. Should be empty, `tcp` or `tls`") 235 } 236 if err != nil { 237 conn.Close() 238 return err 239 } 240 defer func() { 241 conn.Close() 242 rconn.Close() 243 }() 244 var wg sync.WaitGroup 245 wg.Add(2) 246 247 r := pipeOpts{ 248 modifier: func(src, dst net.Conn, data []byte) ([]byte, error) { 249 atomic.AddInt64(&stat.BytesIn, int64(len(data))) 250 h := config.modifier.ModifyRequest 251 if h != nil { 252 return h(src, dst, data) 253 } 254 return data, nil 255 }, 256 beforeExit: func() { 257 wg.Done() 258 }, 259 onReadError: func(err error) { 260 if IsSocketClosed(err) && connectionClosed.Load().(bool) { 261 return 262 } 263 if err == io.EOF { 264 // End of stream from the client. 265 connectionClosed.Store(true) 266 log.WithField("conn", clientConn(conn)).Debug("End of client stream") 267 } else { 268 log.WithError(err).Error("Failed to read from client connection") 269 } 270 }, 271 onWriteError: func(err error) { 272 log.WithError(err).Info("Failed to write to upstream socket") 273 }, 274 } 275 w := pipeOpts{ 276 modifier: func(src, dst net.Conn, data []byte) ([]byte, error) { 277 atomic.AddInt64(&stat.BytesOut, int64(len(data))) 278 h := config.modifier.ModifyResponse 279 if h != nil { 280 return h(src, dst, data) 281 } 282 return data, nil 283 }, 284 beforeExit: func() { 285 wg.Done() 286 }, 287 onReadError: func(err error) { 288 if IsSocketClosed(err) && connectionClosed.Load().(bool) { 289 return 290 } 291 if err == io.EOF { 292 // End of stream from upstream 293 connectionClosed.Store(true) 294 log.WithField("conn", upstreamConn(rconn)).Debug("End of upstream stream") 295 } else { 296 log.WithError(err).Error("Failed to read from upstream connection") 297 } 298 }, 299 onWriteError: func(err error) { 300 log.WithError(err).Info("Failed to write to client connection") 301 }, 302 } 303 go p.pipe(conn, rconn, r) 304 go p.pipe(rconn, conn, w) 305 wg.Wait() 306 return nil 307 } 308 309 func upstreamConn(c net.Conn) string { 310 return formatAddress(c.LocalAddr(), c.RemoteAddr()) 311 } 312 313 func clientConn(c net.Conn) string { 314 return formatAddress(c.RemoteAddr(), c.LocalAddr()) 315 } 316 317 func formatAddress(a, b net.Addr) string { 318 return a.String() + "->" + b.String() 319 } 320 321 // IsSocketClosed returns true if err is a result of reading from closed network 322 // connection 323 func IsSocketClosed(err error) bool { 324 return strings.Contains(err.Error(), "use of closed network connection") 325 } 326 327 type pipeOpts struct { 328 modifier func(net.Conn, net.Conn, []byte) ([]byte, error) 329 onReadError func(error) 330 onWriteError func(error) 331 beforeExit func() 332 } 333 334 func (p *Proxy) pipe(src, dst net.Conn, opts pipeOpts) { 335 defer func() { 336 src.Close() 337 dst.Close() 338 if opts.beforeExit != nil { 339 opts.beforeExit() 340 } 341 }() 342 343 buf := make([]byte, 65535) 344 345 for { 346 var readDeadline time.Time 347 if p.ReadTimeout != 0 { 348 readDeadline = time.Now().Add(p.ReadTimeout) 349 } 350 src.SetReadDeadline(readDeadline) 351 n, err := src.Read(buf) 352 if err != nil { 353 if opts.onReadError != nil { 354 opts.onReadError(err) 355 } 356 return 357 } 358 b := buf[:n] 359 360 if opts.modifier != nil { 361 if b, err = opts.modifier(src, dst, b); err != nil { 362 log.WithError(err).Warning("Closing connection") 363 return 364 } 365 } 366 367 if len(b) == 0 { 368 continue 369 } 370 371 var writeDeadline time.Time 372 if p.WriteTimeout != 0 { 373 writeDeadline = time.Now().Add(p.WriteTimeout) 374 } 375 dst.SetWriteDeadline(writeDeadline) 376 _, err = dst.Write(b) 377 if err != nil { 378 if opts.onWriteError != nil { 379 opts.onWriteError(err) 380 } 381 return 382 } 383 } 384 }