github.com/haraldrudell/parl@v0.4.176/pio/context-copier.go (about)

     1  /*
     2  © 2023–present Harald Rudell <harald.rudell@gmail.com> (https://haraldrudell.github.io/haraldrudell/)
     3  ISC License
     4  */
     5  
     6  package pio
     7  
     8  import (
     9  	"context"
    10  	"errors"
    11  	"io"
    12  	"sync/atomic"
    13  
    14  	"github.com/haraldrudell/parl"
    15  	"github.com/haraldrudell/parl/perrors"
    16  )
    17  
    18  const (
    19  	// buffer size if no buffer provided is 1 MiB
    20  	copyContextBufferSize = 1024 * 1024 // 1 MiB
    21  )
    22  
    23  // errInvalidWrite means that a write returned an impossible count
    24  //   - cause is buggy [io.Writer] implementation
    25  var ErrInvalidWrite = errors.New("invalid write result")
    26  
    27  var ErrCopyShutdown = errors.New("Copy received Shutdown")
    28  
    29  var _ = io.Copy
    30  
    31  // ContextCopier is an io.Copy cancelable via context
    32  type ContextCopier struct {
    33  	buf []byte
    34  
    35  	isShutdown atomic.Bool
    36  	cancelFunc atomic.Pointer[context.CancelFunc]
    37  
    38  	// Copy fields
    39  
    40  	readCloser  *ContextReader
    41  	writeCloser *ContextWriter
    42  	// g is error channel receiving result from the copying thread
    43  	g parl.GoResult
    44  }
    45  
    46  // NewContextCopier copies src to dst aborting if context is canceled
    47  //   - buf is buffer that can be used
    48  //   - if reader implements WriteTo or writer implements ReadFrom,
    49  //     no buffer is required
    50  //   - if a buffer is reqiired ans missing, 1 MiB is allocated
    51  //   - Copy methods does copying
    52  //   - Shutdown method or context cancel aborts Copy in progress
    53  //   - if the runtime type of reader or writer is [io.Closable],
    54  //     a thread is active during copying
    55  func NewContextCopier(buf ...[]byte) (copier *ContextCopier) {
    56  	var c = ContextCopier{}
    57  	if len(buf) > 0 {
    58  		c.buf = buf[0]
    59  	}
    60  	return &c
    61  }
    62  
    63  // Copy copies from src to dst until end of data, error, Shutdown or context cancel
    64  //   - Shutdown method or context cancel aborts copy in progress
    65  //   - on context cancel, error returned is [context.Canceled]
    66  //   - on Shutdown, error returned has [ErrCopyShutdown]
    67  //   - if the runtime type of dst or src is [io.Closable],
    68  //     a thread is active during copying
    69  //   - such reader or writer will be closed
    70  func (c *ContextCopier) Copy(
    71  	dst io.Writer,
    72  	src io.Reader,
    73  	ctx context.Context,
    74  ) (n int64, err error) {
    75  	if c.readCloser != nil {
    76  		panic(perrors.NewPF("second invocation"))
    77  	} else if dst == nil {
    78  		panic(parl.NilError("dst"))
    79  	} else if src == nil {
    80  		panic(parl.NilError("src"))
    81  	} else if ctx == nil {
    82  		panic(parl.NilError("ctx"))
    83  	}
    84  	// check for shutdown prior to Copy
    85  	if c.isShutdown.Load() {
    86  		err = perrors.ErrorfPF("%w", ErrCopyShutdown)
    87  		return
    88  	}
    89  	defer c.copyEnd(&err) // ensures context to be canceled
    90  
    91  	// store context reader writer
    92  	var cancelFunc context.CancelFunc
    93  	ctx, cancelFunc = context.WithCancel(ctx)
    94  	c.cancelFunc.Store(&cancelFunc)
    95  	c.readCloser = NewContextReader(src, ctx)
    96  	c.writeCloser = NewContextWriter(dst, ctx)
    97  
    98  	// if either the reader or the writer can be closed,
    99  	// a separate thread is used
   100  	//	- the thread closes in parallel on context cancel forcing an
   101  	//		immediate abort to copying
   102  	if c.readCloser.IsCloseable() || c.writeCloser.IsCloseable() {
   103  		c.g = parl.NewGoResult()
   104  		go c.contextCopierCloserThread(ctx.Done())
   105  	}
   106  
   107  	// If the reader has a WriteTo method, use it to do the copy.
   108  	//   - buffer-less one-go copy
   109  	//   - on end of file, err is nil
   110  	//   - err may be read or write errors
   111  	//   - on context cancel, error is [context.Canceled]
   112  	if writerTo, ok := src.(io.WriterTo); ok {
   113  		return writerTo.WriteTo(c.writeCloser) // reader’s WriteTo gets the writer
   114  	}
   115  
   116  	// Similarly, if the writer has a ReadFrom method, use it to do the copy
   117  	//   - buffer-less one-go copy
   118  	//   - on end of file, err is nil
   119  	//   - err may be read or write errors
   120  	//   - on context cancel, error is [context.Canceled]
   121  	if readerFrom, ok := dst.(io.ReaderFrom); ok {
   122  		return readerFrom.ReadFrom(c.readCloser) // writer’s ReadFrom gets the reader
   123  	}
   124  
   125  	// copy using an intermediate buffer
   126  	//   - on end of file, err is nil
   127  	//   - err may be read or write errors
   128  	//   - on context cancel, error is [context.Canceled]
   129  
   130  	// ensure buffer
   131  	var buf = c.buf
   132  	if buf == nil {
   133  		buf = make([]byte, copyContextBufferSize)
   134  	}
   135  
   136  	for {
   137  
   138  		// read bytes
   139  		var nRead, errReading = c.readCloser.Read(buf)
   140  
   141  		// write any read bytes
   142  		if nRead > 0 {
   143  			var nWritten, errWriting = c.writeCloser.Write(buf[:nRead])
   144  			if nWritten < 0 || nRead < nWritten {
   145  				nWritten = 0
   146  				if errWriting == nil {
   147  					errWriting = ErrInvalidWrite
   148  				}
   149  			}
   150  
   151  			// handle write outcome
   152  			n += int64(nWritten)
   153  			if errWriting != nil {
   154  				err = errWriting
   155  				return // write error return
   156  			}
   157  			if nRead != nWritten {
   158  				err = io.ErrShortWrite
   159  				return // short write error return
   160  			}
   161  		}
   162  
   163  		// handle read outcome
   164  		if errReading == io.EOF {
   165  			return // end of data return
   166  		} else if errReading != nil {
   167  			err = errReading
   168  			return // read error return
   169  		}
   170  	}
   171  }
   172  
   173  // Shutdown order the thread to exit and
   174  // wait for its result
   175  //   - every Copy invocation will have a Shutdown
   176  //     either by consumer or the deferred copyEnd method
   177  func (c *ContextCopier) Shutdown() {
   178  	if c.isShutdown.Load() {
   179  		return // already shutdown
   180  	} else if !c.isShutdown.CompareAndSwap(false, true) {
   181  		return // another thread shut down
   182  	}
   183  
   184  	// cancel the child context
   185  	//	- any copy in progress is aborted
   186  	//	- if a thread is running, this orders it to exit
   187  	if cfp := c.cancelFunc.Load(); cfp != nil {
   188  		c.cancelFunc.Store(nil)
   189  		(*cfp)() // invoke cancelFunc
   190  	}
   191  }
   192  
   193  // ContextCopierCloseThread is used when either the
   194  // reader or the writer is [io.Closable]
   195  //   - on context cancel, the thread closing reader or writer will
   196  //     immediately cancel copying
   197  func (c *ContextCopier) contextCopierCloserThread(done <-chan struct{}) {
   198  	var err error
   199  	defer c.g.SendError(&err)
   200  	defer parl.PanicToErr(&err)
   201  
   202  	// wait for thread exit order
   203  	//	- app cancel or ordered to exit
   204  	<-done
   205  
   206  	// close reader and writer
   207  	c.close(&err)
   208  }
   209  
   210  // copyEnd:
   211  //   - cancels the context,
   212  //   - if thread, awaits the thread to close reader or writer and collects the result
   213  //   - otherwise closes reader and writer
   214  func (c *ContextCopier) copyEnd(errp *error) {
   215  
   216  	// ensure Shutdown has been invoked
   217  	if c.isShutdown.Load() {
   218  		*errp = perrors.AppendError(*errp, perrors.ErrorfPF("%w", ErrCopyShutdown))
   219  	} else {
   220  		// cancel the context
   221  		// order any thread to exit
   222  		c.Shutdown()
   223  	}
   224  
   225  	// await thread doing close, or do close
   226  	if g := c.g; g.IsValid() {
   227  		// wait for result from thread
   228  		g.ReceiveError(errp)
   229  	} else {
   230  		c.close(errp)
   231  	}
   232  }
   233  
   234  // close closes both reader and writer if their runtime type
   235  // implements [io.Closer]
   236  func (c *ContextCopier) close(errp *error) {
   237  	parl.Close(c.readCloser, errp)
   238  	parl.Close(c.writeCloser, errp)
   239  }