github.com/ipfans/trojan-go@v0.11.0/api/service/server.go (about)

     1  package service
     2  
     3  import (
     4  	"context"
     5  	"crypto/tls"
     6  	"crypto/x509"
     7  	"io"
     8  	"io/ioutil"
     9  	"net"
    10  
    11  	"google.golang.org/grpc"
    12  	"google.golang.org/grpc/credentials"
    13  
    14  	"github.com/ipfans/trojan-go/api"
    15  	"github.com/ipfans/trojan-go/common"
    16  	"github.com/ipfans/trojan-go/config"
    17  	"github.com/ipfans/trojan-go/log"
    18  	"github.com/ipfans/trojan-go/statistic"
    19  	"github.com/ipfans/trojan-go/tunnel/trojan"
    20  )
    21  
    22  type ServerAPI struct {
    23  	TrojanServerServiceServer
    24  	auth statistic.Authenticator
    25  }
    26  
    27  func (s *ServerAPI) GetUsers(stream TrojanServerService_GetUsersServer) error {
    28  	log.Debug("API: GetUsers")
    29  	for {
    30  		req, err := stream.Recv()
    31  		if err == io.EOF {
    32  			return nil
    33  		}
    34  		if err != nil {
    35  			return err
    36  		}
    37  		if req.User == nil {
    38  			return common.NewError("user is unspecified")
    39  		}
    40  		if req.User.Hash == "" {
    41  			req.User.Hash = common.SHA224String(req.User.Password)
    42  		}
    43  		valid, user := s.auth.AuthUser(req.User.Hash)
    44  		if !valid {
    45  			stream.Send(&GetUsersResponse{
    46  				Success: false,
    47  				Info:    "invalid user: " + req.User.Hash,
    48  			})
    49  			continue
    50  		}
    51  		downloadTraffic, uploadTraffic := user.GetTraffic()
    52  		downloadSpeed, uploadSpeed := user.GetSpeed()
    53  		downloadSpeedLimit, uploadSpeedLimit := user.GetSpeedLimit()
    54  		ipLimit := user.GetIPLimit()
    55  		ipCurrent := user.GetIP()
    56  		err = stream.Send(&GetUsersResponse{
    57  			Success: true,
    58  			Status: &UserStatus{
    59  				User: req.User,
    60  				TrafficTotal: &Traffic{
    61  					UploadTraffic:   uploadTraffic,
    62  					DownloadTraffic: downloadTraffic,
    63  				},
    64  				SpeedCurrent: &Speed{
    65  					DownloadSpeed: downloadSpeed,
    66  					UploadSpeed:   uploadSpeed,
    67  				},
    68  				SpeedLimit: &Speed{
    69  					DownloadSpeed: uint64(downloadSpeedLimit),
    70  					UploadSpeed:   uint64(uploadSpeedLimit),
    71  				},
    72  				IpCurrent: int32(ipCurrent),
    73  				IpLimit:   int32(ipLimit),
    74  			},
    75  		})
    76  		if err != nil {
    77  			return err
    78  		}
    79  	}
    80  }
    81  
    82  func (s *ServerAPI) SetUsers(stream TrojanServerService_SetUsersServer) error {
    83  	log.Debug("API: SetUsers")
    84  	for {
    85  		req, err := stream.Recv()
    86  		if err == io.EOF {
    87  			return nil
    88  		}
    89  		if err != nil {
    90  			return err
    91  		}
    92  		if req.Status == nil {
    93  			return common.NewError("status is unspecified")
    94  		}
    95  		if req.Status.User.Hash == "" {
    96  			req.Status.User.Hash = common.SHA224String(req.Status.User.Password)
    97  		}
    98  		switch req.Operation {
    99  		case SetUsersRequest_Add:
   100  			if err = s.auth.AddUser(req.Status.User.Hash); err != nil {
   101  				err = common.NewError("failed to add new user").Base(err)
   102  				break
   103  			}
   104  			if req.Status.SpeedLimit != nil {
   105  				valid, user := s.auth.AuthUser(req.Status.User.Hash)
   106  				if !valid {
   107  					err = common.NewError("failed to auth new user").Base(err)
   108  					continue
   109  				}
   110  				if req.Status.SpeedLimit != nil {
   111  					user.SetSpeedLimit(int(req.Status.SpeedLimit.DownloadSpeed), int(req.Status.SpeedLimit.UploadSpeed))
   112  				}
   113  				if req.Status.TrafficTotal != nil {
   114  					user.SetTraffic(req.Status.TrafficTotal.DownloadTraffic, req.Status.TrafficTotal.UploadTraffic)
   115  				}
   116  				user.SetIPLimit(int(req.Status.IpLimit))
   117  			}
   118  		case SetUsersRequest_Delete:
   119  			err = s.auth.DelUser(req.Status.User.Hash)
   120  		case SetUsersRequest_Modify:
   121  			valid, user := s.auth.AuthUser(req.Status.User.Hash)
   122  			if !valid {
   123  				err = common.NewError("invalid user " + req.Status.User.Hash)
   124  			} else {
   125  				if req.Status.SpeedLimit != nil {
   126  					user.SetSpeedLimit(int(req.Status.SpeedLimit.DownloadSpeed), int(req.Status.SpeedLimit.UploadSpeed))
   127  				}
   128  				if req.Status.TrafficTotal != nil {
   129  					user.SetTraffic(req.Status.TrafficTotal.DownloadTraffic, req.Status.TrafficTotal.UploadTraffic)
   130  				}
   131  				user.SetIPLimit(int(req.Status.IpLimit))
   132  			}
   133  		}
   134  		if err != nil {
   135  			stream.Send(&SetUsersResponse{
   136  				Success: false,
   137  				Info:    err.Error(),
   138  			})
   139  			continue
   140  		}
   141  		stream.Send(&SetUsersResponse{
   142  			Success: true,
   143  		})
   144  	}
   145  }
   146  
   147  func (s *ServerAPI) ListUsers(req *ListUsersRequest, stream TrojanServerService_ListUsersServer) error {
   148  	log.Debug("API: ListUsers")
   149  	users := s.auth.ListUsers()
   150  	for _, user := range users {
   151  		downloadTraffic, uploadTraffic := user.GetTraffic()
   152  		downloadSpeed, uploadSpeed := user.GetSpeed()
   153  		downloadSpeedLimit, uploadSpeedLimit := user.GetSpeedLimit()
   154  		ipLimit := user.GetIPLimit()
   155  		ipCurrent := user.GetIP()
   156  		err := stream.Send(&ListUsersResponse{
   157  			Status: &UserStatus{
   158  				User: &User{
   159  					Hash: user.Hash(),
   160  				},
   161  				TrafficTotal: &Traffic{
   162  					DownloadTraffic: downloadTraffic,
   163  					UploadTraffic:   uploadTraffic,
   164  				},
   165  				SpeedCurrent: &Speed{
   166  					DownloadSpeed: downloadSpeed,
   167  					UploadSpeed:   uploadSpeed,
   168  				},
   169  				SpeedLimit: &Speed{
   170  					DownloadSpeed: uint64(downloadSpeedLimit),
   171  					UploadSpeed:   uint64(uploadSpeedLimit),
   172  				},
   173  				IpLimit:   int32(ipLimit),
   174  				IpCurrent: int32(ipCurrent),
   175  			},
   176  		})
   177  		if err != nil {
   178  			return err
   179  		}
   180  	}
   181  	return nil
   182  }
   183  
   184  func newAPIServer(cfg *Config) (*grpc.Server, error) {
   185  	var server *grpc.Server
   186  	if cfg.API.SSL.Enabled {
   187  		log.Info("api tls enabled")
   188  		keyPair, err := tls.LoadX509KeyPair(cfg.API.SSL.CertPath, cfg.API.SSL.KeyPath)
   189  		if err != nil {
   190  			return nil, common.NewError("failed to load key pair").Base(err)
   191  		}
   192  		tlsConfig := &tls.Config{
   193  			Certificates: []tls.Certificate{keyPair},
   194  		}
   195  		if cfg.API.SSL.VerifyClient {
   196  			tlsConfig.ClientAuth = tls.RequireAndVerifyClientCert
   197  			tlsConfig.ClientCAs = x509.NewCertPool()
   198  			for _, path := range cfg.API.SSL.ClientCertPath {
   199  				log.Debug("loading client cert: " + path)
   200  				certBytes, err := ioutil.ReadFile(path)
   201  				if err != nil {
   202  					return nil, common.NewError("failed to load cert file").Base(err)
   203  				}
   204  				ok := tlsConfig.ClientCAs.AppendCertsFromPEM(certBytes)
   205  				if !ok {
   206  					return nil, common.NewError("invalid client cert")
   207  				}
   208  			}
   209  		}
   210  		creds := credentials.NewTLS(tlsConfig)
   211  		server = grpc.NewServer(grpc.Creds(creds))
   212  	} else {
   213  		server = grpc.NewServer()
   214  	}
   215  	return server, nil
   216  }
   217  
   218  func RunServerAPI(ctx context.Context, auth statistic.Authenticator) error {
   219  	cfg := config.FromContext(ctx, Name).(*Config)
   220  	if !cfg.API.Enabled {
   221  		return nil
   222  	}
   223  	service := &ServerAPI{
   224  		auth: auth,
   225  	}
   226  	server, err := newAPIServer(cfg)
   227  	if err != nil {
   228  		return err
   229  	}
   230  	defer server.Stop()
   231  	RegisterTrojanServerServiceServer(server, service)
   232  	addr, err := net.ResolveIPAddr("ip", cfg.API.APIHost)
   233  	if err != nil {
   234  		return common.NewError("api found invalid addr").Base(err)
   235  	}
   236  	listener, err := net.Listen("tcp", (&net.TCPAddr{
   237  		IP:   addr.IP,
   238  		Port: cfg.API.APIPort,
   239  		Zone: addr.Zone,
   240  	}).String())
   241  	if err != nil {
   242  		return common.NewError("server api failed to listen").Base(err)
   243  	}
   244  	defer listener.Close()
   245  	log.Info("server-side api service is listening on", listener.Addr().String())
   246  	errChan := make(chan error, 1)
   247  	go func() {
   248  		errChan <- server.Serve(listener)
   249  	}()
   250  	select {
   251  	case err := <-errChan:
   252  		return err
   253  	case <-ctx.Done():
   254  		log.Debug("closed")
   255  		return nil
   256  	}
   257  }
   258  
   259  func init() {
   260  	api.RegisterHandler(trojan.Name+"_SERVER", RunServerAPI)
   261  }