github.com/fluhus/gostuff@v0.4.1-0.20240331134726-be71864f2b5d/ppln/serial.go (about)

     1  package ppln
     2  
     3  import (
     4  	"fmt"
     5  	"sync"
     6  
     7  	"github.com/fluhus/gostuff/heaps"
     8  )
     9  
    10  const (
    11  	// Size of pipeline channel buffers per goroutine.
    12  	chanLenCoef = 1000
    13  )
    14  
    15  // Serial starts a multi-goroutine transformation pipeline that maintains the
    16  // order of the inputs.
    17  //
    18  // Pusher should call push on every input value. Stop indicates if an error was
    19  // returned and pushing should stop.
    20  // Mapper receives an input (a), 0-based input serial number (i), 0-based
    21  // goroutine number (g), and returns the result of processing a.
    22  // Puller acts on a single result, and will be called by the same
    23  // order of pusher's input.
    24  //
    25  // If one of the functions returns a non-nil error, the process stops and the
    26  // error is returned. Otherwise returns nil.
    27  func Serial[T1 any, T2 any](
    28  	ngoroutines int,
    29  	pusher func(push func(T1), stop func() bool) error,
    30  	mapper func(a T1, i int, g int) (T2, error),
    31  	puller func(a T2) error) error {
    32  	if ngoroutines < 0 {
    33  		panic(fmt.Sprintf("bad number of goroutines: %d", ngoroutines))
    34  	}
    35  
    36  	var err error
    37  
    38  	// An optimization for a single thread.
    39  	if ngoroutines == 0 {
    40  		i := 0
    41  		pusher(func(a T1) {
    42  			if err != nil {
    43  				return
    44  			}
    45  			t2, merr := mapper(a, i, 0)
    46  			if merr != nil {
    47  				err = merr
    48  				return
    49  			}
    50  			perr := puller(t2)
    51  			if perr != nil {
    52  				err = perr
    53  				return
    54  			}
    55  			i++
    56  		}, func() bool { return err != nil })
    57  		return err
    58  	}
    59  
    60  	push := make(chan serialItem[T1], ngoroutines*chanLenCoef)
    61  	pull := make(chan serialItem[T2], ngoroutines*chanLenCoef)
    62  	wait := &sync.WaitGroup{}
    63  	wait.Add(ngoroutines)
    64  
    65  	go func() {
    66  		i := 0
    67  		perr := pusher(func(a T1) {
    68  			if err != nil {
    69  				return
    70  			}
    71  			push <- serialItem[T1]{i, a}
    72  			i++
    73  		}, func() bool { return err != nil })
    74  		if perr != nil && err == nil {
    75  			err = perr
    76  		}
    77  		close(push)
    78  	}()
    79  	for i := 0; i < ngoroutines; i++ {
    80  		i := i
    81  		go func() {
    82  			for item := range push {
    83  				if err != nil {
    84  					continue // Drain channel.
    85  				}
    86  				t2, merr := mapper(item.data, item.i, i)
    87  				if merr != nil {
    88  					if err == nil {
    89  						err = merr
    90  					}
    91  					continue
    92  				}
    93  				pull <- serialItem[T2]{item.i, t2}
    94  			}
    95  			wait.Done()
    96  		}()
    97  	}
    98  	go func() {
    99  		items := &serialHeap[T2]{
   100  			data: heaps.New(func(a, b serialItem[T2]) bool {
   101  				return a.i < b.i
   102  			}),
   103  		}
   104  		for item := range pull {
   105  			if err != nil {
   106  				continue // Drain channel.
   107  			}
   108  			items.put(item)
   109  			for items.ok() {
   110  				perr := puller(items.pop())
   111  				if perr != nil && err == nil {
   112  					err = perr
   113  				}
   114  			}
   115  		}
   116  		wait.Done()
   117  	}()
   118  
   119  	wait.Wait() // Wait for workers.
   120  	wait.Add(1)
   121  	close(pull)
   122  	wait.Wait() // Wait for pull.
   123  
   124  	return err
   125  }
   126  
   127  // General data with a serial number.
   128  type serialItem[T any] struct {
   129  	i    int
   130  	data T
   131  }
   132  
   133  // A heap of serial items. Sorts by serial number.
   134  type serialHeap[T any] struct {
   135  	next int
   136  	data *heaps.Heap[serialItem[T]]
   137  }
   138  
   139  // Checks whether the minimal element in the heap is the next in the series.
   140  func (s *serialHeap[T]) ok() bool {
   141  	return s.data.Len() > 0 && s.data.Head().i == s.next
   142  }
   143  
   144  // Removes and returns the minimal element in the heap. Panics if the element
   145  // is not the next in the series.
   146  func (s *serialHeap[T]) pop() T {
   147  	if !s.ok() {
   148  		panic("get when not ok")
   149  	}
   150  	s.next++
   151  	a := s.data.Pop()
   152  	return a.data
   153  }
   154  
   155  // Adds an item to the heap.
   156  func (s *serialHeap[T]) put(item serialItem[T]) {
   157  	if item.i < s.next {
   158  		panic(fmt.Sprintf("put(%d) when next is %d", item.i, s.next))
   159  	}
   160  	s.data.Push(item)
   161  }