github.com/shuguocloud/go-zero@v1.3.0/core/mr/mapreduce.go (about)

     1  package mr
     2  
     3  import (
     4  	"context"
     5  	"errors"
     6  	"fmt"
     7  	"sync"
     8  
     9  	"github.com/shuguocloud/go-zero/core/errorx"
    10  	"github.com/shuguocloud/go-zero/core/lang"
    11  	"github.com/shuguocloud/go-zero/core/threading"
    12  )
    13  
    14  const (
    15  	defaultWorkers = 16
    16  	minWorkers     = 1
    17  )
    18  
    19  var (
    20  	// ErrCancelWithNil is an error that mapreduce was cancelled with nil.
    21  	ErrCancelWithNil = errors.New("mapreduce cancelled with nil")
    22  	// ErrReduceNoOutput is an error that reduce did not output a value.
    23  	ErrReduceNoOutput = errors.New("reduce not writing value")
    24  )
    25  
    26  type (
    27  	// ForEachFunc is used to do element processing, but no output.
    28  	ForEachFunc func(item interface{})
    29  	// GenerateFunc is used to let callers send elements into source.
    30  	GenerateFunc func(source chan<- interface{})
    31  	// MapFunc is used to do element processing and write the output to writer.
    32  	MapFunc func(item interface{}, writer Writer)
    33  	// MapperFunc is used to do element processing and write the output to writer,
    34  	// use cancel func to cancel the processing.
    35  	MapperFunc func(item interface{}, writer Writer, cancel func(error))
    36  	// ReducerFunc is used to reduce all the mapping output and write to writer,
    37  	// use cancel func to cancel the processing.
    38  	ReducerFunc func(pipe <-chan interface{}, writer Writer, cancel func(error))
    39  	// VoidReducerFunc is used to reduce all the mapping output, but no output.
    40  	// Use cancel func to cancel the processing.
    41  	VoidReducerFunc func(pipe <-chan interface{}, cancel func(error))
    42  	// Option defines the method to customize the mapreduce.
    43  	Option func(opts *mapReduceOptions)
    44  
    45  	mapReduceOptions struct {
    46  		ctx     context.Context
    47  		workers int
    48  	}
    49  
    50  	// Writer interface wraps Write method.
    51  	Writer interface {
    52  		Write(v interface{})
    53  	}
    54  )
    55  
    56  // Finish runs fns parallelly, cancelled on any error.
    57  func Finish(fns ...func() error) error {
    58  	if len(fns) == 0 {
    59  		return nil
    60  	}
    61  
    62  	return MapReduceVoid(func(source chan<- interface{}) {
    63  		for _, fn := range fns {
    64  			source <- fn
    65  		}
    66  	}, func(item interface{}, writer Writer, cancel func(error)) {
    67  		fn := item.(func() error)
    68  		if err := fn(); err != nil {
    69  			cancel(err)
    70  		}
    71  	}, func(pipe <-chan interface{}, cancel func(error)) {
    72  	}, WithWorkers(len(fns)))
    73  }
    74  
    75  // FinishVoid runs fns parallelly.
    76  func FinishVoid(fns ...func()) {
    77  	if len(fns) == 0 {
    78  		return
    79  	}
    80  
    81  	ForEach(func(source chan<- interface{}) {
    82  		for _, fn := range fns {
    83  			source <- fn
    84  		}
    85  	}, func(item interface{}) {
    86  		fn := item.(func())
    87  		fn()
    88  	}, WithWorkers(len(fns)))
    89  }
    90  
    91  // ForEach maps all elements from given generate but no output.
    92  func ForEach(generate GenerateFunc, mapper ForEachFunc, opts ...Option) {
    93  	drain(Map(generate, func(item interface{}, writer Writer) {
    94  		mapper(item)
    95  	}, opts...))
    96  }
    97  
    98  // Map maps all elements generated from given generate func, and returns an output channel.
    99  func Map(generate GenerateFunc, mapper MapFunc, opts ...Option) chan interface{} {
   100  	options := buildOptions(opts...)
   101  	source := buildSource(generate)
   102  	collector := make(chan interface{}, options.workers)
   103  	done := make(chan lang.PlaceholderType)
   104  
   105  	go executeMappers(options.ctx, mapper, source, collector, done, options.workers)
   106  
   107  	return collector
   108  }
   109  
   110  // MapReduce maps all elements generated from given generate func,
   111  // and reduces the output elements with given reducer.
   112  func MapReduce(generate GenerateFunc, mapper MapperFunc, reducer ReducerFunc,
   113  	opts ...Option) (interface{}, error) {
   114  	source := buildSource(generate)
   115  	return MapReduceChan(source, mapper, reducer, opts...)
   116  }
   117  
   118  // MapReduceChan maps all elements from source, and reduce the output elements with given reducer.
   119  func MapReduceChan(source <-chan interface{}, mapper MapperFunc, reducer ReducerFunc,
   120  	opts ...Option) (interface{}, error) {
   121  	options := buildOptions(opts...)
   122  	output := make(chan interface{})
   123  	defer func() {
   124  		for range output {
   125  			panic("more than one element written in reducer")
   126  		}
   127  	}()
   128  
   129  	collector := make(chan interface{}, options.workers)
   130  	done := make(chan lang.PlaceholderType)
   131  	writer := newGuardedWriter(options.ctx, output, done)
   132  	var closeOnce sync.Once
   133  	var retErr errorx.AtomicError
   134  	finish := func() {
   135  		closeOnce.Do(func() {
   136  			close(done)
   137  			close(output)
   138  		})
   139  	}
   140  	cancel := once(func(err error) {
   141  		if err != nil {
   142  			retErr.Set(err)
   143  		} else {
   144  			retErr.Set(ErrCancelWithNil)
   145  		}
   146  
   147  		drain(source)
   148  		finish()
   149  	})
   150  
   151  	go func() {
   152  		defer func() {
   153  			drain(collector)
   154  
   155  			if r := recover(); r != nil {
   156  				cancel(fmt.Errorf("%v", r))
   157  			} else {
   158  				finish()
   159  			}
   160  		}()
   161  
   162  		reducer(collector, writer, cancel)
   163  	}()
   164  
   165  	go executeMappers(options.ctx, func(item interface{}, w Writer) {
   166  		mapper(item, w, cancel)
   167  	}, source, collector, done, options.workers)
   168  
   169  	select {
   170  	case <-options.ctx.Done():
   171  		cancel(context.DeadlineExceeded)
   172  		return nil, context.DeadlineExceeded
   173  	case value, ok := <-output:
   174  		if err := retErr.Load(); err != nil {
   175  			return nil, err
   176  		} else if ok {
   177  			return value, nil
   178  		} else {
   179  			return nil, ErrReduceNoOutput
   180  		}
   181  	}
   182  }
   183  
   184  // MapReduceVoid maps all elements generated from given generate,
   185  // and reduce the output elements with given reducer.
   186  func MapReduceVoid(generate GenerateFunc, mapper MapperFunc, reducer VoidReducerFunc, opts ...Option) error {
   187  	_, err := MapReduce(generate, mapper, func(input <-chan interface{}, writer Writer, cancel func(error)) {
   188  		reducer(input, cancel)
   189  	}, opts...)
   190  	if errors.Is(err, ErrReduceNoOutput) {
   191  		return nil
   192  	}
   193  
   194  	return err
   195  }
   196  
   197  // WithContext customizes a mapreduce processing accepts a given ctx.
   198  func WithContext(ctx context.Context) Option {
   199  	return func(opts *mapReduceOptions) {
   200  		opts.ctx = ctx
   201  	}
   202  }
   203  
   204  // WithWorkers customizes a mapreduce processing with given workers.
   205  func WithWorkers(workers int) Option {
   206  	return func(opts *mapReduceOptions) {
   207  		if workers < minWorkers {
   208  			opts.workers = minWorkers
   209  		} else {
   210  			opts.workers = workers
   211  		}
   212  	}
   213  }
   214  
   215  func buildOptions(opts ...Option) *mapReduceOptions {
   216  	options := newOptions()
   217  	for _, opt := range opts {
   218  		opt(options)
   219  	}
   220  
   221  	return options
   222  }
   223  
   224  func buildSource(generate GenerateFunc) chan interface{} {
   225  	source := make(chan interface{})
   226  	threading.GoSafe(func() {
   227  		defer close(source)
   228  		generate(source)
   229  	})
   230  
   231  	return source
   232  }
   233  
   234  // drain drains the channel.
   235  func drain(channel <-chan interface{}) {
   236  	// drain the channel
   237  	for range channel {
   238  	}
   239  }
   240  
   241  func executeMappers(ctx context.Context, mapper MapFunc, input <-chan interface{},
   242  	collector chan<- interface{}, done <-chan lang.PlaceholderType, workers int) {
   243  	var wg sync.WaitGroup
   244  	defer func() {
   245  		wg.Wait()
   246  		close(collector)
   247  	}()
   248  
   249  	pool := make(chan lang.PlaceholderType, workers)
   250  	writer := newGuardedWriter(ctx, collector, done)
   251  	for {
   252  		select {
   253  		case <-ctx.Done():
   254  			return
   255  		case <-done:
   256  			return
   257  		case pool <- lang.Placeholder:
   258  			item, ok := <-input
   259  			if !ok {
   260  				<-pool
   261  				return
   262  			}
   263  
   264  			wg.Add(1)
   265  			// better to safely run caller defined method
   266  			threading.GoSafe(func() {
   267  				defer func() {
   268  					wg.Done()
   269  					<-pool
   270  				}()
   271  
   272  				mapper(item, writer)
   273  			})
   274  		}
   275  	}
   276  }
   277  
   278  func newOptions() *mapReduceOptions {
   279  	return &mapReduceOptions{
   280  		ctx:     context.Background(),
   281  		workers: defaultWorkers,
   282  	}
   283  }
   284  
   285  func once(fn func(error)) func(error) {
   286  	once := new(sync.Once)
   287  	return func(err error) {
   288  		once.Do(func() {
   289  			fn(err)
   290  		})
   291  	}
   292  }
   293  
   294  type guardedWriter struct {
   295  	ctx     context.Context
   296  	channel chan<- interface{}
   297  	done    <-chan lang.PlaceholderType
   298  }
   299  
   300  func newGuardedWriter(ctx context.Context, channel chan<- interface{},
   301  	done <-chan lang.PlaceholderType) guardedWriter {
   302  	return guardedWriter{
   303  		ctx:     ctx,
   304  		channel: channel,
   305  		done:    done,
   306  	}
   307  }
   308  
   309  func (gw guardedWriter) Write(v interface{}) {
   310  	select {
   311  	case <-gw.ctx.Done():
   312  		return
   313  	case <-gw.done:
   314  		return
   315  	default:
   316  		gw.channel <- v
   317  	}
   318  }