github.com/wbrown/gpt_bpe@v0.0.0-20250709161131-1571a6e8ad2d/utils.go (about)

     1  package gpt_bpe
     2  
     3  import (
     4  	"strings"
     5  )
     6  
     7  type TrimDirection uint
     8  
     9  const (
    10  	TrimTop    TrimDirection = iota
    11  	TrimBottom TrimDirection = iota
    12  	TrimNone   TrimDirection = iota
    13  )
    14  
    15  func (encoder *GPTEncoder) TrimNewlines(
    16  	tokens *Tokens,
    17  	direction TrimDirection,
    18  	limit uint,
    19  ) (*Tokens, error) {
    20  	var err error
    21  	trimmed := make(Tokens, 0)
    22  	if uint(len(*tokens)) <= limit {
    23  		return tokens, err
    24  	} else if direction == TrimNone {
    25  		return &trimmed, err
    26  	}
    27  	lines := strings.Split(encoder.Decode(tokens), "\n")
    28  	var start, end, step, idx int
    29  	switch direction {
    30  	case TrimTop:
    31  		start = len(lines) - 1
    32  		end = -1
    33  		step = -1
    34  	case TrimBottom:
    35  		start = 0
    36  		end = len(lines)
    37  		step = 1
    38  	}
    39  	accTokens := make(Tokens, 0)
    40  	for idx = start; idx != end; idx += step {
    41  		line := lines[idx]
    42  		switch direction {
    43  		case TrimTop:
    44  			line = "\n" + line
    45  		case TrimBottom:
    46  			line = line + "\n"
    47  		}
    48  		newTokens := encoder.Encode(&line)
    49  		if len(*newTokens)+len(accTokens) > int(limit) {
    50  			return &accTokens, err
    51  		} else {
    52  			switch direction {
    53  			case TrimTop:
    54  				accTokens = append(*newTokens, accTokens...)
    55  			case TrimBottom:
    56  				accTokens = append(accTokens, *newTokens...)
    57  			}
    58  		}
    59  	}
    60  	return &accTokens, err
    61  }
    62  
    63  func (encoder *GPTEncoder) AlignAndSizeTokens(
    64  	tokens *Tokens,
    65  	desiredLength int,
    66  ) (
    67  	alignedTokens Tokens,
    68  	endAt int,
    69  ) {
    70  	if len(*tokens) == desiredLength {
    71  		return *tokens, desiredLength
    72  	}
    73  	var chunk Tokens
    74  	if len(*tokens) < desiredLength {
    75  		chunk = *tokens
    76  		desiredLength = len(*tokens)
    77  	} else {
    78  		chunk = (*tokens)[0:desiredLength]
    79  	}
    80  
    81  	// We trim to valid tokens, as we don't want partials
    82  	// that are truncated multi-tokens.
    83  	trimmed := encoder.TrimTokens(&chunk)
    84  	trimmedLength := len(*trimmed)
    85  	isTrimmed := len(*trimmed) != len(chunk)
    86  	chunk = *trimmed
    87  	idx := trimmedLength
    88  
    89  	// We do a decode and reencode pass, as this can affect
    90  	// the size after a trim.
    91  	if isTrimmed {
    92  		decodedChunk := encoder.Decode(&chunk)
    93  		reencodedChunk := encoder.Encode(&decodedChunk)
    94  		chunk = *reencodedChunk
    95  		// See if there's any change in size that causes it to
    96  		// be smaller than the `desiredLength`.
    97  		roundtripRemainder := desiredLength - len(chunk)
    98  		if roundtripRemainder > 0 {
    99  			addlEnd := idx + roundtripRemainder
   100  			addlTokens := (*tokens)[idx:addlEnd]
   101  			trimmedAddl := encoder.TrimTokens(&addlTokens)
   102  			chunk = append(chunk, *trimmedAddl...)
   103  			idx += len(*trimmedAddl)
   104  			// Another decode/re-encode pass.
   105  			decodedChunk = encoder.Decode(&chunk)
   106  			reencodedChunk = encoder.Encode(&decodedChunk)
   107  			// Loop, dropping tokens one by one until we have
   108  			// valid tokens and we fit within `contextSize`.
   109  			for {
   110  				chunk = *reencodedChunk
   111  				if len(chunk) <= desiredLength &&
   112  					encoder.TokensReady(&chunk) {
   113  					break
   114  				}
   115  				chunk = chunk[:len(chunk)-1]
   116  				idx -= 1
   117  				decodedChunk = encoder.Decode(&chunk)
   118  				reencodedChunk = encoder.Encode(&decodedChunk)
   119  			}
   120  		}
   121  	}
   122  
   123  	return chunk, idx
   124  }