github.com/wbrown/gpt_bpe@v0.0.0-20250709161131-1571a6e8ad2d/prose.go (about) 1 //go:build !wasip1 && !js 2 3 package gpt_bpe 4 5 import ( 6 "strings" 7 "unicode" 8 9 "github.com/jdkato/prose/v2" 10 ) 11 12 func (encoder *GPTEncoder) TrimIncompleteSentence(tokens *Tokens) ( 13 *Tokens, 14 error, 15 ) { 16 trimmed := make(Tokens, 0) 17 doc, err := prose.NewDocument( 18 encoder.Decode(tokens), 19 prose.WithTagging(false), 20 prose.WithExtraction(false), 21 prose.WithTokenization(false), 22 ) 23 if err != nil { 24 return &trimmed, err 25 } 26 firstSentences := doc.Sentences() 27 sentences := make([]string, 0) 28 for _, sentence := range firstSentences { 29 newSentences := encoder.puncPat.Split(sentence.Text, -1) 30 sentences = append(sentences, newSentences...) 31 } 32 lastSentence := sentences[len(sentences)-1] 33 var last rune 34 for _, r := range lastSentence { 35 if unicode.IsSpace(r) { 36 continue 37 } 38 last = r 39 } 40 var text = doc.Text 41 if !unicode.IsPunct(last) { 42 trimPos := strings.LastIndex(text, lastSentence) 43 if trimPos >= 1 { 44 text = doc.Text[:trimPos-1] 45 } 46 } 47 text = strings.TrimSpace(text) 48 if float32(len(text)) < float32(len(doc.Text))*0.8 { 49 return tokens, nil 50 } 51 encoded := encoder.Encode(&text) 52 return encoded, nil 53 } 54 55 func (encoder *GPTEncoder) TrimSentences( 56 tokens *Tokens, 57 direction TrimDirection, 58 limit uint, 59 ) (*Tokens, error) { 60 var err error 61 trimmed := make(Tokens, 0) 62 if uint(len(*tokens)) <= limit { 63 return tokens, err 64 } else if direction == TrimNone { 65 return &trimmed, err 66 } 67 doc, err := prose.NewDocument( 68 encoder.Decode(tokens), 69 prose.WithTagging(false), 70 prose.WithExtraction(false), 71 prose.WithTokenization(false), 72 ) 73 if err != nil { 74 return &trimmed, err 75 } 76 sentences := doc.Sentences() 77 var start, end, step, idx int 78 var textBegin, textEnd int 79 var sentenceIdx, lastSentence int 80 switch direction { 81 case TrimTop: 82 start = len(sentences) - 1 83 end = -1 84 step = -1 85 textBegin = 0 86 textEnd = len(doc.Text) 87 case TrimBottom: 88 start = 0 89 end = len(sentences) 90 step = 1 91 textBegin = 0 92 textEnd = len(doc.Text) 93 default: 94 return &trimmed, err 95 } 96 for idx = start; idx != end; idx += step { 97 sentence := sentences[idx].Text 98 switch direction { 99 case TrimTop: 100 sentenceIdx = strings.LastIndex( 101 doc.Text[textBegin:], 102 sentence, 103 ) + textBegin 104 if sentenceIdx > 0 && sentenceIdx < len(doc.Text) && 105 unicode.IsSpace(rune(doc.Text[sentenceIdx])) { 106 sentenceIdx -= 1 107 } 108 toTokenize := doc.Text[sentenceIdx:] 109 tokCt := uint(len(*(encoder.Encode(&toTokenize)))) 110 if tokCt >= limit { 111 toEncode := doc.Text[textEnd:] 112 return encoder.Encode(&toEncode), err 113 } 114 textEnd = sentenceIdx - 1 115 case TrimBottom: 116 sentenceIdx = strings.Index( 117 doc.Text[textBegin:textEnd], 118 sentence, 119 ) + textBegin 120 sentenceEnd := sentenceIdx + len(sentence) 121 if sentenceEnd < textEnd && 122 doc.Text[sentenceEnd:sentenceEnd+1] == "\n" { 123 sentenceEnd += 1 124 } 125 toTokenize := doc.Text[0:sentenceEnd] 126 tokCt := uint(len(*(encoder.Encode(&toTokenize)))) 127 if tokCt >= limit { 128 toEncode := doc.Text[0:lastSentence] 129 return encoder.Encode(&toEncode), err 130 } 131 lastSentence = sentenceEnd 132 textBegin += len(sentence) 133 } 134 } 135 return &trimmed, err 136 }