github.com/ipfans/trojan-go@v0.11.0/tunnel/trojan/server.go (about)

     1  package trojan
     2  
     3  import (
     4  	"context"
     5  	"fmt"
     6  	"io"
     7  	"net"
     8  	"sync/atomic"
     9  
    10  	"github.com/ipfans/trojan-go/api"
    11  	"github.com/ipfans/trojan-go/common"
    12  	"github.com/ipfans/trojan-go/config"
    13  	"github.com/ipfans/trojan-go/log"
    14  	"github.com/ipfans/trojan-go/redirector"
    15  	"github.com/ipfans/trojan-go/statistic"
    16  	"github.com/ipfans/trojan-go/statistic/memory"
    17  	"github.com/ipfans/trojan-go/statistic/mysql"
    18  	"github.com/ipfans/trojan-go/tunnel"
    19  	"github.com/ipfans/trojan-go/tunnel/mux"
    20  )
    21  
    22  // InboundConn is a trojan inbound connection
    23  type InboundConn struct {
    24  	// WARNING: do not change the order of these fields.
    25  	// 64-bit fields that use `sync/atomic` package functions
    26  	// must be 64-bit aligned on 32-bit systems.
    27  	// Reference: https://github.com/golang/go/issues/599
    28  	// Solution: https://github.com/golang/go/issues/11891#issuecomment-433623786
    29  	sent uint64
    30  	recv uint64
    31  
    32  	net.Conn
    33  	auth     statistic.Authenticator
    34  	user     statistic.User
    35  	hash     string
    36  	metadata *tunnel.Metadata
    37  	ip       string
    38  }
    39  
    40  func (c *InboundConn) Metadata() *tunnel.Metadata {
    41  	return c.metadata
    42  }
    43  
    44  func (c *InboundConn) Write(p []byte) (int, error) {
    45  	n, err := c.Conn.Write(p)
    46  	atomic.AddUint64(&c.sent, uint64(n))
    47  	c.user.AddTraffic(n, 0)
    48  	return n, err
    49  }
    50  
    51  func (c *InboundConn) Read(p []byte) (int, error) {
    52  	n, err := c.Conn.Read(p)
    53  	atomic.AddUint64(&c.recv, uint64(n))
    54  	c.user.AddTraffic(0, n)
    55  	return n, err
    56  }
    57  
    58  func (c *InboundConn) Close() error {
    59  	log.Info("user", c.hash, "from", c.Conn.RemoteAddr(), "tunneling to", c.metadata.Address, "closed",
    60  		"sent:", common.HumanFriendlyTraffic(atomic.LoadUint64(&c.sent)), "recv:", common.HumanFriendlyTraffic(atomic.LoadUint64(&c.recv)))
    61  	c.user.DelIP(c.ip)
    62  	return c.Conn.Close()
    63  }
    64  
    65  func (c *InboundConn) Auth() error {
    66  	userHash := [56]byte{}
    67  	n, err := c.Conn.Read(userHash[:])
    68  	if err != nil || n != 56 {
    69  		return common.NewError("failed to read hash").Base(err)
    70  	}
    71  
    72  	valid, user := c.auth.AuthUser(string(userHash[:]))
    73  	if !valid {
    74  		return common.NewError("invalid hash:" + string(userHash[:]))
    75  	}
    76  	c.hash = string(userHash[:])
    77  	c.user = user
    78  
    79  	ip, _, err := net.SplitHostPort(c.Conn.RemoteAddr().String())
    80  	if err != nil {
    81  		return common.NewError("failed to parse host:" + c.Conn.RemoteAddr().String()).Base(err)
    82  	}
    83  
    84  	c.ip = ip
    85  	ok := user.AddIP(ip)
    86  	if !ok {
    87  		return common.NewError("ip limit reached")
    88  	}
    89  
    90  	crlf := [2]byte{}
    91  	_, err = io.ReadFull(c.Conn, crlf[:])
    92  	if err != nil {
    93  		return err
    94  	}
    95  
    96  	c.metadata = &tunnel.Metadata{}
    97  	if err := c.metadata.ReadFrom(c.Conn); err != nil {
    98  		return err
    99  	}
   100  
   101  	_, err = io.ReadFull(c.Conn, crlf[:])
   102  	if err != nil {
   103  		return err
   104  	}
   105  	return nil
   106  }
   107  
   108  // Server is a trojan tunnel server
   109  type Server struct {
   110  	auth       statistic.Authenticator
   111  	redir      *redirector.Redirector
   112  	redirAddr  *tunnel.Address
   113  	underlay   tunnel.Server
   114  	connChan   chan tunnel.Conn
   115  	muxChan    chan tunnel.Conn
   116  	packetChan chan tunnel.PacketConn
   117  	ctx        context.Context
   118  	cancel     context.CancelFunc
   119  }
   120  
   121  func (s *Server) Close() error {
   122  	s.cancel()
   123  	return s.underlay.Close()
   124  }
   125  
   126  func (s *Server) acceptLoop() {
   127  	for {
   128  		conn, err := s.underlay.AcceptConn(&Tunnel{})
   129  		if err != nil { // Closing
   130  			log.Error(common.NewError("trojan failed to accept conn").Base(err))
   131  			select {
   132  			case <-s.ctx.Done():
   133  				return
   134  			default:
   135  			}
   136  			continue
   137  		}
   138  		go func(conn tunnel.Conn) {
   139  			rewindConn := common.NewRewindConn(conn)
   140  			rewindConn.SetBufferSize(128)
   141  			defer rewindConn.StopBuffering()
   142  
   143  			inboundConn := &InboundConn{
   144  				Conn: rewindConn,
   145  				auth: s.auth,
   146  			}
   147  
   148  			if err := inboundConn.Auth(); err != nil {
   149  				rewindConn.Rewind()
   150  				rewindConn.StopBuffering()
   151  				log.Warn(common.NewError("connection with invalid trojan header from " + rewindConn.RemoteAddr().String()).Base(err))
   152  				s.redir.Redirect(&redirector.Redirection{
   153  					RedirectTo:  s.redirAddr,
   154  					InboundConn: rewindConn,
   155  				})
   156  				return
   157  			}
   158  
   159  			rewindConn.StopBuffering()
   160  			switch inboundConn.metadata.Command {
   161  			case Connect:
   162  				if inboundConn.metadata.DomainName == "MUX_CONN" {
   163  					s.muxChan <- inboundConn
   164  					log.Debug("mux(r) connection")
   165  				} else {
   166  					s.connChan <- inboundConn
   167  					log.Debug("normal trojan connection")
   168  				}
   169  
   170  			case Associate:
   171  				s.packetChan <- &PacketConn{
   172  					Conn: inboundConn,
   173  				}
   174  				log.Debug("trojan udp connection")
   175  			case Mux:
   176  				s.muxChan <- inboundConn
   177  				log.Debug("mux connection")
   178  			default:
   179  				log.Error(common.NewError(fmt.Sprintf("unknown trojan command %d", inboundConn.metadata.Command)))
   180  			}
   181  		}(conn)
   182  	}
   183  }
   184  
   185  func (s *Server) AcceptConn(nextTunnel tunnel.Tunnel) (tunnel.Conn, error) {
   186  	switch nextTunnel.(type) {
   187  	case *mux.Tunnel:
   188  		select {
   189  		case t := <-s.muxChan:
   190  			return t, nil
   191  		case <-s.ctx.Done():
   192  			return nil, common.NewError("trojan client closed")
   193  		}
   194  	default:
   195  		select {
   196  		case t := <-s.connChan:
   197  			return t, nil
   198  		case <-s.ctx.Done():
   199  			return nil, common.NewError("trojan client closed")
   200  		}
   201  	}
   202  }
   203  
   204  func (s *Server) AcceptPacket(tunnel.Tunnel) (tunnel.PacketConn, error) {
   205  	select {
   206  	case t := <-s.packetChan:
   207  		return t, nil
   208  	case <-s.ctx.Done():
   209  		return nil, common.NewError("trojan client closed")
   210  	}
   211  }
   212  
   213  func NewServer(ctx context.Context, underlay tunnel.Server) (*Server, error) {
   214  	cfg := config.FromContext(ctx, Name).(*Config)
   215  	ctx, cancel := context.WithCancel(ctx)
   216  
   217  	// TODO replace this dirty code
   218  	var auth statistic.Authenticator
   219  	var err error
   220  	if cfg.MySQL.Enabled {
   221  		log.Debug("mysql enabled")
   222  		auth, err = statistic.NewAuthenticator(ctx, mysql.Name)
   223  	} else {
   224  		log.Debug("auth by config file")
   225  		auth, err = statistic.NewAuthenticator(ctx, memory.Name)
   226  	}
   227  	if err != nil {
   228  		cancel()
   229  		return nil, common.NewError("trojan failed to create authenticator")
   230  	}
   231  
   232  	if cfg.API.Enabled {
   233  		go api.RunService(ctx, Name+"_SERVER", auth)
   234  	}
   235  
   236  	redirAddr := tunnel.NewAddressFromHostPort("tcp", cfg.RemoteHost, cfg.RemotePort)
   237  	s := &Server{
   238  		underlay:   underlay,
   239  		auth:       auth,
   240  		redirAddr:  redirAddr,
   241  		connChan:   make(chan tunnel.Conn, 32),
   242  		muxChan:    make(chan tunnel.Conn, 32),
   243  		packetChan: make(chan tunnel.PacketConn, 32),
   244  		ctx:        ctx,
   245  		cancel:     cancel,
   246  		redir:      redirector.NewRedirector(ctx),
   247  	}
   248  
   249  	if !cfg.DisableHTTPCheck {
   250  		redirConn, err := net.Dial("tcp", redirAddr.String())
   251  		if err != nil {
   252  			cancel()
   253  			return nil, common.NewError("invalid redirect address. check your http server: " + redirAddr.String()).Base(err)
   254  		}
   255  		redirConn.Close()
   256  	}
   257  
   258  	go s.acceptLoop()
   259  	log.Debug("trojan server created")
   260  	return s, nil
   261  }