github.com/lingyao2333/mo-zero@v1.4.1/core/mr/mapreduce.go (about)

     1  package mr
     2  
     3  import (
     4  	"context"
     5  	"errors"
     6  	"sync"
     7  	"sync/atomic"
     8  
     9  	"github.com/lingyao2333/mo-zero/core/errorx"
    10  	"github.com/lingyao2333/mo-zero/core/lang"
    11  )
    12  
    13  const (
    14  	defaultWorkers = 16
    15  	minWorkers     = 1
    16  )
    17  
    18  var (
    19  	// ErrCancelWithNil is an error that mapreduce was cancelled with nil.
    20  	ErrCancelWithNil = errors.New("mapreduce cancelled with nil")
    21  	// ErrReduceNoOutput is an error that reduce did not output a value.
    22  	ErrReduceNoOutput = errors.New("reduce not writing value")
    23  )
    24  
    25  type (
    26  	// ForEachFunc is used to do element processing, but no output.
    27  	ForEachFunc func(item interface{})
    28  	// GenerateFunc is used to let callers send elements into source.
    29  	GenerateFunc func(source chan<- interface{})
    30  	// MapFunc is used to do element processing and write the output to writer.
    31  	MapFunc func(item interface{}, writer Writer)
    32  	// MapperFunc is used to do element processing and write the output to writer,
    33  	// use cancel func to cancel the processing.
    34  	MapperFunc func(item interface{}, writer Writer, cancel func(error))
    35  	// ReducerFunc is used to reduce all the mapping output and write to writer,
    36  	// use cancel func to cancel the processing.
    37  	ReducerFunc func(pipe <-chan interface{}, writer Writer, cancel func(error))
    38  	// VoidReducerFunc is used to reduce all the mapping output, but no output.
    39  	// Use cancel func to cancel the processing.
    40  	VoidReducerFunc func(pipe <-chan interface{}, cancel func(error))
    41  	// Option defines the method to customize the mapreduce.
    42  	Option func(opts *mapReduceOptions)
    43  
    44  	mapperContext struct {
    45  		ctx       context.Context
    46  		mapper    MapFunc
    47  		source    <-chan interface{}
    48  		panicChan *onceChan
    49  		collector chan<- interface{}
    50  		doneChan  <-chan lang.PlaceholderType
    51  		workers   int
    52  	}
    53  
    54  	mapReduceOptions struct {
    55  		ctx     context.Context
    56  		workers int
    57  	}
    58  
    59  	// Writer interface wraps Write method.
    60  	Writer interface {
    61  		Write(v interface{})
    62  	}
    63  )
    64  
    65  // Finish runs fns parallelly, cancelled on any error.
    66  func Finish(fns ...func() error) error {
    67  	if len(fns) == 0 {
    68  		return nil
    69  	}
    70  
    71  	return MapReduceVoid(func(source chan<- interface{}) {
    72  		for _, fn := range fns {
    73  			source <- fn
    74  		}
    75  	}, func(item interface{}, writer Writer, cancel func(error)) {
    76  		fn := item.(func() error)
    77  		if err := fn(); err != nil {
    78  			cancel(err)
    79  		}
    80  	}, func(pipe <-chan interface{}, cancel func(error)) {
    81  	}, WithWorkers(len(fns)))
    82  }
    83  
    84  // FinishVoid runs fns parallelly.
    85  func FinishVoid(fns ...func()) {
    86  	if len(fns) == 0 {
    87  		return
    88  	}
    89  
    90  	ForEach(func(source chan<- interface{}) {
    91  		for _, fn := range fns {
    92  			source <- fn
    93  		}
    94  	}, func(item interface{}) {
    95  		fn := item.(func())
    96  		fn()
    97  	}, WithWorkers(len(fns)))
    98  }
    99  
   100  // ForEach maps all elements from given generate but no output.
   101  func ForEach(generate GenerateFunc, mapper ForEachFunc, opts ...Option) {
   102  	options := buildOptions(opts...)
   103  	panicChan := &onceChan{channel: make(chan interface{})}
   104  	source := buildSource(generate, panicChan)
   105  	collector := make(chan interface{})
   106  	done := make(chan lang.PlaceholderType)
   107  
   108  	go executeMappers(mapperContext{
   109  		ctx: options.ctx,
   110  		mapper: func(item interface{}, _ Writer) {
   111  			mapper(item)
   112  		},
   113  		source:    source,
   114  		panicChan: panicChan,
   115  		collector: collector,
   116  		doneChan:  done,
   117  		workers:   options.workers,
   118  	})
   119  
   120  	for {
   121  		select {
   122  		case v := <-panicChan.channel:
   123  			panic(v)
   124  		case _, ok := <-collector:
   125  			if !ok {
   126  				return
   127  			}
   128  		}
   129  	}
   130  }
   131  
   132  // MapReduce maps all elements generated from given generate func,
   133  // and reduces the output elements with given reducer.
   134  func MapReduce(generate GenerateFunc, mapper MapperFunc, reducer ReducerFunc,
   135  	opts ...Option) (interface{}, error) {
   136  	panicChan := &onceChan{channel: make(chan interface{})}
   137  	source := buildSource(generate, panicChan)
   138  	return mapReduceWithPanicChan(source, panicChan, mapper, reducer, opts...)
   139  }
   140  
   141  // MapReduceChan maps all elements from source, and reduce the output elements with given reducer.
   142  func MapReduceChan(source <-chan interface{}, mapper MapperFunc, reducer ReducerFunc,
   143  	opts ...Option) (interface{}, error) {
   144  	panicChan := &onceChan{channel: make(chan interface{})}
   145  	return mapReduceWithPanicChan(source, panicChan, mapper, reducer, opts...)
   146  }
   147  
   148  // mapReduceWithPanicChan maps all elements from source, and reduce the output elements with given reducer.
   149  func mapReduceWithPanicChan(source <-chan interface{}, panicChan *onceChan, mapper MapperFunc,
   150  	reducer ReducerFunc, opts ...Option) (interface{}, error) {
   151  	options := buildOptions(opts...)
   152  	// output is used to write the final result
   153  	output := make(chan interface{})
   154  	defer func() {
   155  		// reducer can only write once, if more, panic
   156  		for range output {
   157  			panic("more than one element written in reducer")
   158  		}
   159  	}()
   160  
   161  	// collector is used to collect data from mapper, and consume in reducer
   162  	collector := make(chan interface{}, options.workers)
   163  	// if done is closed, all mappers and reducer should stop processing
   164  	done := make(chan lang.PlaceholderType)
   165  	writer := newGuardedWriter(options.ctx, output, done)
   166  	var closeOnce sync.Once
   167  	// use atomic.Value to avoid data race
   168  	var retErr errorx.AtomicError
   169  	finish := func() {
   170  		closeOnce.Do(func() {
   171  			close(done)
   172  			close(output)
   173  		})
   174  	}
   175  	cancel := once(func(err error) {
   176  		if err != nil {
   177  			retErr.Set(err)
   178  		} else {
   179  			retErr.Set(ErrCancelWithNil)
   180  		}
   181  
   182  		drain(source)
   183  		finish()
   184  	})
   185  
   186  	go func() {
   187  		defer func() {
   188  			drain(collector)
   189  			if r := recover(); r != nil {
   190  				panicChan.write(r)
   191  			}
   192  			finish()
   193  		}()
   194  
   195  		reducer(collector, writer, cancel)
   196  	}()
   197  
   198  	go executeMappers(mapperContext{
   199  		ctx: options.ctx,
   200  		mapper: func(item interface{}, w Writer) {
   201  			mapper(item, w, cancel)
   202  		},
   203  		source:    source,
   204  		panicChan: panicChan,
   205  		collector: collector,
   206  		doneChan:  done,
   207  		workers:   options.workers,
   208  	})
   209  
   210  	select {
   211  	case <-options.ctx.Done():
   212  		cancel(context.DeadlineExceeded)
   213  		return nil, context.DeadlineExceeded
   214  	case v := <-panicChan.channel:
   215  		// drain output here, otherwise for loop panic in defer
   216  		drain(output)
   217  		panic(v)
   218  	case v, ok := <-output:
   219  		if err := retErr.Load(); err != nil {
   220  			return nil, err
   221  		} else if ok {
   222  			return v, nil
   223  		} else {
   224  			return nil, ErrReduceNoOutput
   225  		}
   226  	}
   227  }
   228  
   229  // MapReduceVoid maps all elements generated from given generate,
   230  // and reduce the output elements with given reducer.
   231  func MapReduceVoid(generate GenerateFunc, mapper MapperFunc, reducer VoidReducerFunc, opts ...Option) error {
   232  	_, err := MapReduce(generate, mapper, func(input <-chan interface{}, writer Writer, cancel func(error)) {
   233  		reducer(input, cancel)
   234  	}, opts...)
   235  	if errors.Is(err, ErrReduceNoOutput) {
   236  		return nil
   237  	}
   238  
   239  	return err
   240  }
   241  
   242  // WithContext customizes a mapreduce processing accepts a given ctx.
   243  func WithContext(ctx context.Context) Option {
   244  	return func(opts *mapReduceOptions) {
   245  		opts.ctx = ctx
   246  	}
   247  }
   248  
   249  // WithWorkers customizes a mapreduce processing with given workers.
   250  func WithWorkers(workers int) Option {
   251  	return func(opts *mapReduceOptions) {
   252  		if workers < minWorkers {
   253  			opts.workers = minWorkers
   254  		} else {
   255  			opts.workers = workers
   256  		}
   257  	}
   258  }
   259  
   260  func buildOptions(opts ...Option) *mapReduceOptions {
   261  	options := newOptions()
   262  	for _, opt := range opts {
   263  		opt(options)
   264  	}
   265  
   266  	return options
   267  }
   268  
   269  func buildSource(generate GenerateFunc, panicChan *onceChan) chan interface{} {
   270  	source := make(chan interface{})
   271  	go func() {
   272  		defer func() {
   273  			if r := recover(); r != nil {
   274  				panicChan.write(r)
   275  			}
   276  			close(source)
   277  		}()
   278  
   279  		generate(source)
   280  	}()
   281  
   282  	return source
   283  }
   284  
   285  // drain drains the channel.
   286  func drain(channel <-chan interface{}) {
   287  	// drain the channel
   288  	for range channel {
   289  	}
   290  }
   291  
   292  func executeMappers(mCtx mapperContext) {
   293  	var wg sync.WaitGroup
   294  	defer func() {
   295  		wg.Wait()
   296  		close(mCtx.collector)
   297  		drain(mCtx.source)
   298  	}()
   299  
   300  	var failed int32
   301  	pool := make(chan lang.PlaceholderType, mCtx.workers)
   302  	writer := newGuardedWriter(mCtx.ctx, mCtx.collector, mCtx.doneChan)
   303  	for atomic.LoadInt32(&failed) == 0 {
   304  		select {
   305  		case <-mCtx.ctx.Done():
   306  			return
   307  		case <-mCtx.doneChan:
   308  			return
   309  		case pool <- lang.Placeholder:
   310  			item, ok := <-mCtx.source
   311  			if !ok {
   312  				<-pool
   313  				return
   314  			}
   315  
   316  			wg.Add(1)
   317  			go func() {
   318  				defer func() {
   319  					if r := recover(); r != nil {
   320  						atomic.AddInt32(&failed, 1)
   321  						mCtx.panicChan.write(r)
   322  					}
   323  					wg.Done()
   324  					<-pool
   325  				}()
   326  
   327  				mCtx.mapper(item, writer)
   328  			}()
   329  		}
   330  	}
   331  }
   332  
   333  func newOptions() *mapReduceOptions {
   334  	return &mapReduceOptions{
   335  		ctx:     context.Background(),
   336  		workers: defaultWorkers,
   337  	}
   338  }
   339  
   340  func once(fn func(error)) func(error) {
   341  	once := new(sync.Once)
   342  	return func(err error) {
   343  		once.Do(func() {
   344  			fn(err)
   345  		})
   346  	}
   347  }
   348  
   349  type guardedWriter struct {
   350  	ctx     context.Context
   351  	channel chan<- interface{}
   352  	done    <-chan lang.PlaceholderType
   353  }
   354  
   355  func newGuardedWriter(ctx context.Context, channel chan<- interface{},
   356  	done <-chan lang.PlaceholderType) guardedWriter {
   357  	return guardedWriter{
   358  		ctx:     ctx,
   359  		channel: channel,
   360  		done:    done,
   361  	}
   362  }
   363  
   364  func (gw guardedWriter) Write(v interface{}) {
   365  	select {
   366  	case <-gw.ctx.Done():
   367  		return
   368  	case <-gw.done:
   369  		return
   370  	default:
   371  		gw.channel <- v
   372  	}
   373  }
   374  
   375  type onceChan struct {
   376  	channel chan interface{}
   377  	wrote   int32
   378  }
   379  
   380  func (oc *onceChan) write(val interface{}) {
   381  	if atomic.CompareAndSwapInt32(&oc.wrote, 0, 1) {
   382  		oc.channel <- val
   383  	}
   384  }