github.com/fluhus/gostuff@v0.4.1-0.20240331134726-be71864f2b5d/ppln/v2/serial.go (about)

     1  package ppln
     2  
     3  import (
     4  	"fmt"
     5  	"iter"
     6  	"sync"
     7  	"sync/atomic"
     8  
     9  	"github.com/fluhus/gostuff/heaps"
    10  )
    11  
    12  // Serial starts a multi-goroutine transformation pipeline that maintains the
    13  // order of the inputs.
    14  //
    15  // Input is an iterator over the input values to be transformed.
    16  // It will be called in a thread-safe manner.
    17  // Transform receives an input (a), 0-based input serial number (i), 0-based
    18  // goroutine number (g), and returns the result of processing a.
    19  // Output acts on a single result, and will be called by the same
    20  // order of the input, in a thread-safe manner.
    21  //
    22  // If one of the functions returns a non-nil error, the process stops and the
    23  // error is returned. Otherwise returns nil.
    24  func Serial[T1 any, T2 any](
    25  	ngoroutines int,
    26  	input iter.Seq2[T1, error],
    27  	transform func(a T1, i int, g int) (T2, error),
    28  	output func(a T2) error) error {
    29  	if ngoroutines < 1 {
    30  		panic(fmt.Sprintf("bad number of goroutines: %d", ngoroutines))
    31  	}
    32  	pull, pstop := iter.Pull2(input)
    33  	defer pstop()
    34  
    35  	// An optimization for a single thread.
    36  	if ngoroutines == 1 {
    37  		i := 0
    38  		for {
    39  			t1, err, ok := pull()
    40  			ii := i
    41  			i++
    42  
    43  			if !ok {
    44  				return nil
    45  			}
    46  			if err != nil {
    47  				return err
    48  			}
    49  
    50  			t2, err := transform(t1, ii, 0)
    51  			if err != nil {
    52  				return err
    53  			}
    54  			if err := output(t2); err != nil {
    55  				return err
    56  			}
    57  		}
    58  	}
    59  
    60  	ilock := &sync.Mutex{}
    61  	olock := &sync.Mutex{}
    62  	errs := make(chan error, ngoroutines)
    63  	stop := &atomic.Bool{}
    64  	items := &serialHeap[T2]{
    65  		data: heaps.New(func(a, b serialItem[T2]) bool {
    66  			return a.i < b.i
    67  		}),
    68  	}
    69  
    70  	i := 0
    71  	for g := 0; g < ngoroutines; g++ {
    72  		go func(g int) {
    73  			for {
    74  				if stop.Load() {
    75  					errs <- nil
    76  					return
    77  				}
    78  
    79  				ilock.Lock()
    80  				t1, err, ok := pull()
    81  				ii := i
    82  				i++
    83  				ilock.Unlock()
    84  
    85  				if !ok {
    86  					errs <- nil
    87  					return
    88  				}
    89  				if err != nil {
    90  					stop.Store(true)
    91  					errs <- err
    92  					return
    93  				}
    94  
    95  				t2, err := transform(t1, ii, g)
    96  				if err != nil {
    97  					stop.Store(true)
    98  					errs <- err
    99  					return
   100  				}
   101  
   102  				olock.Lock()
   103  				items.put(serialItem[T2]{ii, t2})
   104  				for items.ok() {
   105  					err = output(items.pop())
   106  					if err != nil {
   107  						olock.Unlock()
   108  						stop.Store(true)
   109  						errs <- err
   110  						return
   111  					}
   112  				}
   113  				olock.Unlock()
   114  			}
   115  		}(g)
   116  	}
   117  
   118  	for g := 0; g < ngoroutines; g++ {
   119  		if err := <-errs; err != nil {
   120  			return err
   121  		}
   122  	}
   123  	return nil
   124  }
   125  
   126  // General data with a serial number.
   127  type serialItem[T any] struct {
   128  	i    int
   129  	data T
   130  }
   131  
   132  // A heap of serial items. Sorts by serial number.
   133  type serialHeap[T any] struct {
   134  	next int
   135  	data *heaps.Heap[serialItem[T]]
   136  }
   137  
   138  // Checks whether the minimal element in the heap is the next in the series.
   139  func (s *serialHeap[T]) ok() bool {
   140  	return s.data.Len() > 0 && s.data.Head().i == s.next
   141  }
   142  
   143  // Removes and returns the minimal element in the heap. Panics if the element
   144  // is not the next in the series.
   145  func (s *serialHeap[T]) pop() T {
   146  	if !s.ok() {
   147  		panic("get when not ok")
   148  	}
   149  	s.next++
   150  	a := s.data.Pop()
   151  	return a.data
   152  }
   153  
   154  // Adds an item to the heap.
   155  func (s *serialHeap[T]) put(item serialItem[T]) {
   156  	if item.i < s.next {
   157  		panic(fmt.Sprintf("put(%d) when next is %d", item.i, s.next))
   158  	}
   159  	s.data.Push(item)
   160  }