storj.io/uplink@v1.13.0/private/storage/streams/splitter/base_splitter.go (about) 1 // Copyright (C) 2023 Storj Labs, Inc. 2 // See LICENSE for copying information. 3 4 package splitter 5 6 import ( 7 "context" 8 "io" 9 "sync" 10 11 "github.com/zeebo/errs" 12 ) 13 14 // WriteFinisher is a Writer that can be signalled by the caller when it is 15 // done being written do. Subsequent calls to write should return an error 16 // after writing is "done". 17 type WriteFinisher interface { 18 io.Writer 19 DoneWriting(error) 20 } 21 22 type baseSplitter struct { 23 split int64 24 minimum int64 25 26 writeMu sync.Mutex // ensures only a single Write call at once 27 nextMu sync.Mutex // ensures only a single Next call at once 28 currentMu sync.Mutex // protects access to current 29 30 emitted bool // set true when the first split is emitted 31 term chan struct{} // closed when finish is called 32 err error // captures the error passed to finish 33 finishOnce sync.Once // only want to finish once 34 temp []byte // holds temporary data up to minimum 35 written int64 // how many bytes written into current 36 next chan WriteFinisher // channel for the next split to write into 37 current WriteFinisher // current split being written to 38 } 39 40 func newBaseSplitter(split, minimum int64) *baseSplitter { 41 return &baseSplitter{ 42 split: split, 43 minimum: minimum, 44 45 term: make(chan struct{}), 46 temp: make([]byte, 0, minimum), 47 next: make(chan WriteFinisher), 48 } 49 } 50 51 func (bs *baseSplitter) Finish(err error) { 52 bs.finishOnce.Do(func() { 53 bs.err = err 54 close(bs.term) 55 bs.currentMu.Lock() 56 if bs.current != nil { 57 bs.current.DoneWriting(err) 58 } 59 bs.currentMu.Unlock() 60 }) 61 } 62 63 func (bs *baseSplitter) Write(p []byte) (n int, err error) { 64 // only ever allow one Write call at a time 65 bs.writeMu.Lock() 66 defer bs.writeMu.Unlock() 67 68 select { 69 case <-bs.term: 70 if bs.err != nil { 71 return 0, bs.err 72 } 73 return 0, errs.New("already finished") 74 default: 75 } 76 77 for len(p) > 0 { 78 // if we have no remaining bytes to write, close and move on 79 rem := bs.split - bs.written 80 if rem == 0 && bs.current != nil { 81 bs.currentMu.Lock() 82 bs.current.DoneWriting(nil) 83 bs.current = nil 84 bs.currentMu.Unlock() 85 bs.written = 0 86 } 87 88 // if we have a current buffer, write up to the point of the next split 89 if bs.current != nil { 90 pp := p 91 if rem < int64(len(pp)) { 92 pp = p[:rem] 93 } 94 95 // drop the state mutex so that Finish calls can interrupt 96 nn, err := bs.current.Write(pp) 97 98 // update tracking of how many bytes have been written 99 n += nn 100 bs.written += int64(nn) 101 p = p[nn:] 102 103 if err != nil { 104 bs.Finish(err) 105 return n, err 106 } 107 108 continue 109 } 110 111 // if we can fully fit in temp, do so 112 if len(bs.temp)+len(p) <= cap(bs.temp) { 113 bs.temp = append(bs.temp, p...) 114 115 n += len(p) 116 p = p[len(p):] 117 118 continue 119 } 120 121 // fill up temp as much as possible and wait for a new buffer 122 nn := copy(bs.temp[len(bs.temp):cap(bs.temp)], p) 123 bs.temp = bs.temp[:cap(bs.temp)] 124 125 // update tracking of how many bytes have been written 126 n += nn 127 p = p[nn:] 128 129 select { 130 case wf := <-bs.next: 131 bs.currentMu.Lock() 132 bs.current = wf 133 bs.currentMu.Unlock() 134 135 n, err := wf.Write(bs.temp) 136 137 bs.temp = bs.temp[:0] 138 bs.written += int64(n) 139 140 if err != nil { 141 bs.Finish(err) 142 return n, err 143 } 144 145 case <-bs.term: 146 if bs.err != nil { 147 return n, bs.err 148 } 149 return n, errs.New("write interrupted by finish") 150 } 151 } 152 153 return n, nil 154 } 155 156 func (bs *baseSplitter) Next(ctx context.Context, wf WriteFinisher) (inline []byte, eof bool, err error) { 157 if err := ctx.Err(); err != nil { 158 return nil, false, err 159 } 160 161 bs.nextMu.Lock() 162 defer bs.nextMu.Unlock() 163 164 select { 165 case <-ctx.Done(): 166 return nil, false, ctx.Err() 167 168 case bs.next <- wf: 169 bs.emitted = true 170 return nil, false, nil 171 172 case <-bs.term: 173 if bs.err != nil { 174 return nil, false, bs.err 175 } 176 if len(bs.temp) > 0 || !bs.emitted { 177 bs.emitted = true 178 179 temp := bs.temp 180 bs.temp = bs.temp[:0] 181 182 return temp, false, nil 183 } 184 return nil, true, nil 185 } 186 }