github.com/google/syzkaller@v0.0.0-20240517125934-c0f1611a36d6/pkg/bisect/minimize/slice.go (about)

     1  // Copyright 2023 syzkaller project authors. All rights reserved.
     2  // Use of this source code is governed by Apache 2 LICENSE that can be found in the LICENSE file.
     3  
     4  package minimize
     5  
     6  import (
     7  	"errors"
     8  	"fmt"
     9  	"math"
    10  	"strings"
    11  )
    12  
    13  type Config[T any] struct {
    14  	// The original slice is minimized with respect to this predicate.
    15  	// If Pred(X) returns true, X is assumed to contain all elements that must stay.
    16  	Pred func([]T) (bool, error)
    17  	// MaxSteps is a limit on the number of predicate calls during bisection.
    18  	// If it's hit, the bisection continues as if Pred() begins to return false.
    19  	// If it's set to 0 (by default), no limit is applied.
    20  	MaxSteps int
    21  	// MaxChunks sets a limit on the number of chunks pursued by the bisection algorithm.
    22  	// If we hit the limit, bisection is stopped and Array() returns ErrTooManyChunks
    23  	// anongside the intermediate bisection result (a valid, but not fully minimized slice).
    24  	MaxChunks int
    25  	// Logf is used for sharing debugging output.
    26  	Logf func(string, ...interface{})
    27  }
    28  
    29  // Slice() finds a minimal subsequence of slice elements that still gives Pred() == true.
    30  // The algorithm works by sequentially splitting the slice into smaller-size chunks and running
    31  // Pred() witout those chunks. Slice() receives the original slice chunks.
    32  // The expected number of Pred() runs is O(|result|*log2(|elements|)).
    33  func Slice[T any](config Config[T], slice []T) ([]T, error) {
    34  	if config.Logf == nil {
    35  		config.Logf = func(string, ...interface{}) {}
    36  	}
    37  	ctx := &sliceCtx[T]{
    38  		Config: config,
    39  		chunks: []*arrayChunk[T]{
    40  			{
    41  				elements: slice,
    42  			},
    43  		},
    44  	}
    45  	return ctx.bisect()
    46  }
    47  
    48  type sliceCtx[T any] struct {
    49  	Config[T]
    50  	chunks   []*arrayChunk[T]
    51  	predRuns int
    52  }
    53  
    54  type arrayChunk[T any] struct {
    55  	elements []T
    56  	final    bool // There's no way to further split this chunk.
    57  }
    58  
    59  // ErrTooManyChunks is returned if the number of necessary chunks surpassed MaxChunks.
    60  var ErrTooManyChunks = errors.New("the bisection process is following too many necessary chunks")
    61  
    62  func (ctx *sliceCtx[T]) bisect() ([]T, error) {
    63  	// At first, we don't know if the original chunks are really necessary.
    64  	err := ctx.splitChunks(false)
    65  	// Then, keep on splitting the chunks layer by layer until we have identified
    66  	// all necessary elements.
    67  	// This way we ensure that we always go from larger to smaller chunks.
    68  	for err == nil && !ctx.done() {
    69  		if ctx.MaxChunks > 0 && len(ctx.chunks) > ctx.MaxChunks {
    70  			err = ErrTooManyChunks
    71  			break
    72  		}
    73  		err = ctx.splitChunks(true)
    74  	}
    75  	if err != nil && err != ErrTooManyChunks {
    76  		return nil, err
    77  	}
    78  	return ctx.elements(), err
    79  }
    80  
    81  // splitChunks() splits each chunk in two and only leaves the necessary sub-parts.
    82  func (ctx *sliceCtx[T]) splitChunks(someNeeded bool) error {
    83  	ctx.Logf("split chunks (needed=%v): %s", someNeeded, ctx.chunkInfo())
    84  	splitInto := 2
    85  	if !someNeeded && len(ctx.chunks) == 1 {
    86  		// It's our first iteration.
    87  		splitInto = ctx.initialSplit(len(ctx.chunks[0].elements))
    88  	}
    89  	var newChunks []*arrayChunk[T]
    90  	for i, chunk := range ctx.chunks {
    91  		if chunk.final {
    92  			newChunks = append(newChunks, chunk)
    93  			continue
    94  		}
    95  		ctx.Logf("split chunk #%d of len %d into %d parts", i, len(chunk.elements), splitInto)
    96  		chunks := splitChunk[T](chunk.elements, splitInto)
    97  		if len(chunks) == 1 && someNeeded {
    98  			ctx.Logf("no way to further split the chunk")
    99  			chunk.final = true
   100  			newChunks = append(newChunks, chunk)
   101  			continue
   102  		}
   103  		foundNeeded := false
   104  		for j := range chunks {
   105  			ctx.Logf("testing without sub-chunk %d/%d", j+1, len(chunks))
   106  			if j < len(chunks)-1 || foundNeeded || !someNeeded {
   107  				ret, err := ctx.predRun(
   108  					newChunks,
   109  					mergeRawChunks(chunks[j+1:]),
   110  					ctx.chunks[i+1:],
   111  				)
   112  				if err != nil {
   113  					return err
   114  				}
   115  				if ret {
   116  					ctx.Logf("the chunk can be dropped")
   117  					continue
   118  				}
   119  			} else {
   120  				ctx.Logf("no need to test this chunk, it's definitely needed")
   121  			}
   122  			foundNeeded = true
   123  			newChunks = append(newChunks, &arrayChunk[T]{
   124  				elements: chunks[j],
   125  			})
   126  		}
   127  	}
   128  	ctx.chunks = newChunks
   129  	return nil
   130  }
   131  
   132  // Since Pred() runs can be costly, the objective is to get the most out of the
   133  // limited number of Pred() calls.
   134  // We try to achieve it by splitting the initial array in more than 2 elements.
   135  func (ctx *sliceCtx[T]) initialSplit(size int) int {
   136  	// If the number of steps is small and the number of elements is big,
   137  	// let's just split the initial array into MaxSteps chunks.
   138  	// There's no solid reasoning behind the condition below, so feel free to
   139  	// change it if you have better ideas.
   140  	if ctx.MaxSteps > 0 && math.Log2(float64(size)) > float64(ctx.MaxSteps) {
   141  		return ctx.MaxSteps
   142  	}
   143  	// Otherwise let's split in 3.
   144  	return 3
   145  }
   146  
   147  // predRun() determines whether (before + mid + after) covers the necessary elements.
   148  func (ctx *sliceCtx[T]) predRun(before []*arrayChunk[T], mid []T, after []*arrayChunk[T]) (bool, error) {
   149  	if ctx.MaxSteps > 0 && ctx.predRuns >= ctx.MaxSteps {
   150  		ctx.Logf("we have reached the limit on predicate runs (%d); pretend it returns false",
   151  			ctx.MaxSteps)
   152  		return false, nil
   153  	}
   154  	ctx.predRuns++
   155  	return ctx.Pred(mergeChunks(before, mid, after))
   156  }
   157  
   158  // The bisection process is done once every chunk is marked as final.
   159  func (ctx *sliceCtx[T]) done() bool {
   160  	if ctx.MaxSteps > 0 && ctx.predRuns >= ctx.MaxSteps {
   161  		// No reason to continue.
   162  		return true
   163  	}
   164  	for _, chunk := range ctx.chunks {
   165  		if !chunk.final {
   166  			return false
   167  		}
   168  	}
   169  	return true
   170  }
   171  
   172  func (ctx *sliceCtx[T]) elements() []T {
   173  	return mergeChunks(ctx.chunks, nil, nil)
   174  }
   175  
   176  func (ctx *sliceCtx[T]) chunkInfo() string {
   177  	var parts []string
   178  	for _, chunk := range ctx.chunks {
   179  		str := ""
   180  		if chunk.final {
   181  			str = ", final"
   182  		}
   183  		parts = append(parts, fmt.Sprintf("<%d%s>", len(chunk.elements), str))
   184  	}
   185  	return strings.Join(parts, ", ")
   186  }
   187  
   188  func mergeChunks[T any](before []*arrayChunk[T], mid []T, after []*arrayChunk[T]) []T {
   189  	var ret []T
   190  	for _, chunk := range before {
   191  		ret = append(ret, chunk.elements...)
   192  	}
   193  	ret = append(ret, mid...)
   194  	for _, chunk := range after {
   195  		ret = append(ret, chunk.elements...)
   196  	}
   197  	return ret
   198  }
   199  
   200  func mergeRawChunks[T any](chunks [][]T) []T {
   201  	var ret []T
   202  	for _, chunk := range chunks {
   203  		ret = append(ret, chunk...)
   204  	}
   205  	return ret
   206  }
   207  
   208  func splitChunk[T any](chunk []T, parts int) [][]T {
   209  	chunkSize := (len(chunk) + parts - 1) / parts
   210  	if chunkSize == 0 {
   211  		chunkSize = 1
   212  	}
   213  	var ret [][]T
   214  	for i := 0; i < len(chunk); i += chunkSize {
   215  		end := i + chunkSize
   216  		if end > len(chunk) {
   217  			end = len(chunk)
   218  		}
   219  		ret = append(ret, chunk[i:end])
   220  	}
   221  	return ret
   222  }