github.com/craicoverflow/tyk@v2.9.6-rc3+incompatible/tcp/tcp.go (about)

     1  package tcp
     2  
     3  import (
     4  	"context"
     5  	"crypto/tls"
     6  	"errors"
     7  	"io"
     8  	"net"
     9  	"net/url"
    10  	"strings"
    11  	"sync"
    12  	"sync/atomic"
    13  	"time"
    14  
    15  	logger "github.com/TykTechnologies/tyk/log"
    16  )
    17  
    18  var log = logger.Get().WithField("prefix", "tcp-proxy")
    19  
    20  type ConnState uint
    21  
    22  const (
    23  	Active ConnState = iota
    24  	Open
    25  	Closed
    26  )
    27  
    28  // Modifier define rules for tranforming incoming and outcoming TCP messages
    29  // To filter response set data to empty
    30  // To close connection, return error
    31  type Modifier struct {
    32  	ModifyRequest  func(src, dst net.Conn, data []byte) ([]byte, error)
    33  	ModifyResponse func(src, dst net.Conn, data []byte) ([]byte, error)
    34  }
    35  
    36  type targetConfig struct {
    37  	modifier *Modifier
    38  	target   string
    39  }
    40  
    41  // Stat defines basic statistics about a tcp connection
    42  type Stat struct {
    43  	State    ConnState
    44  	BytesIn  int64
    45  	BytesOut int64
    46  }
    47  
    48  func (s *Stat) Flush() Stat {
    49  	v := Stat{
    50  		BytesIn:  atomic.LoadInt64(&s.BytesIn),
    51  		BytesOut: atomic.LoadInt64(&s.BytesOut),
    52  	}
    53  	atomic.StoreInt64(&s.BytesIn, 0)
    54  	atomic.StoreInt64(&s.BytesOut, 0)
    55  	return v
    56  }
    57  
    58  type Proxy struct {
    59  	sync.RWMutex
    60  
    61  	DialTLS         func(network, addr string) (net.Conn, error)
    62  	Dial            func(network, addr string) (net.Conn, error)
    63  	TLSConfigTarget *tls.Config
    64  
    65  	ReadTimeout  time.Duration
    66  	WriteTimeout time.Duration
    67  
    68  	// Domain to config mapping
    69  	muxer     map[string]*targetConfig
    70  	SyncStats func(Stat)
    71  	// Duration in which connection stats will be flushed. Defaults to one second.
    72  	StatsSyncInterval time.Duration
    73  }
    74  
    75  func (p *Proxy) AddDomainHandler(domain, target string, modifier *Modifier) {
    76  	p.Lock()
    77  	defer p.Unlock()
    78  
    79  	if p.muxer == nil {
    80  		p.muxer = make(map[string]*targetConfig)
    81  	}
    82  
    83  	if modifier == nil {
    84  		modifier = &Modifier{}
    85  	}
    86  
    87  	p.muxer[domain] = &targetConfig{
    88  		modifier: modifier,
    89  		target:   target,
    90  	}
    91  }
    92  
    93  func (p *Proxy) Swap(new *Proxy) {
    94  	p.Lock()
    95  	defer p.Unlock()
    96  
    97  	p.muxer = new.muxer
    98  }
    99  
   100  func (p *Proxy) RemoveDomainHandler(domain string) {
   101  	p.Lock()
   102  	defer p.Unlock()
   103  
   104  	delete(p.muxer, domain)
   105  }
   106  
   107  func (p *Proxy) Serve(l net.Listener) error {
   108  	for {
   109  		conn, err := l.Accept()
   110  		if err != nil {
   111  			log.WithError(err).Warning("Can't accept connection")
   112  			return err
   113  		}
   114  		go func() {
   115  			if err := p.handleConn(conn); err != nil {
   116  				log.WithError(err).Warning("Can't handle connection")
   117  			}
   118  		}()
   119  	}
   120  }
   121  
   122  func (p *Proxy) getTargetConfig(conn net.Conn) (*targetConfig, error) {
   123  	p.RLock()
   124  	defer p.RUnlock()
   125  
   126  	if len(p.muxer) == 0 {
   127  		return nil, errors.New("No services defined")
   128  	}
   129  
   130  	switch v := conn.(type) {
   131  	case *tls.Conn:
   132  		if err := v.Handshake(); err != nil {
   133  			return nil, err
   134  		}
   135  
   136  		state := v.ConnectionState()
   137  
   138  		if state.ServerName == "" {
   139  			// If SNI disabled, and only 1 record defined return it
   140  			if len(p.muxer) == 1 {
   141  				for _, config := range p.muxer {
   142  					return config, nil
   143  				}
   144  			}
   145  
   146  			return nil, errors.New("Multiple services on different domains running on the same port, but no SNI (domain) information from client")
   147  		}
   148  
   149  		// If SNI supported try to match domain
   150  		if config, ok := p.muxer[state.ServerName]; ok {
   151  			return config, nil
   152  		}
   153  
   154  		// If no custom domains are used
   155  		if config, ok := p.muxer[""]; ok {
   156  			return config, nil
   157  		}
   158  
   159  		return nil, errors.New("Can't detect service based on provided SNI information: " + state.ServerName)
   160  	default:
   161  		if len(p.muxer) > 1 {
   162  			return nil, errors.New("Running multiple services without TLS and SNI not supported")
   163  		}
   164  
   165  		for _, config := range p.muxer {
   166  			return config, nil
   167  		}
   168  	}
   169  
   170  	return nil, errors.New("Can't detect service configuration")
   171  }
   172  
   173  func (p *Proxy) handleConn(conn net.Conn) error {
   174  	var connectionClosed atomic.Value
   175  	connectionClosed.Store(false)
   176  
   177  	stat := Stat{}
   178  
   179  	ctx, cancel := context.WithCancel(context.Background())
   180  	defer cancel()
   181  	if p.SyncStats != nil {
   182  		go func() {
   183  			duration := p.StatsSyncInterval
   184  			if duration == 0 {
   185  				duration = time.Second
   186  			}
   187  			tick := time.NewTicker(duration)
   188  			defer tick.Stop()
   189  			p.SyncStats(Stat{State: Open})
   190  			for {
   191  				select {
   192  				case <-ctx.Done():
   193  					s := stat.Flush()
   194  					s.State = Closed
   195  					p.SyncStats(s)
   196  					return
   197  				case <-tick.C:
   198  					p.SyncStats(stat.Flush())
   199  				}
   200  			}
   201  		}()
   202  	}
   203  	config, err := p.getTargetConfig(conn)
   204  	if err != nil {
   205  		conn.Close()
   206  		return err
   207  	}
   208  	u, uErr := url.Parse(config.target)
   209  	if uErr != nil {
   210  		u, uErr = url.Parse("tcp://" + config.target)
   211  
   212  		if uErr != nil {
   213  			conn.Close()
   214  			return uErr
   215  		}
   216  	}
   217  
   218  	// connects to target server
   219  	var rconn net.Conn
   220  	switch u.Scheme {
   221  	case "tcp":
   222  		if p.Dial != nil {
   223  			rconn, err = p.Dial("tcp", u.Host)
   224  		} else {
   225  			rconn, err = net.Dial("tcp", u.Host)
   226  		}
   227  	case "tls":
   228  		if p.DialTLS != nil {
   229  			rconn, err = p.DialTLS("tcp", u.Host)
   230  		} else {
   231  			rconn, err = tls.Dial("tcp", u.Host, p.TLSConfigTarget)
   232  		}
   233  	default:
   234  		err = errors.New("Unsupported protocol. Should be empty, `tcp` or `tls`")
   235  	}
   236  	if err != nil {
   237  		conn.Close()
   238  		return err
   239  	}
   240  	defer func() {
   241  		conn.Close()
   242  		rconn.Close()
   243  	}()
   244  	var wg sync.WaitGroup
   245  	wg.Add(2)
   246  
   247  	r := pipeOpts{
   248  		modifier: func(src, dst net.Conn, data []byte) ([]byte, error) {
   249  			atomic.AddInt64(&stat.BytesIn, int64(len(data)))
   250  			h := config.modifier.ModifyRequest
   251  			if h != nil {
   252  				return h(src, dst, data)
   253  			}
   254  			return data, nil
   255  		},
   256  		beforeExit: func() {
   257  			wg.Done()
   258  		},
   259  		onReadError: func(err error) {
   260  			if IsSocketClosed(err) && connectionClosed.Load().(bool) {
   261  				return
   262  			}
   263  			if err == io.EOF {
   264  				// End of stream from the client.
   265  				connectionClosed.Store(true)
   266  				log.WithField("conn", clientConn(conn)).Debug("End of client stream")
   267  			} else {
   268  				log.WithError(err).Error("Failed to read from client connection")
   269  			}
   270  		},
   271  		onWriteError: func(err error) {
   272  			log.WithError(err).Info("Failed to write to upstream socket")
   273  		},
   274  	}
   275  	w := pipeOpts{
   276  		modifier: func(src, dst net.Conn, data []byte) ([]byte, error) {
   277  			atomic.AddInt64(&stat.BytesOut, int64(len(data)))
   278  			h := config.modifier.ModifyResponse
   279  			if h != nil {
   280  				return h(src, dst, data)
   281  			}
   282  			return data, nil
   283  		},
   284  		beforeExit: func() {
   285  			wg.Done()
   286  		},
   287  		onReadError: func(err error) {
   288  			if IsSocketClosed(err) && connectionClosed.Load().(bool) {
   289  				return
   290  			}
   291  			if err == io.EOF {
   292  				// End of stream from upstream
   293  				connectionClosed.Store(true)
   294  				log.WithField("conn", upstreamConn(rconn)).Debug("End of upstream stream")
   295  			} else {
   296  				log.WithError(err).Error("Failed to read from upstream connection")
   297  			}
   298  		},
   299  		onWriteError: func(err error) {
   300  			log.WithError(err).Info("Failed to write to client connection")
   301  		},
   302  	}
   303  	go p.pipe(conn, rconn, r)
   304  	go p.pipe(rconn, conn, w)
   305  	wg.Wait()
   306  	return nil
   307  }
   308  
   309  func upstreamConn(c net.Conn) string {
   310  	return formatAddress(c.LocalAddr(), c.RemoteAddr())
   311  }
   312  
   313  func clientConn(c net.Conn) string {
   314  	return formatAddress(c.RemoteAddr(), c.LocalAddr())
   315  }
   316  
   317  func formatAddress(a, b net.Addr) string {
   318  	return a.String() + "->" + b.String()
   319  }
   320  
   321  // IsSocketClosed returns true if err is a result of reading from closed network
   322  // connection
   323  func IsSocketClosed(err error) bool {
   324  	return strings.Contains(err.Error(), "use of closed network connection")
   325  }
   326  
   327  type pipeOpts struct {
   328  	modifier     func(net.Conn, net.Conn, []byte) ([]byte, error)
   329  	onReadError  func(error)
   330  	onWriteError func(error)
   331  	beforeExit   func()
   332  }
   333  
   334  func (p *Proxy) pipe(src, dst net.Conn, opts pipeOpts) {
   335  	defer func() {
   336  		src.Close()
   337  		dst.Close()
   338  		if opts.beforeExit != nil {
   339  			opts.beforeExit()
   340  		}
   341  	}()
   342  
   343  	buf := make([]byte, 65535)
   344  
   345  	for {
   346  		var readDeadline time.Time
   347  		if p.ReadTimeout != 0 {
   348  			readDeadline = time.Now().Add(p.ReadTimeout)
   349  		}
   350  		src.SetReadDeadline(readDeadline)
   351  		n, err := src.Read(buf)
   352  		if err != nil {
   353  			if opts.onReadError != nil {
   354  				opts.onReadError(err)
   355  			}
   356  			return
   357  		}
   358  		b := buf[:n]
   359  
   360  		if opts.modifier != nil {
   361  			if b, err = opts.modifier(src, dst, b); err != nil {
   362  				log.WithError(err).Warning("Closing connection")
   363  				return
   364  			}
   365  		}
   366  
   367  		if len(b) == 0 {
   368  			continue
   369  		}
   370  
   371  		var writeDeadline time.Time
   372  		if p.WriteTimeout != 0 {
   373  			writeDeadline = time.Now().Add(p.WriteTimeout)
   374  		}
   375  		dst.SetWriteDeadline(writeDeadline)
   376  		_, err = dst.Write(b)
   377  		if err != nil {
   378  			if opts.onWriteError != nil {
   379  				opts.onWriteError(err)
   380  			}
   381  			return
   382  		}
   383  	}
   384  }