k8s.io/test-infra@v0.0.0-20240520184403-27c6b4c223d8/experiment/ml/analyze/predict.go (about)

     1  /*
     2  Copyright 2022 The Kubernetes Authors.
     3  
     4  Licensed under the Apache License, Version 2.0 (the "License");
     5  you may not use this file except in compliance with the License.
     6  You may obtain a copy of the License at
     7  
     8      http://www.apache.org/licenses/LICENSE-2.0
     9  
    10  Unless required by applicable law or agreed to in writing, software
    11  distributed under the License is distributed on an "AS IS" BASIS,
    12  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    13  See the License for the specific language governing permissions and
    14  limitations under the License.
    15  */
    16  
    17  package main
    18  
    19  import (
    20  	"bufio"
    21  	"context"
    22  	"fmt"
    23  	"log"
    24  	"strings"
    25  	"sync"
    26  
    27  	"github.com/GoogleCloudPlatform/testgrid/util/gcs"
    28  )
    29  
    30  func annotateBuild(ctx context.Context, gcsClient gcs.ConditionalClient, predictor *predictionClient, build gcs.Path) ([]int, string, error) {
    31  
    32  	log.Println("Analyzing:", build)
    33  
    34  	sentences, err := readLines(ctx, gcsClient, build)
    35  	if err != nil {
    36  		return nil, "", fmt.Errorf("read lines: %v", err)
    37  	}
    38  
    39  	lines, err := predictByPage(ctx, predictor, sentences...)
    40  	if err != nil {
    41  		return nil, "", err
    42  	}
    43  	min, max := minMax(lines)
    44  	const window = 5
    45  	min -= window
    46  	max += window
    47  	if min < 0 {
    48  		min = 0
    49  	}
    50  	if max >= len(sentences) {
    51  		max = len(sentences) - 1
    52  	}
    53  
    54  	return lines, strings.Join(sentences[min:max+1], "\n"), nil
    55  }
    56  
    57  func readLines(ctx context.Context, client gcs.ConditionalClient, path gcs.Path) ([]string, error) {
    58  	r, _, err := client.Open(ctx, path)
    59  	if err != nil {
    60  		return nil, fmt.Errorf("open: %w", err)
    61  	}
    62  	defer r.Close()
    63  	scanner := bufio.NewScanner(r)
    64  	var sentences []string
    65  	var lineno int
    66  	for scanner.Scan() {
    67  		lineno++
    68  		txt := scanner.Text()
    69  		if t := truncateLine(txt, *sentenceLen); t != nil {
    70  			txt = *t
    71  		}
    72  		sentences = append(sentences, txt)
    73  	}
    74  
    75  	if err := scanner.Err(); err != nil {
    76  		lineno++
    77  		return sentences, fmt.Errorf("%d: %w", lineno, err)
    78  	}
    79  
    80  	return sentences, nil
    81  
    82  }
    83  
    84  func truncateLine(s string, n int) *string {
    85  	if n <= 0 || len(s) <= n {
    86  		return nil
    87  	}
    88  	half := n / 2
    89  	s = strings.ToValidUTF8(s[:half-2]+"..."+s[len(s)-half+1:], "")
    90  	return &s
    91  }
    92  
    93  var (
    94  	predictLock sync.Mutex
    95  )
    96  
    97  func predictByPage(ctx context.Context, predictor *predictionClient, sentences ...string) ([]int, error) {
    98  	predictLock.Lock() // allocate all quota to a single request at a time
    99  	scores, err := predictSentencesByPage(ctx, predictor, sentences...)
   100  	predictLock.Unlock()
   101  	if err != nil {
   102  		return nil, err
   103  	}
   104  
   105  	var maxScore float32
   106  	var maxIdx int
   107  
   108  	var more int
   109  
   110  	const (
   111  		threshold = 0.5
   112  		window    = 5
   113  	)
   114  	for n, score := range scores {
   115  		if score > maxScore {
   116  			maxIdx = n
   117  			maxScore = score
   118  		}
   119  		var notice string
   120  		if score > threshold {
   121  			notice = "+++"
   122  			if more == 0 && !*additional {
   123  				for i := n - window; i < n; i++ {
   124  					if i < 0 {
   125  						continue
   126  					}
   127  					println(i+1, "---", scores[i], sentences[i])
   128  				}
   129  			}
   130  			more = window
   131  		} else {
   132  			notice = "---"
   133  		}
   134  		if more > 0 || *additional {
   135  			println(n+1, notice, score, sentences[n])
   136  			more--
   137  		}
   138  	}
   139  
   140  	start, end := maxIdx, maxIdx
   141  	for start > 0 && scores[start-1] >= threshold {
   142  		start--
   143  	}
   144  
   145  	for end+1 < len(scores) && scores[end+1] >= threshold {
   146  		end++
   147  	}
   148  
   149  	if !*additional {
   150  		for i := start - window; i <= end+window; i++ {
   151  			if i < 0 {
   152  				continue
   153  			}
   154  			if i >= len(sentences) {
   155  				break
   156  			}
   157  			var notice string
   158  			score := scores[i]
   159  			if score > threshold {
   160  				notice = "+++"
   161  			} else {
   162  				notice = "---"
   163  			}
   164  			println(i+1, notice, score, sentences[i])
   165  		}
   166  	}
   167  
   168  	return []int{start + 1, end + 1}, nil
   169  }
   170  
   171  func predictSentencesByPage(ctx context.Context, predictor *predictionClient, sentences ...string) ([]float32, error) {
   172  	pages := splitPages(sentences, *sentenceLen, *documentLen)
   173  	if len(pages) == 0 {
   174  		return nil, nil
   175  	}
   176  
   177  	log.Printf("Found %d pages in %d lines", len(pages), len(sentences))
   178  
   179  	const (
   180  		maxRequestLen = 128000
   181  		maxPages      = 100
   182  	)
   183  	if bytesPerPage := len(pages) * *documentLen / maxPages; bytesPerPage > maxRequestLen {
   184  		return nil, fmt.Errorf("compressing %d pages to %d pages would make %d byte requests", len(pages), maxPages, bytesPerPage)
   185  	}
   186  
   187  	trunc := truncatePages(pages, maxPages)
   188  	if len(trunc) != len(pages) {
   189  		log.Printf("Truncated %d pages to %d", len(pages), len(trunc))
   190  		pages = trunc
   191  	}
   192  
   193  	scores := make([]float32, len(sentences))
   194  	highlights, err := predictPages(ctx, predictor, pages)
   195  	if err != nil {
   196  		return nil, fmt.Errorf("predict: %w", err)
   197  	}
   198  
   199  	var line int
   200  	for n, score := range highlights {
   201  		for more := len(pages[n]); more > 0; more-- {
   202  			scores[line] = score
   203  			line++
   204  		}
   205  	}
   206  
   207  	return scores, nil
   208  }
   209  
   210  func splitPages(lines []string, lineLen, pageLen int) [][]string {
   211  	var pages [][]string
   212  
   213  	var working int
   214  
   215  	var page []string
   216  	for _, txt := range lines {
   217  		if t := truncateLine(txt, lineLen); t != nil {
   218  			txt = *t
   219  		}
   220  		n := len(txt)
   221  		if n+working > pageLen {
   222  			if len(page) > 0 {
   223  				pages = append(pages, page)
   224  			}
   225  			page = nil
   226  			working = 0
   227  		}
   228  		page = append(page, txt)
   229  		working += n
   230  	}
   231  	if len(page) > 0 {
   232  		pages = append(pages, page)
   233  	}
   234  	return pages
   235  }
   236  
   237  func truncatePages(pages [][]string, maxPages int) [][]string {
   238  	n := len(pages)
   239  	if n <= maxPages {
   240  		return pages
   241  	}
   242  
   243  	join := n / maxPages
   244  
   245  	if n%maxPages != 0 {
   246  		join++
   247  	}
   248  
   249  	out := make([][]string, 0, maxPages)
   250  
   251  	for i := 0; i < n; i += join {
   252  		chapter := pages[i : i+join]
   253  		var total int
   254  		for _, pages := range chapter {
   255  			total += len(pages)
   256  		}
   257  		bigPage := make([]string, 0, total)
   258  		for _, pages := range chapter {
   259  			bigPage = append(bigPage, pages...)
   260  		}
   261  
   262  		out = append(out, bigPage)
   263  	}
   264  
   265  	return out
   266  }
   267  
   268  func predictPages(ctx context.Context, predictor *predictionClient, pages [][]string) ([]float32, error) {
   269  	highlights := make([]float32, len(pages))
   270  
   271  	ch := make(chan int)
   272  	errCh := make(chan error)
   273  
   274  	ctx, cancel := context.WithCancel(ctx)
   275  	defer cancel()
   276  
   277  	const workers = 10
   278  
   279  	for i := 0; i < workers; i++ {
   280  		go func() {
   281  			for n := range ch {
   282  				page := pages[n]
   283  				txt := strings.Join(page, "\n")
   284  				results, err := predictor.predict(ctx, txt)
   285  				if err != nil {
   286  					select {
   287  					case <-ctx.Done():
   288  					case errCh <- fmt.Errorf("%d (%s): %w", n, page, err):
   289  					}
   290  					return
   291  				}
   292  				const goal = "highlight"
   293  				highlights[n] = results[goal]
   294  			}
   295  			select {
   296  			case <-ctx.Done():
   297  			case errCh <- nil:
   298  			}
   299  		}()
   300  	}
   301  
   302  	go func() {
   303  		for n := range pages {
   304  			select {
   305  			case <-ctx.Done():
   306  			case ch <- n:
   307  			}
   308  		}
   309  		close(ch)
   310  	}()
   311  
   312  	for i := workers; i > 0; i-- {
   313  		select {
   314  		case <-ctx.Done():
   315  			return nil, ctx.Err()
   316  		case err := <-errCh:
   317  			if err != nil {
   318  				return nil, err
   319  			}
   320  		}
   321  	}
   322  
   323  	return highlights, nil
   324  }
   325  
   326  func println(stuff ...interface{}) {
   327  	if !*shout {
   328  		return
   329  	}
   330  	fmt.Println(stuff...)
   331  }
   332  
   333  func minMax(lines []int) (int, int) {
   334  	var min, max int
   335  	for i, l := range lines {
   336  		if i == 0 || l < min {
   337  			min = l
   338  		}
   339  		if i == 0 || l > max {
   340  			max = l
   341  		}
   342  	}
   343  	return min, max
   344  }