github.com/ydb-platform/ydb-go-sdk/v3@v3.57.0/internal/topic/topicwriterinternal/encoders.go (about)

     1  package topicwriterinternal
     2  
     3  import (
     4  	"compress/gzip"
     5  	"fmt"
     6  	"io"
     7  	"sync"
     8  
     9  	"github.com/ydb-platform/ydb-go-sdk/v3/internal/grpcwrapper/rawtopic/rawtopiccommon"
    10  	"github.com/ydb-platform/ydb-go-sdk/v3/internal/xerrors"
    11  	"github.com/ydb-platform/ydb-go-sdk/v3/internal/xsync"
    12  	"github.com/ydb-platform/ydb-go-sdk/v3/trace"
    13  )
    14  
    15  const (
    16  	codecMeasureIntervalBatches = 100
    17  	codecUnknown                = rawtopiccommon.CodecUNSPECIFIED
    18  )
    19  
    20  type EncoderMap struct {
    21  	m map[rawtopiccommon.Codec]PublicCreateEncoderFunc
    22  }
    23  
    24  func NewEncoderMap() *EncoderMap {
    25  	return &EncoderMap{
    26  		m: map[rawtopiccommon.Codec]PublicCreateEncoderFunc{
    27  			rawtopiccommon.CodecRaw: func(writer io.Writer) (io.WriteCloser, error) {
    28  				return nopWriteCloser{writer}, nil
    29  			},
    30  			rawtopiccommon.CodecGzip: func(writer io.Writer) (io.WriteCloser, error) {
    31  				return gzip.NewWriter(writer), nil
    32  			},
    33  		},
    34  	}
    35  }
    36  
    37  func (e *EncoderMap) AddEncoder(codec rawtopiccommon.Codec, creator PublicCreateEncoderFunc) {
    38  	e.m[codec] = creator
    39  }
    40  
    41  func (e *EncoderMap) CreateLazyEncodeWriter(codec rawtopiccommon.Codec, target io.Writer) (io.WriteCloser, error) {
    42  	if encoderCreator, ok := e.m[codec]; ok {
    43  		return encoderCreator(target)
    44  	}
    45  
    46  	return nil, xerrors.WithStackTrace(xerrors.Wrap(fmt.Errorf("ydb: unexpected codec '%v' for encode message", codec)))
    47  }
    48  
    49  func (e *EncoderMap) GetSupportedCodecs() rawtopiccommon.SupportedCodecs {
    50  	res := make(rawtopiccommon.SupportedCodecs, 0, len(e.m))
    51  	for codec := range e.m {
    52  		res = append(res, codec)
    53  	}
    54  
    55  	return res
    56  }
    57  
    58  func (e *EncoderMap) IsSupported(codec rawtopiccommon.Codec) bool {
    59  	_, ok := e.m[codec]
    60  
    61  	return ok
    62  }
    63  
    64  type PublicCreateEncoderFunc func(writer io.Writer) (io.WriteCloser, error)
    65  
    66  type nopWriteCloser struct {
    67  	io.Writer
    68  }
    69  
    70  func (nopWriteCloser) Close() error {
    71  	return nil
    72  }
    73  
    74  // EncoderSelector not thread safe
    75  type EncoderSelector struct {
    76  	m *EncoderMap
    77  
    78  	tracer              *trace.Topic
    79  	writerReconnectorID string
    80  	sessionID           string
    81  
    82  	allowedCodecs          rawtopiccommon.SupportedCodecs
    83  	lastSelectedCodec      rawtopiccommon.Codec
    84  	parallelCompressors    int
    85  	batchCounter           int
    86  	measureIntervalBatches int
    87  }
    88  
    89  func NewEncoderSelector(
    90  	m *EncoderMap,
    91  	allowedCodecs rawtopiccommon.SupportedCodecs,
    92  	parallelCompressors int,
    93  	tracer *trace.Topic,
    94  	writerReconnectorID, sessionID string,
    95  ) EncoderSelector {
    96  	if parallelCompressors <= 0 {
    97  		panic("ydb: need leas one allowed compressor")
    98  	}
    99  
   100  	res := EncoderSelector{
   101  		m:                      m,
   102  		parallelCompressors:    parallelCompressors,
   103  		measureIntervalBatches: codecMeasureIntervalBatches,
   104  		tracer:                 tracer,
   105  		writerReconnectorID:    writerReconnectorID,
   106  		sessionID:              sessionID,
   107  	}
   108  	res.ResetAllowedCodecs(allowedCodecs)
   109  
   110  	return res
   111  }
   112  
   113  func (s *EncoderSelector) CompressMessages(messages []messageWithDataContent) (rawtopiccommon.Codec, error) {
   114  	codec, err := s.selectCodec(messages)
   115  	if err == nil {
   116  		onCompressDone := trace.TopicOnWriterCompressMessages(
   117  			s.tracer,
   118  			s.writerReconnectorID,
   119  			s.sessionID,
   120  			codec.ToInt32(),
   121  			messages[0].SeqNo,
   122  			len(messages),
   123  			trace.TopicWriterCompressMessagesReasonCompressData,
   124  		)
   125  		err = cacheMessages(messages, codec, s.parallelCompressors)
   126  		onCompressDone(err)
   127  	}
   128  
   129  	return codec, err
   130  }
   131  
   132  func (s *EncoderSelector) ResetAllowedCodecs(allowedCodecs rawtopiccommon.SupportedCodecs) {
   133  	if s.allowedCodecs.IsEqualsTo(allowedCodecs) {
   134  		return
   135  	}
   136  
   137  	s.allowedCodecs = allowedCodecs.Clone()
   138  	s.lastSelectedCodec = codecUnknown
   139  	s.batchCounter = 0
   140  }
   141  
   142  func (s *EncoderSelector) selectCodec(messages []messageWithDataContent) (rawtopiccommon.Codec, error) {
   143  	if len(s.allowedCodecs) == 0 {
   144  		return codecUnknown, errNoAllowedCodecs
   145  	}
   146  	if len(s.allowedCodecs) == 1 {
   147  		return s.allowedCodecs[0], nil
   148  	}
   149  
   150  	defer func() {
   151  		s.batchCounter++
   152  	}()
   153  
   154  	if s.batchCounter < 0 {
   155  		s.batchCounter = 0
   156  	}
   157  
   158  	// Try every codec at start - for fast reader fail if unexpected include codec, incompatible with readers
   159  	if s.batchCounter < len(s.allowedCodecs) {
   160  		return s.allowedCodecs[s.batchCounter], nil
   161  	}
   162  
   163  	if s.lastSelectedCodec == codecUnknown || s.batchCounter%s.measureIntervalBatches == 0 {
   164  		if codec, err := s.measureCodecs(messages); err == nil {
   165  			s.lastSelectedCodec = codec
   166  		} else {
   167  			return codecUnknown, err
   168  		}
   169  	}
   170  
   171  	return s.lastSelectedCodec, nil
   172  }
   173  
   174  func (s *EncoderSelector) measureCodecs(messages []messageWithDataContent) (rawtopiccommon.Codec, error) {
   175  	if len(s.allowedCodecs) == 0 {
   176  		return codecUnknown, errNoAllowedCodecs
   177  	}
   178  
   179  	sizes := make([]int, len(s.allowedCodecs))
   180  
   181  	for codecIndex, codec := range s.allowedCodecs {
   182  		firstSeqNo := int64(-1)
   183  		if len(messages) > 0 {
   184  			firstSeqNo = messages[0].SeqNo
   185  		}
   186  		onCompressDone := trace.TopicOnWriterCompressMessages(
   187  			s.tracer,
   188  			s.writerReconnectorID,
   189  			s.sessionID,
   190  			codec.ToInt32(),
   191  			firstSeqNo,
   192  			len(messages),
   193  			trace.TopicWriterCompressMessagesReasonCodecsMeasure,
   194  		)
   195  		err := cacheMessages(messages, codec, s.parallelCompressors)
   196  		onCompressDone(err)
   197  		if err != nil {
   198  			return codecUnknown, err
   199  		}
   200  
   201  		size := 0
   202  		for messIndex := range messages {
   203  			content, err := messages[messIndex].GetEncodedBytes(codec)
   204  			if err != nil {
   205  				return codecUnknown, err
   206  			}
   207  			size += len(content)
   208  		}
   209  		sizes[codecIndex] = size
   210  	}
   211  
   212  	minSizeIndex := 0
   213  	for i := range sizes {
   214  		if sizes[i] < sizes[minSizeIndex] {
   215  			minSizeIndex = i
   216  		}
   217  	}
   218  
   219  	return s.allowedCodecs[minSizeIndex], nil
   220  }
   221  
   222  func cacheMessages(messages []messageWithDataContent, codec rawtopiccommon.Codec, workerCount int) error {
   223  	if len(messages) < workerCount {
   224  		workerCount = len(messages)
   225  	}
   226  
   227  	// no need goroutines and synchronization for zero or one worker
   228  	if workerCount < 2 {
   229  		for i := range messages {
   230  			if _, err := messages[i].GetEncodedBytes(codec); err != nil {
   231  				return err
   232  			}
   233  		}
   234  	}
   235  
   236  	tasks := make(chan *messageWithDataContent, len(messages))
   237  
   238  	for i := range messages {
   239  		tasks <- &messages[i]
   240  	}
   241  	close(tasks)
   242  
   243  	var resErrMutex xsync.Mutex
   244  	var resErr error
   245  
   246  	var wg sync.WaitGroup
   247  	worker := func() {
   248  		defer wg.Done()
   249  
   250  		for task := range tasks {
   251  			var localErr error
   252  			resErrMutex.WithLock(func() {
   253  				localErr = resErr
   254  			})
   255  
   256  			if localErr != nil {
   257  				return
   258  			}
   259  			localErr = task.CacheMessageData(codec)
   260  			if localErr != nil {
   261  				resErrMutex.WithLock(func() {
   262  					resErr = localErr
   263  				})
   264  
   265  				return
   266  			}
   267  		}
   268  	}
   269  
   270  	wg.Add(workerCount)
   271  	for i := 0; i < workerCount; i++ {
   272  		go worker()
   273  	}
   274  
   275  	wg.Wait()
   276  
   277  	return resErr
   278  }