github.com/ydb-platform/ydb-go-sdk/v3@v3.57.0/internal/grpcwrapper/rawtopic/rawtopicwriter/streamwriter.go (about)

     1  package rawtopicwriter
     2  
     3  import (
     4  	"errors"
     5  	"fmt"
     6  	"reflect"
     7  	"sync"
     8  	"sync/atomic"
     9  
    10  	"github.com/ydb-platform/ydb-go-genproto/protos/Ydb_Topic"
    11  	"google.golang.org/grpc/codes"
    12  	grpcStatus "google.golang.org/grpc/status"
    13  
    14  	"github.com/ydb-platform/ydb-go-sdk/v3/internal/grpcwrapper/rawtopic/rawtopiccommon"
    15  	"github.com/ydb-platform/ydb-go-sdk/v3/internal/grpcwrapper/rawydb"
    16  	"github.com/ydb-platform/ydb-go-sdk/v3/internal/xerrors"
    17  )
    18  
    19  var errConcurencyReadDenied = xerrors.Wrap(errors.New("ydb: read from rawtopicwriter in parallel"))
    20  
    21  type GrpcStream interface {
    22  	Send(messageNew *Ydb_Topic.StreamWriteMessage_FromClient) error
    23  	Recv() (*Ydb_Topic.StreamWriteMessage_FromServer, error)
    24  	CloseSend() error
    25  }
    26  
    27  type StreamWriter struct {
    28  	readCounter int32
    29  
    30  	sendCloseMtx sync.Mutex
    31  	Stream       GrpcStream
    32  }
    33  
    34  func (w *StreamWriter) Recv() (ServerMessage, error) {
    35  	readCnt := atomic.AddInt32(&w.readCounter, 1)
    36  	defer atomic.AddInt32(&w.readCounter, -1)
    37  
    38  	if readCnt != 1 {
    39  		return nil, xerrors.WithStackTrace(errConcurencyReadDenied)
    40  	}
    41  
    42  	grpcMsg, err := w.Stream.Recv()
    43  	if err != nil {
    44  		if !xerrors.IsErrorFromServer(err) {
    45  			err = xerrors.Transport(err)
    46  		}
    47  
    48  		return nil, xerrors.WithStackTrace(xerrors.Wrap(fmt.Errorf(
    49  			"ydb: failed to read grpc message from writer stream: %w",
    50  			err,
    51  		)))
    52  	}
    53  
    54  	var meta rawtopiccommon.ServerMessageMetadata
    55  	if err = meta.MetaFromStatusAndIssues(grpcMsg); err != nil {
    56  		return nil, err
    57  	}
    58  	if !meta.Status.IsSuccess() {
    59  		return nil, xerrors.WithStackTrace(fmt.Errorf("ydb: bad status from topic server: %v", meta.Status))
    60  	}
    61  
    62  	switch v := grpcMsg.GetServerMessage().(type) {
    63  	case *Ydb_Topic.StreamWriteMessage_FromServer_InitResponse:
    64  		var res InitResult
    65  		res.ServerMessageMetadata = meta
    66  		res.mustFromProto(v.InitResponse)
    67  
    68  		return &res, nil
    69  	case *Ydb_Topic.StreamWriteMessage_FromServer_WriteResponse:
    70  		var res WriteResult
    71  		res.ServerMessageMetadata = meta
    72  		err = res.fromProto(v.WriteResponse)
    73  		if err != nil {
    74  			return nil, err
    75  		}
    76  
    77  		return &res, nil
    78  	case *Ydb_Topic.StreamWriteMessage_FromServer_UpdateTokenResponse:
    79  		var res UpdateTokenResponse
    80  		res.MustFromProto(v.UpdateTokenResponse)
    81  
    82  		return &res, nil
    83  	default:
    84  		return nil, xerrors.WithStackTrace(xerrors.Wrap(fmt.Errorf(
    85  			"ydb: unexpected message type received from raw writer stream: '%v'",
    86  			reflect.TypeOf(grpcMsg),
    87  		)))
    88  	}
    89  }
    90  
    91  func (w *StreamWriter) Send(rawMsg ClientMessage) (err error) {
    92  	w.sendCloseMtx.Lock()
    93  	defer func() {
    94  		w.sendCloseMtx.Unlock()
    95  		err = xerrors.Transport(err)
    96  	}()
    97  
    98  	var protoMsg Ydb_Topic.StreamWriteMessage_FromClient
    99  	switch v := rawMsg.(type) {
   100  	case *InitRequest:
   101  		initReqProto, initErr := v.toProto()
   102  		if initErr != nil {
   103  			return initErr
   104  		}
   105  		protoMsg.ClientMessage = &Ydb_Topic.StreamWriteMessage_FromClient_InitRequest{
   106  			InitRequest: initReqProto,
   107  		}
   108  	case *WriteRequest:
   109  		writeReqProto, writeErr := v.toProto()
   110  		if writeErr != nil {
   111  			return writeErr
   112  		}
   113  
   114  		return sendWriteRequest(w.Stream.Send, writeReqProto)
   115  	case *UpdateTokenRequest:
   116  		protoMsg.ClientMessage = &Ydb_Topic.StreamWriteMessage_FromClient_UpdateTokenRequest{
   117  			UpdateTokenRequest: v.ToProto(),
   118  		}
   119  	default:
   120  		return xerrors.WithStackTrace(xerrors.Wrap(fmt.Errorf(
   121  			"ydb: unexpected message type for send to raw writer stream: '%v'",
   122  			reflect.TypeOf(rawMsg),
   123  		)))
   124  	}
   125  
   126  	err = w.Stream.Send(&protoMsg)
   127  	if err != nil {
   128  		return xerrors.WithStackTrace(xerrors.Wrap(fmt.Errorf("ydb: failed to send grpc message to writer stream: %w", err)))
   129  	}
   130  
   131  	return nil
   132  }
   133  
   134  type sendFunc func(req *Ydb_Topic.StreamWriteMessage_FromClient) error
   135  
   136  func (w *StreamWriter) CloseSend() error {
   137  	w.sendCloseMtx.Lock()
   138  	defer w.sendCloseMtx.Unlock()
   139  
   140  	return w.Stream.CloseSend()
   141  }
   142  
   143  type ClientMessage interface {
   144  	isClientMessage()
   145  }
   146  
   147  type clientMessageImpl struct{}
   148  
   149  func (*clientMessageImpl) isClientMessage() {}
   150  
   151  type ServerMessage interface {
   152  	isServerMessage()
   153  	StatusData() rawtopiccommon.ServerMessageMetadata
   154  	SetStatus(status rawydb.StatusCode)
   155  }
   156  
   157  type serverMessageImpl struct{}
   158  
   159  func (*serverMessageImpl) isServerMessage() {}
   160  
   161  func sendWriteRequest(send sendFunc, req *Ydb_Topic.StreamWriteMessage_FromClient_WriteRequest) error {
   162  	sendErr := send(&Ydb_Topic.StreamWriteMessage_FromClient{
   163  		ClientMessage: req,
   164  	})
   165  
   166  	if sendErr == nil {
   167  		return nil
   168  	}
   169  
   170  	grpcStatus, ok := grpcStatus.FromError(sendErr)
   171  	if !ok {
   172  		return sendErr
   173  	}
   174  
   175  	grpcMessages := req.WriteRequest.GetMessages()
   176  	if grpcStatus.Code() != codes.ResourceExhausted || len(grpcMessages) < 2 {
   177  		return sendErr
   178  	}
   179  
   180  	splitIndex := len(grpcMessages) / 2
   181  	firstMessages, lastMessages := grpcMessages[:splitIndex], grpcMessages[splitIndex:]
   182  	defer func() {
   183  		req.WriteRequest.Messages = grpcMessages
   184  	}()
   185  
   186  	req.WriteRequest.Messages = firstMessages
   187  	err := sendWriteRequest(send, req)
   188  	if err != nil {
   189  		return err
   190  	}
   191  
   192  	req.WriteRequest.Messages = lastMessages
   193  
   194  	return sendWriteRequest(send, req)
   195  }