github.com/wbrown/gpt_bpe@v0.0.0-20250709161131-1571a6e8ad2d/prose.go (about)

     1  //go:build !wasip1 && !js
     2  
     3  package gpt_bpe
     4  
     5  import (
     6  	"strings"
     7  	"unicode"
     8  
     9  	"github.com/jdkato/prose/v2"
    10  )
    11  
    12  func (encoder *GPTEncoder) TrimIncompleteSentence(tokens *Tokens) (
    13  	*Tokens,
    14  	error,
    15  ) {
    16  	trimmed := make(Tokens, 0)
    17  	doc, err := prose.NewDocument(
    18  		encoder.Decode(tokens),
    19  		prose.WithTagging(false),
    20  		prose.WithExtraction(false),
    21  		prose.WithTokenization(false),
    22  	)
    23  	if err != nil {
    24  		return &trimmed, err
    25  	}
    26  	firstSentences := doc.Sentences()
    27  	sentences := make([]string, 0)
    28  	for _, sentence := range firstSentences {
    29  		newSentences := encoder.puncPat.Split(sentence.Text, -1)
    30  		sentences = append(sentences, newSentences...)
    31  	}
    32  	lastSentence := sentences[len(sentences)-1]
    33  	var last rune
    34  	for _, r := range lastSentence {
    35  		if unicode.IsSpace(r) {
    36  			continue
    37  		}
    38  		last = r
    39  	}
    40  	var text = doc.Text
    41  	if !unicode.IsPunct(last) {
    42  		trimPos := strings.LastIndex(text, lastSentence)
    43  		if trimPos >= 1 {
    44  			text = doc.Text[:trimPos-1]
    45  		}
    46  	}
    47  	text = strings.TrimSpace(text)
    48  	if float32(len(text)) < float32(len(doc.Text))*0.8 {
    49  		return tokens, nil
    50  	}
    51  	encoded := encoder.Encode(&text)
    52  	return encoded, nil
    53  }
    54  
    55  func (encoder *GPTEncoder) TrimSentences(
    56  	tokens *Tokens,
    57  	direction TrimDirection,
    58  	limit uint,
    59  ) (*Tokens, error) {
    60  	var err error
    61  	trimmed := make(Tokens, 0)
    62  	if uint(len(*tokens)) <= limit {
    63  		return tokens, err
    64  	} else if direction == TrimNone {
    65  		return &trimmed, err
    66  	}
    67  	doc, err := prose.NewDocument(
    68  		encoder.Decode(tokens),
    69  		prose.WithTagging(false),
    70  		prose.WithExtraction(false),
    71  		prose.WithTokenization(false),
    72  	)
    73  	if err != nil {
    74  		return &trimmed, err
    75  	}
    76  	sentences := doc.Sentences()
    77  	var start, end, step, idx int
    78  	var textBegin, textEnd int
    79  	var sentenceIdx, lastSentence int
    80  	switch direction {
    81  	case TrimTop:
    82  		start = len(sentences) - 1
    83  		end = -1
    84  		step = -1
    85  		textBegin = 0
    86  		textEnd = len(doc.Text)
    87  	case TrimBottom:
    88  		start = 0
    89  		end = len(sentences)
    90  		step = 1
    91  		textBegin = 0
    92  		textEnd = len(doc.Text)
    93  	default:
    94  		return &trimmed, err
    95  	}
    96  	for idx = start; idx != end; idx += step {
    97  		sentence := sentences[idx].Text
    98  		switch direction {
    99  		case TrimTop:
   100  			sentenceIdx = strings.LastIndex(
   101  				doc.Text[textBegin:],
   102  				sentence,
   103  			) + textBegin
   104  			if sentenceIdx > 0 && sentenceIdx < len(doc.Text) &&
   105  				unicode.IsSpace(rune(doc.Text[sentenceIdx])) {
   106  				sentenceIdx -= 1
   107  			}
   108  			toTokenize := doc.Text[sentenceIdx:]
   109  			tokCt := uint(len(*(encoder.Encode(&toTokenize))))
   110  			if tokCt >= limit {
   111  				toEncode := doc.Text[textEnd:]
   112  				return encoder.Encode(&toEncode), err
   113  			}
   114  			textEnd = sentenceIdx - 1
   115  		case TrimBottom:
   116  			sentenceIdx = strings.Index(
   117  				doc.Text[textBegin:textEnd],
   118  				sentence,
   119  			) + textBegin
   120  			sentenceEnd := sentenceIdx + len(sentence)
   121  			if sentenceEnd < textEnd &&
   122  				doc.Text[sentenceEnd:sentenceEnd+1] == "\n" {
   123  				sentenceEnd += 1
   124  			}
   125  			toTokenize := doc.Text[0:sentenceEnd]
   126  			tokCt := uint(len(*(encoder.Encode(&toTokenize))))
   127  			if tokCt >= limit {
   128  				toEncode := doc.Text[0:lastSentence]
   129  				return encoder.Encode(&toEncode), err
   130  			}
   131  			lastSentence = sentenceEnd
   132  			textBegin += len(sentence)
   133  		}
   134  	}
   135  	return &trimmed, err
   136  }