github.com/cloudwego/kitex@v0.9.0/client/stream.go (about) 1 /* 2 * Copyright 2021 CloudWeGo Authors 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 package client 18 19 import ( 20 "context" 21 "io" 22 "sync/atomic" 23 24 "github.com/cloudwego/kitex/pkg/remote/trans/nphttp2/metadata" 25 "github.com/cloudwego/kitex/pkg/serviceinfo" 26 27 "github.com/cloudwego/kitex/pkg/endpoint" 28 "github.com/cloudwego/kitex/pkg/remote" 29 "github.com/cloudwego/kitex/pkg/remote/remotecli" 30 "github.com/cloudwego/kitex/pkg/rpcinfo" 31 "github.com/cloudwego/kitex/pkg/streaming" 32 ) 33 34 // Streaming client streaming interface for code generate 35 type Streaming interface { 36 Stream(ctx context.Context, method string, request, response interface{}) error 37 } 38 39 // Stream implements the Streaming interface 40 func (kc *kClient) Stream(ctx context.Context, method string, request, response interface{}) error { 41 if !kc.inited { 42 panic("client not initialized") 43 } 44 if kc.closed { 45 panic("client is already closed") 46 } 47 if ctx == nil { 48 panic("ctx is nil") 49 } 50 var ri rpcinfo.RPCInfo 51 ctx, ri, _ = kc.initRPCInfo(ctx, method, 0, nil) 52 53 rpcinfo.AsMutableRPCConfig(ri.Config()).SetInteractionMode(rpcinfo.Streaming) 54 ctx = rpcinfo.NewCtxWithRPCInfo(ctx, ri) 55 56 ctx = kc.opt.TracerCtl.DoStart(ctx, ri) 57 return kc.sEps(ctx, request, response) 58 } 59 60 func (kc *kClient) invokeSendEndpoint() endpoint.SendEndpoint { 61 return func(stream streaming.Stream, req interface{}) (err error) { 62 return stream.SendMsg(req) 63 } 64 } 65 66 func (kc *kClient) invokeRecvEndpoint() endpoint.RecvEndpoint { 67 return func(stream streaming.Stream, resp interface{}) (err error) { 68 return stream.RecvMsg(resp) 69 } 70 } 71 72 func (kc *kClient) invokeStreamingEndpoint() (endpoint.Endpoint, error) { 73 handler, err := kc.opt.RemoteOpt.CliHandlerFactory.NewTransHandler(kc.opt.RemoteOpt) 74 if err != nil { 75 return nil, err 76 } 77 for _, h := range kc.opt.MetaHandlers { 78 if shdlr, ok := h.(remote.StreamingMetaHandler); ok { 79 kc.opt.RemoteOpt.StreamingMetaHandlers = append(kc.opt.RemoteOpt.StreamingMetaHandlers, shdlr) 80 } 81 } 82 83 recvEndpoint := kc.opt.Streaming.BuildRecvInvokeChain(kc.invokeRecvEndpoint()) 84 sendEndpoint := kc.opt.Streaming.BuildSendInvokeChain(kc.invokeSendEndpoint()) 85 86 return func(ctx context.Context, req, resp interface{}) (err error) { 87 // req and resp as &streaming.Stream 88 ri := rpcinfo.GetRPCInfo(ctx) 89 st, err := remotecli.NewStream(ctx, ri, handler, kc.opt.RemoteOpt) 90 if err != nil { 91 return 92 } 93 clientStream := newStream(st, kc, ri, kc.getStreamingMode(ri), sendEndpoint, recvEndpoint) 94 resp.(*streaming.Result).Stream = clientStream 95 return 96 }, nil 97 } 98 99 func (kc *kClient) getStreamingMode(ri rpcinfo.RPCInfo) serviceinfo.StreamingMode { 100 methodInfo := kc.svcInfo.MethodInfo(ri.Invocation().MethodName()) 101 if methodInfo == nil { 102 return serviceinfo.StreamingNone 103 } 104 return methodInfo.StreamingMode() 105 } 106 107 type stream struct { 108 stream streaming.Stream 109 kc *kClient 110 ri rpcinfo.RPCInfo 111 112 streamingMode serviceinfo.StreamingMode 113 sendEndpoint endpoint.SendEndpoint 114 recvEndpoint endpoint.RecvEndpoint 115 finished uint32 116 } 117 118 var _ streaming.WithDoFinish = (*stream)(nil) 119 120 func newStream(s streaming.Stream, kc *kClient, ri rpcinfo.RPCInfo, 121 mode serviceinfo.StreamingMode, sendEP endpoint.SendEndpoint, recvEP endpoint.RecvEndpoint, 122 ) *stream { 123 return &stream{ 124 stream: s, 125 kc: kc, 126 ri: ri, 127 streamingMode: mode, 128 sendEndpoint: sendEP, 129 recvEndpoint: recvEP, 130 } 131 } 132 133 func (s *stream) SetTrailer(metadata.MD) { 134 panic("this method should only be used in server side stream!") 135 } 136 137 func (s *stream) SetHeader(metadata.MD) error { 138 panic("this method should only be used in server side stream!") 139 } 140 141 func (s *stream) SendHeader(metadata.MD) error { 142 panic("this method should only be used in server side stream!") 143 } 144 145 // Header returns the header metadata sent by the server if any. 146 // If a non-nil error is returned, stream.DoFinish() will be called to record the EndOfStream 147 func (s *stream) Header() (md metadata.MD, err error) { 148 if md, err = s.stream.Header(); err != nil { 149 s.DoFinish(err) 150 } 151 return 152 } 153 154 func (s *stream) Trailer() metadata.MD { 155 return s.stream.Trailer() 156 } 157 158 func (s *stream) Context() context.Context { 159 return s.stream.Context() 160 } 161 162 // RecvMsg receives a message from the server. 163 // If an error is returned, stream.DoFinish() will be called to record the end of stream 164 func (s *stream) RecvMsg(m interface{}) (err error) { 165 err = s.recvEndpoint(s.stream, m) 166 if err == nil { 167 err = s.ri.Invocation().BizStatusErr() 168 } 169 if err != nil || s.streamingMode == serviceinfo.StreamingClient { 170 s.DoFinish(err) 171 } 172 return 173 } 174 175 // SendMsg sends a message to the server. 176 // If an error is returned, stream.DoFinish() will be called to record the end of stream 177 func (s *stream) SendMsg(m interface{}) (err error) { 178 if err = s.sendEndpoint(s.stream, m); err != nil { 179 s.DoFinish(err) 180 } 181 return 182 } 183 184 // Close will send a frame with EndStream=true to the server. 185 // It will always return a nil 186 func (s *stream) Close() error { 187 return s.stream.Close() 188 } 189 190 // DoFinish implements the streaming.WithDoFinish interface, and it records the end of stream 191 func (s *stream) DoFinish(err error) { 192 if atomic.SwapUint32(&s.finished, 1) == 1 { 193 // already called 194 return 195 } 196 if err == io.EOF { 197 err = nil 198 } 199 ctx := s.Context() 200 ri := rpcinfo.GetRPCInfo(ctx) 201 s.kc.opt.TracerCtl.DoFinish(ctx, ri, err) 202 }