github.com/ipfans/trojan-go@v0.11.0/tunnel/transport/server.go (about)

     1  package transport
     2  
     3  import (
     4  	"bufio"
     5  	"context"
     6  	"net"
     7  	"net/http"
     8  	"os"
     9  	"os/exec"
    10  	"strconv"
    11  	"sync"
    12  	"time"
    13  
    14  	"github.com/ipfans/trojan-go/common"
    15  	"github.com/ipfans/trojan-go/config"
    16  	"github.com/ipfans/trojan-go/log"
    17  	"github.com/ipfans/trojan-go/tunnel"
    18  )
    19  
    20  // Server is a server of transport layer
    21  type Server struct {
    22  	tcpListener net.Listener
    23  	cmd         *exec.Cmd
    24  	connChan    chan tunnel.Conn
    25  	wsChan      chan tunnel.Conn
    26  	httpLock    sync.RWMutex
    27  	nextHTTP    bool
    28  	ctx         context.Context
    29  	cancel      context.CancelFunc
    30  }
    31  
    32  func (s *Server) Close() error {
    33  	s.cancel()
    34  	if s.cmd != nil && s.cmd.Process != nil {
    35  		s.cmd.Process.Kill()
    36  	}
    37  	return s.tcpListener.Close()
    38  }
    39  
    40  func (s *Server) acceptLoop() {
    41  	for {
    42  		tcpConn, err := s.tcpListener.Accept()
    43  		if err != nil {
    44  			select {
    45  			case <-s.ctx.Done():
    46  			default:
    47  				log.Error(common.NewError("transport accept error").Base(err))
    48  				time.Sleep(time.Millisecond * 100)
    49  			}
    50  			return
    51  		}
    52  
    53  		go func(tcpConn net.Conn) {
    54  			log.Info("tcp connection from", tcpConn.RemoteAddr())
    55  			s.httpLock.RLock()
    56  			if s.nextHTTP { // plaintext mode enabled
    57  				s.httpLock.RUnlock()
    58  				// we use real http header parser to mimic a real http server
    59  				rewindConn := common.NewRewindConn(tcpConn)
    60  				rewindConn.SetBufferSize(512)
    61  				defer rewindConn.StopBuffering()
    62  
    63  				r := bufio.NewReader(rewindConn)
    64  				httpReq, err := http.ReadRequest(r)
    65  				rewindConn.Rewind()
    66  				rewindConn.StopBuffering()
    67  				if err != nil {
    68  					// this is not a http request, pass it to trojan protocol layer for further inspection
    69  					s.connChan <- &Conn{
    70  						Conn: rewindConn,
    71  					}
    72  				} else {
    73  					// this is a http request, pass it to websocket protocol layer
    74  					log.Debug("plaintext http request: ", httpReq)
    75  					s.wsChan <- &Conn{
    76  						Conn: rewindConn,
    77  					}
    78  				}
    79  			} else {
    80  				s.httpLock.RUnlock()
    81  				s.connChan <- &Conn{
    82  					Conn: tcpConn,
    83  				}
    84  			}
    85  		}(tcpConn)
    86  	}
    87  }
    88  
    89  func (s *Server) AcceptConn(overlay tunnel.Tunnel) (tunnel.Conn, error) {
    90  	// TODO fix import cycle
    91  	if overlay != nil && (overlay.Name() == "WEBSOCKET" || overlay.Name() == "HTTP") {
    92  		s.httpLock.Lock()
    93  		s.nextHTTP = true
    94  		s.httpLock.Unlock()
    95  		select {
    96  		case conn := <-s.wsChan:
    97  			return conn, nil
    98  		case <-s.ctx.Done():
    99  			return nil, common.NewError("transport server closed")
   100  		}
   101  	}
   102  	select {
   103  	case conn := <-s.connChan:
   104  		return conn, nil
   105  	case <-s.ctx.Done():
   106  		return nil, common.NewError("transport server closed")
   107  	}
   108  }
   109  
   110  func (s *Server) AcceptPacket(tunnel.Tunnel) (tunnel.PacketConn, error) {
   111  	panic("not supported")
   112  }
   113  
   114  // NewServer creates a transport layer server
   115  func NewServer(ctx context.Context, _ tunnel.Server) (*Server, error) {
   116  	cfg := config.FromContext(ctx, Name).(*Config)
   117  	listenAddress := tunnel.NewAddressFromHostPort("tcp", cfg.LocalHost, cfg.LocalPort)
   118  
   119  	var cmd *exec.Cmd
   120  	if cfg.TransportPlugin.Enabled {
   121  		log.Warn("transport server will use plugin and work in plain text mode")
   122  		switch cfg.TransportPlugin.Type {
   123  		case "shadowsocks":
   124  			trojanHost := "127.0.0.1"
   125  			trojanPort := common.PickPort("tcp", trojanHost)
   126  			cfg.TransportPlugin.Env = append(
   127  				cfg.TransportPlugin.Env,
   128  				"SS_REMOTE_HOST="+cfg.LocalHost,
   129  				"SS_REMOTE_PORT="+strconv.FormatInt(int64(cfg.LocalPort), 10),
   130  				"SS_LOCAL_HOST="+trojanHost,
   131  				"SS_LOCAL_PORT="+strconv.FormatInt(int64(trojanPort), 10),
   132  				"SS_PLUGIN_OPTIONS="+cfg.TransportPlugin.Option,
   133  			)
   134  
   135  			cfg.LocalHost = trojanHost
   136  			cfg.LocalPort = trojanPort
   137  			listenAddress = tunnel.NewAddressFromHostPort("tcp", cfg.LocalHost, cfg.LocalPort)
   138  			log.Debug("new listen address", listenAddress)
   139  			log.Debug("plugin env", cfg.TransportPlugin.Env)
   140  
   141  			cmd = exec.Command(cfg.TransportPlugin.Command, cfg.TransportPlugin.Arg...)
   142  			cmd.Env = append(cmd.Env, cfg.TransportPlugin.Env...)
   143  			cmd.Stdout = os.Stdout
   144  			cmd.Stderr = os.Stdout
   145  			cmd.Start()
   146  		case "other":
   147  			cmd = exec.Command(cfg.TransportPlugin.Command, cfg.TransportPlugin.Arg...)
   148  			cmd.Env = append(cmd.Env, cfg.TransportPlugin.Env...)
   149  			cmd.Stdout = os.Stdout
   150  			cmd.Stderr = os.Stdout
   151  			cmd.Start()
   152  		case "plaintext":
   153  			// do nothing
   154  		default:
   155  			return nil, common.NewError("invalid plugin type: " + cfg.TransportPlugin.Type)
   156  		}
   157  	}
   158  	tcpListener, err := net.Listen("tcp", listenAddress.String())
   159  	if err != nil {
   160  		return nil, err
   161  	}
   162  
   163  	ctx, cancel := context.WithCancel(ctx)
   164  	server := &Server{
   165  		tcpListener: tcpListener,
   166  		cmd:         cmd,
   167  		ctx:         ctx,
   168  		cancel:      cancel,
   169  		connChan:    make(chan tunnel.Conn, 32),
   170  		wsChan:      make(chan tunnel.Conn, 32),
   171  	}
   172  	go server.acceptLoop()
   173  	return server, nil
   174  }