github.com/cloudwego/kitex@v0.9.0/pkg/remote/trans/netpollmux/client_handler.go (about) 1 /* 2 * Copyright 2021 CloudWeGo Authors 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 package netpollmux 18 19 import ( 20 "context" 21 "errors" 22 "fmt" 23 "net" 24 "sync/atomic" 25 26 "github.com/cloudwego/netpoll" 27 28 "github.com/cloudwego/kitex/pkg/kerrors" 29 "github.com/cloudwego/kitex/pkg/klog" 30 "github.com/cloudwego/kitex/pkg/remote" 31 "github.com/cloudwego/kitex/pkg/remote/codec" 32 "github.com/cloudwego/kitex/pkg/remote/trans" 33 np "github.com/cloudwego/kitex/pkg/remote/trans/netpoll" 34 "github.com/cloudwego/kitex/pkg/rpcinfo" 35 "github.com/cloudwego/kitex/pkg/serviceinfo" 36 "github.com/cloudwego/kitex/pkg/stats" 37 ) 38 39 type cliTransHandlerFactory struct{} 40 41 // NewCliTransHandlerFactory creates a new netpollmux client transport handler factory. 42 func NewCliTransHandlerFactory() remote.ClientTransHandlerFactory { 43 return &cliTransHandlerFactory{} 44 } 45 46 // NewTransHandler implements the remote.ClientTransHandlerFactory interface. 47 func (f *cliTransHandlerFactory) NewTransHandler(opt *remote.ClientOption) (remote.ClientTransHandler, error) { 48 if _, ok := opt.ConnPool.(*MuxPool); !ok { 49 return nil, fmt.Errorf("ConnPool[%T] invalid, netpoll mux just support MuxPool", opt.ConnPool) 50 } 51 return newCliTransHandler(opt) 52 } 53 54 func newCliTransHandler(opt *remote.ClientOption) (*cliTransHandler, error) { 55 return &cliTransHandler{ 56 opt: opt, 57 codec: opt.Codec, 58 }, nil 59 } 60 61 var _ remote.ClientTransHandler = &cliTransHandler{} 62 63 type cliTransHandler struct { 64 opt *remote.ClientOption 65 codec remote.Codec 66 transPipe *remote.TransPipeline 67 } 68 69 // Write implements the remote.ClientTransHandler interface. 70 func (t *cliTransHandler) Write(ctx context.Context, conn net.Conn, sendMsg remote.Message) (nctx context.Context, err error) { 71 ri := sendMsg.RPCInfo() 72 rpcinfo.Record(ctx, ri, stats.WriteStart, nil) 73 buf := netpoll.NewLinkBuffer() 74 bufWriter := np.NewWriterByteBuffer(buf) 75 defer func() { 76 if err != nil { 77 buf.Close() 78 bufWriter.Release(err) 79 } 80 rpcinfo.Record(ctx, ri, stats.WriteFinish, nil) 81 }() 82 83 // Set header flag = 1 84 tags := sendMsg.Tags() 85 if tags == nil { 86 tags = make(map[string]interface{}) 87 } 88 tags[codec.HeaderFlagsKey] = codec.HeaderFlagSupportOutOfOrder 89 90 // encode 91 sendMsg.SetPayloadCodec(t.opt.PayloadCodec) 92 err = t.codec.Encode(ctx, sendMsg, bufWriter) 93 if err != nil { 94 return ctx, err 95 } 96 97 mc, _ := conn.(*muxCliConn) 98 99 // if oneway 100 var methodInfo serviceinfo.MethodInfo 101 if methodInfo, err = trans.GetMethodInfo(ri, sendMsg.ServiceInfo()); err != nil { 102 return ctx, err 103 } 104 if methodInfo.OneWay() { 105 mc.Put(func() (_ netpoll.Writer, isNil bool) { 106 return buf, false 107 }) 108 return ctx, nil 109 } 110 111 // add notify 112 seqID := ri.Invocation().SeqID() 113 callback := newAsyncCallback(buf, bufWriter) 114 mc.seqIDMap.store(seqID, callback) 115 mc.Put(callback.getter) 116 return ctx, err 117 } 118 119 // Read implements the remote.ClientTransHandler interface. 120 func (t *cliTransHandler) Read(ctx context.Context, conn net.Conn, msg remote.Message) (nctx context.Context, err error) { 121 ri := msg.RPCInfo() 122 mc, _ := conn.(*muxCliConn) 123 seqID := ri.Invocation().SeqID() 124 // load & delete before return 125 event, _ := mc.seqIDMap.load(seqID) 126 defer mc.seqIDMap.delete(seqID) 127 128 callback, _ := event.(*asyncCallback) 129 defer callback.Close() 130 131 readTimeout := trans.GetReadTimeout(ri.Config()) 132 if readTimeout > 0 { 133 var cancel context.CancelFunc 134 ctx, cancel = context.WithTimeout(ctx, readTimeout) 135 defer cancel() 136 } 137 select { 138 case <-ctx.Done(): 139 // timeout 140 return ctx, fmt.Errorf("recv wait timeout %s, seqID=%d", readTimeout, seqID) 141 case bufReader := <-callback.notifyChan: 142 // recv 143 if bufReader == nil { 144 return ctx, ErrConnClosed 145 } 146 rpcinfo.Record(ctx, ri, stats.ReadStart, nil) 147 msg.SetPayloadCodec(t.opt.PayloadCodec) 148 err := t.codec.Decode(ctx, msg, bufReader) 149 if err != nil && errors.Is(err, netpoll.ErrReadTimeout) { 150 err = kerrors.ErrRPCTimeout.WithCause(err) 151 } 152 if l := bufReader.ReadableLen(); l > 0 { 153 bufReader.Skip(l) 154 } 155 bufReader.Release(nil) 156 rpcinfo.Record(ctx, ri, stats.ReadFinish, err) 157 return ctx, err 158 } 159 } 160 161 // OnMessage implements the remote.ClientTransHandler interface. 162 func (t *cliTransHandler) OnMessage(ctx context.Context, args, result remote.Message) (context.Context, error) { 163 // do nothing 164 return ctx, nil 165 } 166 167 // OnInactive implements the remote.ClientTransHandler interface. 168 func (t *cliTransHandler) OnInactive(ctx context.Context, conn net.Conn) { 169 // ineffective now and do nothing 170 } 171 172 // OnError implements the remote.ClientTransHandler interface. 173 func (t *cliTransHandler) OnError(ctx context.Context, err error, conn net.Conn) { 174 if pe, ok := err.(*kerrors.DetailedError); ok { 175 klog.CtxErrorf(ctx, "KITEX: send request error, remote=%s, error=%s\nstack=%s", conn.RemoteAddr(), err.Error(), pe.Stack()) 176 } else { 177 klog.CtxErrorf(ctx, "KITEX: send request error, remote=%s, error=%s", conn.RemoteAddr(), err.Error()) 178 } 179 } 180 181 // SetPipeline implements the remote.ClientTransHandler interface. 182 func (t *cliTransHandler) SetPipeline(p *remote.TransPipeline) { 183 t.transPipe = p 184 } 185 186 type asyncCallback struct { 187 wbuf *netpoll.LinkBuffer 188 bufWriter remote.ByteBuffer 189 notifyChan chan remote.ByteBuffer // notify recv reader 190 closed int32 // 1 is closed, 2 means wbuf has been flush 191 } 192 193 func newAsyncCallback(wbuf *netpoll.LinkBuffer, bufWriter remote.ByteBuffer) *asyncCallback { 194 return &asyncCallback{ 195 wbuf: wbuf, 196 bufWriter: bufWriter, 197 notifyChan: make(chan remote.ByteBuffer, 1), 198 } 199 } 200 201 // Recv is called when receive a message. 202 func (c *asyncCallback) Recv(bufReader remote.ByteBuffer, err error) error { 203 c.notify(bufReader) 204 return nil 205 } 206 207 // Close is used to close the mux connection. 208 func (c *asyncCallback) Close() error { 209 if atomic.CompareAndSwapInt32(&c.closed, 0, 1) { 210 c.wbuf.Close() 211 c.bufWriter.Release(nil) 212 } 213 return nil 214 } 215 216 func (c *asyncCallback) notify(bufReader remote.ByteBuffer) { 217 select { 218 case c.notifyChan <- bufReader: 219 default: 220 if bufReader != nil { 221 bufReader.Release(nil) 222 } 223 } 224 } 225 226 func (c *asyncCallback) getter() (w netpoll.Writer, isNil bool) { 227 if atomic.CompareAndSwapInt32(&c.closed, 0, 2) { 228 return c.wbuf, false 229 } 230 return nil, true 231 }