github.com/ipfans/trojan-go@v0.11.0/redirector/redirector.go (about)

     1  package redirector
     2  
     3  import (
     4  	"context"
     5  	"io"
     6  	"net"
     7  	"reflect"
     8  
     9  	"github.com/ipfans/trojan-go/common"
    10  	"github.com/ipfans/trojan-go/log"
    11  )
    12  
    13  type Dial func(net.Addr) (net.Conn, error)
    14  
    15  func defaultDial(addr net.Addr) (net.Conn, error) {
    16  	return net.Dial("tcp", addr.String())
    17  }
    18  
    19  type Redirection struct {
    20  	Dial
    21  	RedirectTo  net.Addr
    22  	InboundConn net.Conn
    23  }
    24  
    25  type Redirector struct {
    26  	ctx             context.Context
    27  	redirectionChan chan *Redirection
    28  }
    29  
    30  func (r *Redirector) Redirect(redirection *Redirection) {
    31  	select {
    32  	case r.redirectionChan <- redirection:
    33  		log.Debug("redirect request")
    34  	case <-r.ctx.Done():
    35  		log.Debug("exiting")
    36  	}
    37  }
    38  
    39  func (r *Redirector) worker() {
    40  	for {
    41  		select {
    42  		case redirection := <-r.redirectionChan:
    43  			handle := func(redirection *Redirection) {
    44  				if redirection.InboundConn == nil || reflect.ValueOf(redirection.InboundConn).IsNil() {
    45  					log.Error("nil inbound conn")
    46  					return
    47  				}
    48  				defer redirection.InboundConn.Close()
    49  				if redirection.RedirectTo == nil || reflect.ValueOf(redirection.RedirectTo).IsNil() {
    50  					log.Error("nil redirection addr")
    51  					return
    52  				}
    53  				if redirection.Dial == nil {
    54  					redirection.Dial = defaultDial
    55  				}
    56  				log.Warn("redirecting connection from", redirection.InboundConn.RemoteAddr(), "to", redirection.RedirectTo.String())
    57  				outboundConn, err := redirection.Dial(redirection.RedirectTo)
    58  				if err != nil {
    59  					log.Error(common.NewError("failed to redirect to target address").Base(err))
    60  					return
    61  				}
    62  				defer outboundConn.Close()
    63  				errChan := make(chan error, 2)
    64  				copyConn := func(a, b net.Conn) {
    65  					_, err := io.Copy(a, b)
    66  					errChan <- err
    67  				}
    68  				go copyConn(outboundConn, redirection.InboundConn)
    69  				go copyConn(redirection.InboundConn, outboundConn)
    70  				select {
    71  				case err := <-errChan:
    72  					if err != nil {
    73  						log.Error(common.NewError("failed to redirect").Base(err))
    74  					}
    75  					log.Info("redirection done")
    76  				case <-r.ctx.Done():
    77  					log.Debug("exiting")
    78  					return
    79  				}
    80  			}
    81  			go handle(redirection)
    82  		case <-r.ctx.Done():
    83  			log.Debug("shutting down redirector")
    84  			return
    85  		}
    86  	}
    87  }
    88  
    89  func NewRedirector(ctx context.Context) *Redirector {
    90  	r := &Redirector{
    91  		ctx:             ctx,
    92  		redirectionChan: make(chan *Redirection, 64),
    93  	}
    94  	go r.worker()
    95  	return r
    96  }