github.com/wbrown/gpt_bpe@v0.0.0-20250709161131-1571a6e8ad2d/utils.go (about) 1 package gpt_bpe 2 3 import ( 4 "strings" 5 ) 6 7 type TrimDirection uint 8 9 const ( 10 TrimTop TrimDirection = iota 11 TrimBottom TrimDirection = iota 12 TrimNone TrimDirection = iota 13 ) 14 15 func (encoder *GPTEncoder) TrimNewlines( 16 tokens *Tokens, 17 direction TrimDirection, 18 limit uint, 19 ) (*Tokens, error) { 20 var err error 21 trimmed := make(Tokens, 0) 22 if uint(len(*tokens)) <= limit { 23 return tokens, err 24 } else if direction == TrimNone { 25 return &trimmed, err 26 } 27 lines := strings.Split(encoder.Decode(tokens), "\n") 28 var start, end, step, idx int 29 switch direction { 30 case TrimTop: 31 start = len(lines) - 1 32 end = -1 33 step = -1 34 case TrimBottom: 35 start = 0 36 end = len(lines) 37 step = 1 38 } 39 accTokens := make(Tokens, 0) 40 for idx = start; idx != end; idx += step { 41 line := lines[idx] 42 switch direction { 43 case TrimTop: 44 line = "\n" + line 45 case TrimBottom: 46 line = line + "\n" 47 } 48 newTokens := encoder.Encode(&line) 49 if len(*newTokens)+len(accTokens) > int(limit) { 50 return &accTokens, err 51 } else { 52 switch direction { 53 case TrimTop: 54 accTokens = append(*newTokens, accTokens...) 55 case TrimBottom: 56 accTokens = append(accTokens, *newTokens...) 57 } 58 } 59 } 60 return &accTokens, err 61 } 62 63 func (encoder *GPTEncoder) AlignAndSizeTokens( 64 tokens *Tokens, 65 desiredLength int, 66 ) ( 67 alignedTokens Tokens, 68 endAt int, 69 ) { 70 if len(*tokens) == desiredLength { 71 return *tokens, desiredLength 72 } 73 var chunk Tokens 74 if len(*tokens) < desiredLength { 75 chunk = *tokens 76 desiredLength = len(*tokens) 77 } else { 78 chunk = (*tokens)[0:desiredLength] 79 } 80 81 // We trim to valid tokens, as we don't want partials 82 // that are truncated multi-tokens. 83 trimmed := encoder.TrimTokens(&chunk) 84 trimmedLength := len(*trimmed) 85 isTrimmed := len(*trimmed) != len(chunk) 86 chunk = *trimmed 87 idx := trimmedLength 88 89 // We do a decode and reencode pass, as this can affect 90 // the size after a trim. 91 if isTrimmed { 92 decodedChunk := encoder.Decode(&chunk) 93 reencodedChunk := encoder.Encode(&decodedChunk) 94 chunk = *reencodedChunk 95 // See if there's any change in size that causes it to 96 // be smaller than the `desiredLength`. 97 roundtripRemainder := desiredLength - len(chunk) 98 if roundtripRemainder > 0 { 99 addlEnd := idx + roundtripRemainder 100 addlTokens := (*tokens)[idx:addlEnd] 101 trimmedAddl := encoder.TrimTokens(&addlTokens) 102 chunk = append(chunk, *trimmedAddl...) 103 idx += len(*trimmedAddl) 104 // Another decode/re-encode pass. 105 decodedChunk = encoder.Decode(&chunk) 106 reencodedChunk = encoder.Encode(&decodedChunk) 107 // Loop, dropping tokens one by one until we have 108 // valid tokens and we fit within `contextSize`. 109 for { 110 chunk = *reencodedChunk 111 if len(chunk) <= desiredLength && 112 encoder.TokensReady(&chunk) { 113 break 114 } 115 chunk = chunk[:len(chunk)-1] 116 idx -= 1 117 decodedChunk = encoder.Decode(&chunk) 118 reencodedChunk = encoder.Encode(&decodedChunk) 119 } 120 } 121 } 122 123 return chunk, idx 124 }