k8s.io/test-infra@v0.0.0-20240520184403-27c6b4c223d8/experiment/ml/analyze/predict.go (about) 1 /* 2 Copyright 2022 The Kubernetes Authors. 3 4 Licensed under the Apache License, Version 2.0 (the "License"); 5 you may not use this file except in compliance with the License. 6 You may obtain a copy of the License at 7 8 http://www.apache.org/licenses/LICENSE-2.0 9 10 Unless required by applicable law or agreed to in writing, software 11 distributed under the License is distributed on an "AS IS" BASIS, 12 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 See the License for the specific language governing permissions and 14 limitations under the License. 15 */ 16 17 package main 18 19 import ( 20 "bufio" 21 "context" 22 "fmt" 23 "log" 24 "strings" 25 "sync" 26 27 "github.com/GoogleCloudPlatform/testgrid/util/gcs" 28 ) 29 30 func annotateBuild(ctx context.Context, gcsClient gcs.ConditionalClient, predictor *predictionClient, build gcs.Path) ([]int, string, error) { 31 32 log.Println("Analyzing:", build) 33 34 sentences, err := readLines(ctx, gcsClient, build) 35 if err != nil { 36 return nil, "", fmt.Errorf("read lines: %v", err) 37 } 38 39 lines, err := predictByPage(ctx, predictor, sentences...) 40 if err != nil { 41 return nil, "", err 42 } 43 min, max := minMax(lines) 44 const window = 5 45 min -= window 46 max += window 47 if min < 0 { 48 min = 0 49 } 50 if max >= len(sentences) { 51 max = len(sentences) - 1 52 } 53 54 return lines, strings.Join(sentences[min:max+1], "\n"), nil 55 } 56 57 func readLines(ctx context.Context, client gcs.ConditionalClient, path gcs.Path) ([]string, error) { 58 r, _, err := client.Open(ctx, path) 59 if err != nil { 60 return nil, fmt.Errorf("open: %w", err) 61 } 62 defer r.Close() 63 scanner := bufio.NewScanner(r) 64 var sentences []string 65 var lineno int 66 for scanner.Scan() { 67 lineno++ 68 txt := scanner.Text() 69 if t := truncateLine(txt, *sentenceLen); t != nil { 70 txt = *t 71 } 72 sentences = append(sentences, txt) 73 } 74 75 if err := scanner.Err(); err != nil { 76 lineno++ 77 return sentences, fmt.Errorf("%d: %w", lineno, err) 78 } 79 80 return sentences, nil 81 82 } 83 84 func truncateLine(s string, n int) *string { 85 if n <= 0 || len(s) <= n { 86 return nil 87 } 88 half := n / 2 89 s = strings.ToValidUTF8(s[:half-2]+"..."+s[len(s)-half+1:], "") 90 return &s 91 } 92 93 var ( 94 predictLock sync.Mutex 95 ) 96 97 func predictByPage(ctx context.Context, predictor *predictionClient, sentences ...string) ([]int, error) { 98 predictLock.Lock() // allocate all quota to a single request at a time 99 scores, err := predictSentencesByPage(ctx, predictor, sentences...) 100 predictLock.Unlock() 101 if err != nil { 102 return nil, err 103 } 104 105 var maxScore float32 106 var maxIdx int 107 108 var more int 109 110 const ( 111 threshold = 0.5 112 window = 5 113 ) 114 for n, score := range scores { 115 if score > maxScore { 116 maxIdx = n 117 maxScore = score 118 } 119 var notice string 120 if score > threshold { 121 notice = "+++" 122 if more == 0 && !*additional { 123 for i := n - window; i < n; i++ { 124 if i < 0 { 125 continue 126 } 127 println(i+1, "---", scores[i], sentences[i]) 128 } 129 } 130 more = window 131 } else { 132 notice = "---" 133 } 134 if more > 0 || *additional { 135 println(n+1, notice, score, sentences[n]) 136 more-- 137 } 138 } 139 140 start, end := maxIdx, maxIdx 141 for start > 0 && scores[start-1] >= threshold { 142 start-- 143 } 144 145 for end+1 < len(scores) && scores[end+1] >= threshold { 146 end++ 147 } 148 149 if !*additional { 150 for i := start - window; i <= end+window; i++ { 151 if i < 0 { 152 continue 153 } 154 if i >= len(sentences) { 155 break 156 } 157 var notice string 158 score := scores[i] 159 if score > threshold { 160 notice = "+++" 161 } else { 162 notice = "---" 163 } 164 println(i+1, notice, score, sentences[i]) 165 } 166 } 167 168 return []int{start + 1, end + 1}, nil 169 } 170 171 func predictSentencesByPage(ctx context.Context, predictor *predictionClient, sentences ...string) ([]float32, error) { 172 pages := splitPages(sentences, *sentenceLen, *documentLen) 173 if len(pages) == 0 { 174 return nil, nil 175 } 176 177 log.Printf("Found %d pages in %d lines", len(pages), len(sentences)) 178 179 const ( 180 maxRequestLen = 128000 181 maxPages = 100 182 ) 183 if bytesPerPage := len(pages) * *documentLen / maxPages; bytesPerPage > maxRequestLen { 184 return nil, fmt.Errorf("compressing %d pages to %d pages would make %d byte requests", len(pages), maxPages, bytesPerPage) 185 } 186 187 trunc := truncatePages(pages, maxPages) 188 if len(trunc) != len(pages) { 189 log.Printf("Truncated %d pages to %d", len(pages), len(trunc)) 190 pages = trunc 191 } 192 193 scores := make([]float32, len(sentences)) 194 highlights, err := predictPages(ctx, predictor, pages) 195 if err != nil { 196 return nil, fmt.Errorf("predict: %w", err) 197 } 198 199 var line int 200 for n, score := range highlights { 201 for more := len(pages[n]); more > 0; more-- { 202 scores[line] = score 203 line++ 204 } 205 } 206 207 return scores, nil 208 } 209 210 func splitPages(lines []string, lineLen, pageLen int) [][]string { 211 var pages [][]string 212 213 var working int 214 215 var page []string 216 for _, txt := range lines { 217 if t := truncateLine(txt, lineLen); t != nil { 218 txt = *t 219 } 220 n := len(txt) 221 if n+working > pageLen { 222 if len(page) > 0 { 223 pages = append(pages, page) 224 } 225 page = nil 226 working = 0 227 } 228 page = append(page, txt) 229 working += n 230 } 231 if len(page) > 0 { 232 pages = append(pages, page) 233 } 234 return pages 235 } 236 237 func truncatePages(pages [][]string, maxPages int) [][]string { 238 n := len(pages) 239 if n <= maxPages { 240 return pages 241 } 242 243 join := n / maxPages 244 245 if n%maxPages != 0 { 246 join++ 247 } 248 249 out := make([][]string, 0, maxPages) 250 251 for i := 0; i < n; i += join { 252 chapter := pages[i : i+join] 253 var total int 254 for _, pages := range chapter { 255 total += len(pages) 256 } 257 bigPage := make([]string, 0, total) 258 for _, pages := range chapter { 259 bigPage = append(bigPage, pages...) 260 } 261 262 out = append(out, bigPage) 263 } 264 265 return out 266 } 267 268 func predictPages(ctx context.Context, predictor *predictionClient, pages [][]string) ([]float32, error) { 269 highlights := make([]float32, len(pages)) 270 271 ch := make(chan int) 272 errCh := make(chan error) 273 274 ctx, cancel := context.WithCancel(ctx) 275 defer cancel() 276 277 const workers = 10 278 279 for i := 0; i < workers; i++ { 280 go func() { 281 for n := range ch { 282 page := pages[n] 283 txt := strings.Join(page, "\n") 284 results, err := predictor.predict(ctx, txt) 285 if err != nil { 286 select { 287 case <-ctx.Done(): 288 case errCh <- fmt.Errorf("%d (%s): %w", n, page, err): 289 } 290 return 291 } 292 const goal = "highlight" 293 highlights[n] = results[goal] 294 } 295 select { 296 case <-ctx.Done(): 297 case errCh <- nil: 298 } 299 }() 300 } 301 302 go func() { 303 for n := range pages { 304 select { 305 case <-ctx.Done(): 306 case ch <- n: 307 } 308 } 309 close(ch) 310 }() 311 312 for i := workers; i > 0; i-- { 313 select { 314 case <-ctx.Done(): 315 return nil, ctx.Err() 316 case err := <-errCh: 317 if err != nil { 318 return nil, err 319 } 320 } 321 } 322 323 return highlights, nil 324 } 325 326 func println(stuff ...interface{}) { 327 if !*shout { 328 return 329 } 330 fmt.Println(stuff...) 331 } 332 333 func minMax(lines []int) (int, int) { 334 var min, max int 335 for i, l := range lines { 336 if i == 0 || l < min { 337 min = l 338 } 339 if i == 0 || l > max { 340 max = l 341 } 342 } 343 return min, max 344 }