github.com/ipfans/trojan-go@v0.11.0/api/service/server.go (about) 1 package service 2 3 import ( 4 "context" 5 "crypto/tls" 6 "crypto/x509" 7 "io" 8 "io/ioutil" 9 "net" 10 11 "google.golang.org/grpc" 12 "google.golang.org/grpc/credentials" 13 14 "github.com/ipfans/trojan-go/api" 15 "github.com/ipfans/trojan-go/common" 16 "github.com/ipfans/trojan-go/config" 17 "github.com/ipfans/trojan-go/log" 18 "github.com/ipfans/trojan-go/statistic" 19 "github.com/ipfans/trojan-go/tunnel/trojan" 20 ) 21 22 type ServerAPI struct { 23 TrojanServerServiceServer 24 auth statistic.Authenticator 25 } 26 27 func (s *ServerAPI) GetUsers(stream TrojanServerService_GetUsersServer) error { 28 log.Debug("API: GetUsers") 29 for { 30 req, err := stream.Recv() 31 if err == io.EOF { 32 return nil 33 } 34 if err != nil { 35 return err 36 } 37 if req.User == nil { 38 return common.NewError("user is unspecified") 39 } 40 if req.User.Hash == "" { 41 req.User.Hash = common.SHA224String(req.User.Password) 42 } 43 valid, user := s.auth.AuthUser(req.User.Hash) 44 if !valid { 45 stream.Send(&GetUsersResponse{ 46 Success: false, 47 Info: "invalid user: " + req.User.Hash, 48 }) 49 continue 50 } 51 downloadTraffic, uploadTraffic := user.GetTraffic() 52 downloadSpeed, uploadSpeed := user.GetSpeed() 53 downloadSpeedLimit, uploadSpeedLimit := user.GetSpeedLimit() 54 ipLimit := user.GetIPLimit() 55 ipCurrent := user.GetIP() 56 err = stream.Send(&GetUsersResponse{ 57 Success: true, 58 Status: &UserStatus{ 59 User: req.User, 60 TrafficTotal: &Traffic{ 61 UploadTraffic: uploadTraffic, 62 DownloadTraffic: downloadTraffic, 63 }, 64 SpeedCurrent: &Speed{ 65 DownloadSpeed: downloadSpeed, 66 UploadSpeed: uploadSpeed, 67 }, 68 SpeedLimit: &Speed{ 69 DownloadSpeed: uint64(downloadSpeedLimit), 70 UploadSpeed: uint64(uploadSpeedLimit), 71 }, 72 IpCurrent: int32(ipCurrent), 73 IpLimit: int32(ipLimit), 74 }, 75 }) 76 if err != nil { 77 return err 78 } 79 } 80 } 81 82 func (s *ServerAPI) SetUsers(stream TrojanServerService_SetUsersServer) error { 83 log.Debug("API: SetUsers") 84 for { 85 req, err := stream.Recv() 86 if err == io.EOF { 87 return nil 88 } 89 if err != nil { 90 return err 91 } 92 if req.Status == nil { 93 return common.NewError("status is unspecified") 94 } 95 if req.Status.User.Hash == "" { 96 req.Status.User.Hash = common.SHA224String(req.Status.User.Password) 97 } 98 switch req.Operation { 99 case SetUsersRequest_Add: 100 if err = s.auth.AddUser(req.Status.User.Hash); err != nil { 101 err = common.NewError("failed to add new user").Base(err) 102 break 103 } 104 if req.Status.SpeedLimit != nil { 105 valid, user := s.auth.AuthUser(req.Status.User.Hash) 106 if !valid { 107 err = common.NewError("failed to auth new user").Base(err) 108 continue 109 } 110 if req.Status.SpeedLimit != nil { 111 user.SetSpeedLimit(int(req.Status.SpeedLimit.DownloadSpeed), int(req.Status.SpeedLimit.UploadSpeed)) 112 } 113 if req.Status.TrafficTotal != nil { 114 user.SetTraffic(req.Status.TrafficTotal.DownloadTraffic, req.Status.TrafficTotal.UploadTraffic) 115 } 116 user.SetIPLimit(int(req.Status.IpLimit)) 117 } 118 case SetUsersRequest_Delete: 119 err = s.auth.DelUser(req.Status.User.Hash) 120 case SetUsersRequest_Modify: 121 valid, user := s.auth.AuthUser(req.Status.User.Hash) 122 if !valid { 123 err = common.NewError("invalid user " + req.Status.User.Hash) 124 } else { 125 if req.Status.SpeedLimit != nil { 126 user.SetSpeedLimit(int(req.Status.SpeedLimit.DownloadSpeed), int(req.Status.SpeedLimit.UploadSpeed)) 127 } 128 if req.Status.TrafficTotal != nil { 129 user.SetTraffic(req.Status.TrafficTotal.DownloadTraffic, req.Status.TrafficTotal.UploadTraffic) 130 } 131 user.SetIPLimit(int(req.Status.IpLimit)) 132 } 133 } 134 if err != nil { 135 stream.Send(&SetUsersResponse{ 136 Success: false, 137 Info: err.Error(), 138 }) 139 continue 140 } 141 stream.Send(&SetUsersResponse{ 142 Success: true, 143 }) 144 } 145 } 146 147 func (s *ServerAPI) ListUsers(req *ListUsersRequest, stream TrojanServerService_ListUsersServer) error { 148 log.Debug("API: ListUsers") 149 users := s.auth.ListUsers() 150 for _, user := range users { 151 downloadTraffic, uploadTraffic := user.GetTraffic() 152 downloadSpeed, uploadSpeed := user.GetSpeed() 153 downloadSpeedLimit, uploadSpeedLimit := user.GetSpeedLimit() 154 ipLimit := user.GetIPLimit() 155 ipCurrent := user.GetIP() 156 err := stream.Send(&ListUsersResponse{ 157 Status: &UserStatus{ 158 User: &User{ 159 Hash: user.Hash(), 160 }, 161 TrafficTotal: &Traffic{ 162 DownloadTraffic: downloadTraffic, 163 UploadTraffic: uploadTraffic, 164 }, 165 SpeedCurrent: &Speed{ 166 DownloadSpeed: downloadSpeed, 167 UploadSpeed: uploadSpeed, 168 }, 169 SpeedLimit: &Speed{ 170 DownloadSpeed: uint64(downloadSpeedLimit), 171 UploadSpeed: uint64(uploadSpeedLimit), 172 }, 173 IpLimit: int32(ipLimit), 174 IpCurrent: int32(ipCurrent), 175 }, 176 }) 177 if err != nil { 178 return err 179 } 180 } 181 return nil 182 } 183 184 func newAPIServer(cfg *Config) (*grpc.Server, error) { 185 var server *grpc.Server 186 if cfg.API.SSL.Enabled { 187 log.Info("api tls enabled") 188 keyPair, err := tls.LoadX509KeyPair(cfg.API.SSL.CertPath, cfg.API.SSL.KeyPath) 189 if err != nil { 190 return nil, common.NewError("failed to load key pair").Base(err) 191 } 192 tlsConfig := &tls.Config{ 193 Certificates: []tls.Certificate{keyPair}, 194 } 195 if cfg.API.SSL.VerifyClient { 196 tlsConfig.ClientAuth = tls.RequireAndVerifyClientCert 197 tlsConfig.ClientCAs = x509.NewCertPool() 198 for _, path := range cfg.API.SSL.ClientCertPath { 199 log.Debug("loading client cert: " + path) 200 certBytes, err := ioutil.ReadFile(path) 201 if err != nil { 202 return nil, common.NewError("failed to load cert file").Base(err) 203 } 204 ok := tlsConfig.ClientCAs.AppendCertsFromPEM(certBytes) 205 if !ok { 206 return nil, common.NewError("invalid client cert") 207 } 208 } 209 } 210 creds := credentials.NewTLS(tlsConfig) 211 server = grpc.NewServer(grpc.Creds(creds)) 212 } else { 213 server = grpc.NewServer() 214 } 215 return server, nil 216 } 217 218 func RunServerAPI(ctx context.Context, auth statistic.Authenticator) error { 219 cfg := config.FromContext(ctx, Name).(*Config) 220 if !cfg.API.Enabled { 221 return nil 222 } 223 service := &ServerAPI{ 224 auth: auth, 225 } 226 server, err := newAPIServer(cfg) 227 if err != nil { 228 return err 229 } 230 defer server.Stop() 231 RegisterTrojanServerServiceServer(server, service) 232 addr, err := net.ResolveIPAddr("ip", cfg.API.APIHost) 233 if err != nil { 234 return common.NewError("api found invalid addr").Base(err) 235 } 236 listener, err := net.Listen("tcp", (&net.TCPAddr{ 237 IP: addr.IP, 238 Port: cfg.API.APIPort, 239 Zone: addr.Zone, 240 }).String()) 241 if err != nil { 242 return common.NewError("server api failed to listen").Base(err) 243 } 244 defer listener.Close() 245 log.Info("server-side api service is listening on", listener.Addr().String()) 246 errChan := make(chan error, 1) 247 go func() { 248 errChan <- server.Serve(listener) 249 }() 250 select { 251 case err := <-errChan: 252 return err 253 case <-ctx.Done(): 254 log.Debug("closed") 255 return nil 256 } 257 } 258 259 func init() { 260 api.RegisterHandler(trojan.Name+"_SERVER", RunServerAPI) 261 }