github.com/ipfans/trojan-go@v0.11.0/proxy/proxy.go (about)

     1  package proxy
     2  
     3  import (
     4  	"context"
     5  	"io"
     6  	"math/rand"
     7  	"net"
     8  	"os"
     9  	"strings"
    10  
    11  	"github.com/ipfans/trojan-go/common"
    12  	"github.com/ipfans/trojan-go/config"
    13  	"github.com/ipfans/trojan-go/log"
    14  	"github.com/ipfans/trojan-go/tunnel"
    15  )
    16  
    17  const Name = "PROXY"
    18  
    19  const (
    20  	MaxPacketSize = 1024 * 8
    21  )
    22  
    23  // Proxy relay connections and packets
    24  type Proxy struct {
    25  	sources []tunnel.Server
    26  	sink    tunnel.Client
    27  	ctx     context.Context
    28  	cancel  context.CancelFunc
    29  }
    30  
    31  func (p *Proxy) Run() error {
    32  	p.relayConnLoop()
    33  	p.relayPacketLoop()
    34  	<-p.ctx.Done()
    35  	return nil
    36  }
    37  
    38  func (p *Proxy) Close() error {
    39  	p.cancel()
    40  	p.sink.Close()
    41  	for _, source := range p.sources {
    42  		source.Close()
    43  	}
    44  	return nil
    45  }
    46  
    47  func (p *Proxy) relayConnLoop() {
    48  	for _, source := range p.sources {
    49  		go func(source tunnel.Server) {
    50  			for {
    51  				inbound, err := source.AcceptConn(nil)
    52  				if err != nil {
    53  					select {
    54  					case <-p.ctx.Done():
    55  						log.Debug("exiting")
    56  						return
    57  					default:
    58  					}
    59  					log.Error(common.NewError("failed to accept connection").Base(err))
    60  					continue
    61  				}
    62  				go func(inbound tunnel.Conn) {
    63  					defer inbound.Close()
    64  					outbound, err := p.sink.DialConn(inbound.Metadata().Address, nil)
    65  					if err != nil {
    66  						log.Error(common.NewError("proxy failed to dial connection").Base(err))
    67  						return
    68  					}
    69  					defer outbound.Close()
    70  					errChan := make(chan error, 2)
    71  					copyConn := func(a, b net.Conn) {
    72  						_, err := io.Copy(a, b)
    73  						errChan <- err
    74  					}
    75  					go copyConn(inbound, outbound)
    76  					go copyConn(outbound, inbound)
    77  					select {
    78  					case err = <-errChan:
    79  						if err != nil {
    80  							log.Error(err)
    81  						}
    82  					case <-p.ctx.Done():
    83  						log.Debug("shutting down conn relay")
    84  						return
    85  					}
    86  					log.Debug("conn relay ends")
    87  				}(inbound)
    88  			}
    89  		}(source)
    90  	}
    91  }
    92  
    93  func (p *Proxy) relayPacketLoop() {
    94  	for _, source := range p.sources {
    95  		go func(source tunnel.Server) {
    96  			for {
    97  				inbound, err := source.AcceptPacket(nil)
    98  				if err != nil {
    99  					select {
   100  					case <-p.ctx.Done():
   101  						log.Debug("exiting")
   102  						return
   103  					default:
   104  					}
   105  					log.Error(common.NewError("failed to accept packet").Base(err))
   106  					continue
   107  				}
   108  				go func(inbound tunnel.PacketConn) {
   109  					defer inbound.Close()
   110  					outbound, err := p.sink.DialPacket(nil)
   111  					if err != nil {
   112  						log.Error(common.NewError("proxy failed to dial packet").Base(err))
   113  						return
   114  					}
   115  					defer outbound.Close()
   116  					errChan := make(chan error, 2)
   117  					copyPacket := func(a, b tunnel.PacketConn) {
   118  						for {
   119  							buf := make([]byte, MaxPacketSize)
   120  							n, metadata, err := a.ReadWithMetadata(buf)
   121  							if err != nil {
   122  								errChan <- err
   123  								return
   124  							}
   125  							if n == 0 {
   126  								errChan <- nil
   127  								return
   128  							}
   129  							_, err = b.WriteWithMetadata(buf[:n], metadata)
   130  							if err != nil {
   131  								errChan <- err
   132  								return
   133  							}
   134  						}
   135  					}
   136  					go copyPacket(inbound, outbound)
   137  					go copyPacket(outbound, inbound)
   138  					select {
   139  					case err = <-errChan:
   140  						if err != nil {
   141  							log.Error(err)
   142  						}
   143  					case <-p.ctx.Done():
   144  						log.Debug("shutting down packet relay")
   145  					}
   146  					log.Debug("packet relay ends")
   147  				}(inbound)
   148  			}
   149  		}(source)
   150  	}
   151  }
   152  
   153  func NewProxy(ctx context.Context, cancel context.CancelFunc, sources []tunnel.Server, sink tunnel.Client) *Proxy {
   154  	return &Proxy{
   155  		sources: sources,
   156  		sink:    sink,
   157  		ctx:     ctx,
   158  		cancel:  cancel,
   159  	}
   160  }
   161  
   162  type Creator func(ctx context.Context) (*Proxy, error)
   163  
   164  var creators = make(map[string]Creator)
   165  
   166  func RegisterProxyCreator(name string, creator Creator) {
   167  	creators[name] = creator
   168  }
   169  
   170  func NewProxyFromConfigData(data []byte, isJSON bool) (*Proxy, error) {
   171  	// create a unique context for each proxy instance to avoid duplicated authenticator
   172  	ctx := context.WithValue(context.Background(), Name+"_ID", rand.Int())
   173  	var err error
   174  	if isJSON {
   175  		ctx, err = config.WithJSONConfig(ctx, data)
   176  		if err != nil {
   177  			return nil, err
   178  		}
   179  	} else {
   180  		ctx, err = config.WithYAMLConfig(ctx, data)
   181  		if err != nil {
   182  			return nil, err
   183  		}
   184  	}
   185  	cfg := config.FromContext(ctx, Name).(*Config)
   186  	create, ok := creators[strings.ToUpper(cfg.RunType)]
   187  	if !ok {
   188  		return nil, common.NewError("unknown proxy type: " + cfg.RunType)
   189  	}
   190  	log.SetLogLevel(log.LogLevel(cfg.LogLevel))
   191  	if cfg.LogFile != "" {
   192  		file, err := os.OpenFile(cfg.LogFile, os.O_CREATE|os.O_APPEND|os.O_WRONLY, 0o644)
   193  		if err != nil {
   194  			return nil, common.NewError("failed to open log file").Base(err)
   195  		}
   196  		log.SetOutput(file)
   197  	}
   198  	return create(ctx)
   199  }