github.com/volts-dev/volts@v0.0.0-20240120094013-5e9c65924106/client/rpc_client.go (about) 1 package client 2 3 import ( 4 "context" 5 "fmt" 6 "time" 7 8 "github.com/volts-dev/volts/codec" 9 "github.com/volts-dev/volts/internal/body" 10 "github.com/volts-dev/volts/internal/errors" 11 "github.com/volts-dev/volts/internal/metadata" 12 "github.com/volts-dev/volts/internal/net" 13 "github.com/volts-dev/volts/internal/pool" 14 "github.com/volts-dev/volts/registry" 15 "github.com/volts-dev/volts/selector" 16 "github.com/volts-dev/volts/transport" 17 ) 18 19 type ( 20 RpcClient struct { 21 config *Config 22 pool pool.Pool // connect pool 23 closing bool // user has called Close 24 shutdown bool // server has told us to stop 25 } 26 ) 27 28 func NewRpcClient(opts ...Option) *RpcClient { 29 cfg := newConfig( 30 transport.NewTCPTransport(), 31 opts..., 32 ) 33 34 // 默认编码 35 if cfg.SerializeType == "" { 36 cfg.Serialize = codec.JSON 37 } 38 39 p := pool.NewPool( 40 pool.Size(cfg.PoolSize), 41 pool.TTL(cfg.PoolTtl), 42 pool.Transport(cfg.Transport), 43 ) 44 45 return &RpcClient{ 46 config: cfg, 47 pool: p, 48 } 49 } 50 51 func (self *RpcClient) Init(opts ...Option) error { 52 self.config.Init(opts...) 53 return nil 54 } 55 56 func (self *RpcClient) Config() *Config { 57 return self.config 58 } 59 60 // 新建请求 61 func (self *RpcClient) NewRequest(service, method string, request interface{}, optinos ...RequestOption) (*rpcRequest, error) { 62 optinos = append(optinos, 63 WithCodec(self.config.Serialize), 64 ) 65 return newRpcRequest(service, method, request, optinos...) 66 } 67 68 func (self *RpcClient) call(ctx context.Context, node *registry.Node, req IRequest, opts CallOptions) (IResponse, error) { 69 // 验证解码器 70 msgCodece := codec.IdentifyCodec(self.config.Serialize) 71 if msgCodece == nil { // no codec specified 72 //call.Error = rpc.ErrUnsupportedCodec 73 //client.mutex.Unlock() 74 //call.done() 75 return nil, errors.UnsupportedCodec("volts.client", self.config.SerializeType) 76 } 77 78 // 获取空闲链接 79 dOpts := []transport.DialOption{ 80 transport.WithStream(), 81 } 82 83 if opts.DialTimeout >= 0 { 84 dOpts = append(dOpts, transport.WithTimeout(opts.DialTimeout, opts.RequestTimeout, 0)) 85 } 86 87 conn, err := self.pool.Get(node.Address, dOpts...) 88 if err != nil { 89 return nil, errors.InternalServerError("volts.client", "connection error: %v", err) 90 } 91 defer self.pool.Release(conn, nil) 92 93 // 获取消息载体 94 msg := transport.GetMessageFromPool() 95 msg.SetMessageType(transport.MT_REQUEST) 96 msg.SetSerializeType(self.config.Serialize) 97 98 // init header 99 for k, v := range req.Header() { 100 msg.Header[k] = v[0] 101 } 102 md, ok := metadata.FromContext(ctx) 103 if ok { 104 for k, v := range md { 105 msg.Header[k] = v 106 } 107 } 108 109 // set timeout in nanoseconds 110 msg.Header["Timeout"] = fmt.Sprintf("%d", opts.RequestTimeout) 111 // set the content type for the request 112 msg.Header["Content-Type"] = req.ContentType() 113 // set the accept header 114 msg.Header["Accept"] = req.ContentType() 115 116 msg.Path = req.Method() // TODO msg 添加server action 117 data := req.Body().Data.Bytes() 118 if len(data) > 1024 && self.config.CompressType == transport.Gzip { 119 data, err = transport.Zip(data) 120 if err != nil { 121 return nil, err 122 } 123 124 msg.SetCompressType(self.config.CompressType) 125 } 126 127 msg.Payload = data 128 //seq := atomic.AddUint64(&self.seq, 1) - 1 129 //codec := newRpcCodec(msg, c, cf, "") 130 131 // 开始发送消息 132 // wait for error response 133 ch := make(chan error, 1) 134 resp := &rpcResponse{} 135 go func(resp *rpcResponse) { 136 defer func() { 137 if r := recover(); r != nil { 138 ch <- errors.InternalServerError("volts.client", "panic recovered: %v", r) 139 } 140 }() 141 142 // send request 143 // 返回编译过的数据 144 err := conn.Send(msg) 145 if err != nil { 146 ch <- err 147 return 148 } 149 150 // recv request 151 msg = transport.GetMessageFromPool() 152 err = conn.Recv(msg) 153 if err != nil { 154 ch <- err 155 return 156 } 157 158 // 状态码处理 159 switch msg.MessageStatusType() { 160 case transport.StatusOK: 161 break 162 case transport.StatusError: 163 ch <- errors.New("StatusError", int32(transport.StatusError), string(msg.Payload)) 164 return 165 default: 166 ch <- errors.New("", int32(msg.MessageStatusType()), string(msg.Payload)) 167 return 168 } 169 170 bd := body.New(codec.IdentifyCodec(msg.SerializeType())) 171 bd.Data.Write(msg.Payload) 172 // 解码消息内容 173 resp.contentType = msg.SerializeType() 174 resp.body = bd // msg.Payload 175 176 // success 177 ch <- nil 178 }(resp) 179 180 err = nil 181 select { 182 case err := <-ch: 183 return resp, err 184 case <-ctx.Done(): 185 err = errors.Timeout("volts.client", fmt.Sprintf("%v", ctx.Err())) 186 break 187 } 188 189 // set the stream error 190 if err != nil { 191 //stream.Lock() 192 //stream.err = grr 193 //stream.Unlock() 194 return nil, err 195 } 196 197 return resp, nil 198 } 199 200 // 阻塞请求 201 func (self *RpcClient) Call(request IRequest, opts ...CallOption) (IResponse, error) { 202 // make a copy of call opts 203 callOpts := self.config.CallOptions 204 callOpts.SelectOptions = append(callOpts.SelectOptions, selector.WithFilter(selector.FilterTrasport(self.config.Transport))) 205 for _, opt := range opts { 206 opt(&callOpts) 207 } 208 209 next, err := self.next(request, callOpts) 210 if err != nil { 211 return nil, err 212 } 213 214 ctx := callOpts.Context 215 if ctx == nil { 216 ctx = context.Background() 217 } 218 // check if we already have a deadline 219 d, ok := ctx.Deadline() 220 if !ok { 221 // no deadline so we create a new one 222 var cancel context.CancelFunc 223 ctx, cancel = context.WithTimeout(ctx, callOpts.RequestTimeout) 224 defer cancel() 225 } else { 226 // got a deadline so no need to setup context 227 // but we need to set the timeout we pass along 228 opt := WithRequestTimeout(time.Until(d)) 229 opt(&callOpts) 230 } 231 232 // should we noop right here? 233 select { 234 case <-ctx.Done(): 235 return nil, errors.Timeout("volts.client", fmt.Sprintf("%v", ctx.Err())) 236 default: 237 } 238 239 // return errors.New("volts.client", "request timeout", 408) 240 call := func(i int, response *IResponse) error { 241 // select next node 242 // selector 可能因为过滤后得不到合适服务器 243 node, err := next() 244 if err != nil { 245 return err 246 } 247 248 // make the call 249 *response, err = self.call(ctx, node, request, callOpts) 250 //r.opts.Selector.Mark(service, node, err) 251 return err 252 } 253 var response IResponse 254 // get the retries 255 retries := callOpts.Retries 256 ch := make(chan error, retries+1) 257 var gerr error 258 for i := 0; i <= retries; i++ { 259 go func(i int, response *IResponse) { 260 ch <- call(i, response) 261 }(i, &response) 262 263 select { 264 case <-ctx.Done(): 265 return nil, errors.Timeout("volts.client", fmt.Sprintf("call timeout: %v", ctx.Err())) 266 case err := <-ch: 267 // if the call succeeded lets bail early 268 if err == nil { 269 return response, nil 270 } 271 272 retry, rerr := callOpts.Retry(ctx, request, i, err) 273 if rerr != nil { 274 return nil, rerr 275 } 276 277 if !retry { 278 return nil, err 279 } 280 281 gerr = err 282 } 283 } 284 285 return response, gerr 286 } 287 288 // next returns an iterator for the next nodes to call 289 func (r *RpcClient) next(request IRequest, opts CallOptions) (selector.Next, error) { 290 // try get the proxy 291 service, address, _ := net.Proxy(request.Service(), opts.Address) 292 293 // return remote address 294 if len(address) > 0 { 295 nodes := make([]*registry.Node, len(address)) 296 297 for i, addr := range address { 298 nodes[i] = ®istry.Node{ 299 Address: addr, 300 // Set the protocol 301 Metadata: map[string]string{ 302 "protocol": "mucp", 303 }, 304 } 305 } 306 307 // crude return method 308 return func() (*registry.Node, error) { 309 return nodes[time.Now().Unix()%int64(len(nodes))], nil 310 }, nil 311 } 312 // only get the things that are of http protocol 313 selectOptions := append(opts.SelectOptions, selector.WithFilter( 314 selector.FilterLabel("protocol", r.config.Transport.Protocol()), 315 )) 316 317 // get next nodes from the selector 318 next, err := r.config.Selector.Select(service, selectOptions...) 319 if err != nil { 320 if err == selector.ErrNotFound { 321 return nil, errors.InternalServerError("volts.client", "service %s: %s", service, err.Error()) 322 } 323 return nil, errors.InternalServerError("volts.client", "error selecting %s node: %s", service, err.Error()) 324 } 325 326 return next, nil 327 } 328 329 func (self *RpcClient) String() string { 330 return "RpcClient" 331 }