storj.io/uplink@v1.13.0/private/storage/streams/splitter/base_splitter.go (about)

     1  // Copyright (C) 2023 Storj Labs, Inc.
     2  // See LICENSE for copying information.
     3  
     4  package splitter
     5  
     6  import (
     7  	"context"
     8  	"io"
     9  	"sync"
    10  
    11  	"github.com/zeebo/errs"
    12  )
    13  
    14  // WriteFinisher is a Writer that can be signalled by the caller when it is
    15  // done being written do. Subsequent calls to write should return an error
    16  // after writing is "done".
    17  type WriteFinisher interface {
    18  	io.Writer
    19  	DoneWriting(error)
    20  }
    21  
    22  type baseSplitter struct {
    23  	split   int64
    24  	minimum int64
    25  
    26  	writeMu   sync.Mutex // ensures only a single Write call at once
    27  	nextMu    sync.Mutex // ensures only a single Next call at once
    28  	currentMu sync.Mutex // protects access to current
    29  
    30  	emitted    bool               // set true when the first split is emitted
    31  	term       chan struct{}      // closed when finish is called
    32  	err        error              // captures the error passed to finish
    33  	finishOnce sync.Once          // only want to finish once
    34  	temp       []byte             // holds temporary data up to minimum
    35  	written    int64              // how many bytes written into current
    36  	next       chan WriteFinisher // channel for the next split to write into
    37  	current    WriteFinisher      // current split being written to
    38  }
    39  
    40  func newBaseSplitter(split, minimum int64) *baseSplitter {
    41  	return &baseSplitter{
    42  		split:   split,
    43  		minimum: minimum,
    44  
    45  		term: make(chan struct{}),
    46  		temp: make([]byte, 0, minimum),
    47  		next: make(chan WriteFinisher),
    48  	}
    49  }
    50  
    51  func (bs *baseSplitter) Finish(err error) {
    52  	bs.finishOnce.Do(func() {
    53  		bs.err = err
    54  		close(bs.term)
    55  		bs.currentMu.Lock()
    56  		if bs.current != nil {
    57  			bs.current.DoneWriting(err)
    58  		}
    59  		bs.currentMu.Unlock()
    60  	})
    61  }
    62  
    63  func (bs *baseSplitter) Write(p []byte) (n int, err error) {
    64  	// only ever allow one Write call at a time
    65  	bs.writeMu.Lock()
    66  	defer bs.writeMu.Unlock()
    67  
    68  	select {
    69  	case <-bs.term:
    70  		if bs.err != nil {
    71  			return 0, bs.err
    72  		}
    73  		return 0, errs.New("already finished")
    74  	default:
    75  	}
    76  
    77  	for len(p) > 0 {
    78  		// if we have no remaining bytes to write, close and move on
    79  		rem := bs.split - bs.written
    80  		if rem == 0 && bs.current != nil {
    81  			bs.currentMu.Lock()
    82  			bs.current.DoneWriting(nil)
    83  			bs.current = nil
    84  			bs.currentMu.Unlock()
    85  			bs.written = 0
    86  		}
    87  
    88  		// if we have a current buffer, write up to the point of the next split
    89  		if bs.current != nil {
    90  			pp := p
    91  			if rem < int64(len(pp)) {
    92  				pp = p[:rem]
    93  			}
    94  
    95  			// drop the state mutex so that Finish calls can interrupt
    96  			nn, err := bs.current.Write(pp)
    97  
    98  			// update tracking of how many bytes have been written
    99  			n += nn
   100  			bs.written += int64(nn)
   101  			p = p[nn:]
   102  
   103  			if err != nil {
   104  				bs.Finish(err)
   105  				return n, err
   106  			}
   107  
   108  			continue
   109  		}
   110  
   111  		// if we can fully fit in temp, do so
   112  		if len(bs.temp)+len(p) <= cap(bs.temp) {
   113  			bs.temp = append(bs.temp, p...)
   114  
   115  			n += len(p)
   116  			p = p[len(p):]
   117  
   118  			continue
   119  		}
   120  
   121  		// fill up temp as much as possible and wait for a new buffer
   122  		nn := copy(bs.temp[len(bs.temp):cap(bs.temp)], p)
   123  		bs.temp = bs.temp[:cap(bs.temp)]
   124  
   125  		// update tracking of how many bytes have been written
   126  		n += nn
   127  		p = p[nn:]
   128  
   129  		select {
   130  		case wf := <-bs.next:
   131  			bs.currentMu.Lock()
   132  			bs.current = wf
   133  			bs.currentMu.Unlock()
   134  
   135  			n, err := wf.Write(bs.temp)
   136  
   137  			bs.temp = bs.temp[:0]
   138  			bs.written += int64(n)
   139  
   140  			if err != nil {
   141  				bs.Finish(err)
   142  				return n, err
   143  			}
   144  
   145  		case <-bs.term:
   146  			if bs.err != nil {
   147  				return n, bs.err
   148  			}
   149  			return n, errs.New("write interrupted by finish")
   150  		}
   151  	}
   152  
   153  	return n, nil
   154  }
   155  
   156  func (bs *baseSplitter) Next(ctx context.Context, wf WriteFinisher) (inline []byte, eof bool, err error) {
   157  	if err := ctx.Err(); err != nil {
   158  		return nil, false, err
   159  	}
   160  
   161  	bs.nextMu.Lock()
   162  	defer bs.nextMu.Unlock()
   163  
   164  	select {
   165  	case <-ctx.Done():
   166  		return nil, false, ctx.Err()
   167  
   168  	case bs.next <- wf:
   169  		bs.emitted = true
   170  		return nil, false, nil
   171  
   172  	case <-bs.term:
   173  		if bs.err != nil {
   174  			return nil, false, bs.err
   175  		}
   176  		if len(bs.temp) > 0 || !bs.emitted {
   177  			bs.emitted = true
   178  
   179  			temp := bs.temp
   180  			bs.temp = bs.temp[:0]
   181  
   182  			return temp, false, nil
   183  		}
   184  		return nil, true, nil
   185  	}
   186  }