github.com/weaviate/weaviate@v1.24.6/modules/qna-transformers/additional/answer/answer_result.go (about)

     1  //                           _       _
     2  // __      _____  __ ___   ___  __ _| |_ ___
     3  // \ \ /\ / / _ \/ _` \ \ / / |/ _` | __/ _ \
     4  //  \ V  V /  __/ (_| |\ V /| | (_| | ||  __/
     5  //   \_/\_/ \___|\__,_| \_/ |_|\__,_|\__\___|
     6  //
     7  //  Copyright © 2016 - 2024 Weaviate B.V. All rights reserved.
     8  //
     9  //  CONTACT: hello@weaviate.io
    10  //
    11  
    12  package answer
    13  
    14  import (
    15  	"context"
    16  	"errors"
    17  	"sort"
    18  	"strings"
    19  
    20  	"github.com/weaviate/weaviate/entities/models"
    21  	"github.com/weaviate/weaviate/entities/search"
    22  	qnamodels "github.com/weaviate/weaviate/modules/qna-transformers/additional/models"
    23  	"github.com/weaviate/weaviate/modules/qna-transformers/ent"
    24  )
    25  
    26  func (p *AnswerProvider) findAnswer(ctx context.Context,
    27  	in []search.Result, params *Params, limit *int,
    28  	argumentModuleParams map[string]interface{},
    29  ) ([]search.Result, error) {
    30  	if len(in) > 0 {
    31  		question := p.paramsHelper.GetQuestion(argumentModuleParams["ask"])
    32  		if question == "" {
    33  			return in, errors.New("empty question")
    34  		}
    35  		properties := p.paramsHelper.GetProperties(argumentModuleParams["ask"])
    36  
    37  		for i := range in {
    38  			textProperties := map[string]string{}
    39  			schema := in[i].Object().Properties.(map[string]interface{})
    40  			for property, value := range schema {
    41  				if p.containsProperty(property, properties) {
    42  					if valueString, ok := value.(string); ok && len(valueString) > 0 {
    43  						textProperties[property] = valueString
    44  					}
    45  				}
    46  			}
    47  
    48  			texts := []string{}
    49  			for _, value := range textProperties {
    50  				texts = append(texts, value)
    51  			}
    52  			text := strings.Join(texts, " ")
    53  			if len(text) == 0 {
    54  				return in, errors.New("empty content")
    55  			}
    56  
    57  			answer, err := p.qna.Answer(ctx, text, question)
    58  			if err != nil {
    59  				return in, err
    60  			}
    61  
    62  			ap := in[i].AdditionalProperties
    63  			if ap == nil {
    64  				ap = models.AdditionalProperties{}
    65  			}
    66  
    67  			if answerMeetsSimilarityThreshold(argumentModuleParams["ask"], p.paramsHelper, answer) {
    68  				propertyName, startPos, endPos := p.findProperty(answer.Answer, textProperties)
    69  				ap["answer"] = &qnamodels.Answer{
    70  					Result:        answer.Answer,
    71  					Property:      propertyName,
    72  					StartPosition: startPos,
    73  					EndPosition:   endPos,
    74  					Certainty:     answer.Certainty,
    75  					Distance:      answer.Distance,
    76  					HasAnswer:     answer.Answer != nil,
    77  				}
    78  			} else {
    79  				ap["answer"] = &qnamodels.Answer{
    80  					HasAnswer: false,
    81  				}
    82  			}
    83  
    84  			in[i].AdditionalProperties = ap
    85  		}
    86  	}
    87  
    88  	rerank := p.paramsHelper.GetRerank(argumentModuleParams["ask"])
    89  	if rerank {
    90  		return p.rerank(in), nil
    91  	}
    92  	return in, nil
    93  }
    94  
    95  func answerMeetsSimilarityThreshold(params interface{}, helper paramsHelper, ans *ent.AnswerResult) bool {
    96  	certainty := helper.GetCertainty(params)
    97  	if certainty > 0 && ans.Certainty != nil && *ans.Certainty < certainty {
    98  		return false
    99  	}
   100  
   101  	distance := helper.GetDistance(params)
   102  	if distance > 0 && ans.Distance != nil && *ans.Distance > distance {
   103  		return false
   104  	}
   105  
   106  	return true
   107  }
   108  
   109  func (p *AnswerProvider) rerank(in []search.Result) []search.Result {
   110  	if len(in) > 0 {
   111  		sort.SliceStable(in, func(i, j int) bool {
   112  			return p.getAnswerCertainty(in[i]) > p.getAnswerCertainty(in[j])
   113  		})
   114  	}
   115  	return in
   116  }
   117  
   118  func (p *AnswerProvider) getAnswerCertainty(result search.Result) float64 {
   119  	answerObj, ok := result.AdditionalProperties["answer"]
   120  	if ok {
   121  		answer, ok := answerObj.(*qnamodels.Answer)
   122  		if ok {
   123  			if answer.HasAnswer {
   124  				return *answer.Certainty
   125  			}
   126  		}
   127  	}
   128  	return 0
   129  }
   130  
   131  func (p *AnswerProvider) containsProperty(property string, properties []string) bool {
   132  	if len(properties) == 0 {
   133  		return true
   134  	}
   135  	for i := range properties {
   136  		if properties[i] == property {
   137  			return true
   138  		}
   139  	}
   140  	return false
   141  }
   142  
   143  func (p *AnswerProvider) findProperty(answer *string, textProperties map[string]string) (*string, int, int) {
   144  	if answer == nil {
   145  		return nil, 0, 0
   146  	}
   147  	lowercaseAnswer := strings.ToLower(*answer)
   148  	if len(lowercaseAnswer) > 0 {
   149  		for property, value := range textProperties {
   150  			lowercaseValue := strings.ToLower(strings.ReplaceAll(value, "\n", " "))
   151  			if strings.Contains(lowercaseValue, lowercaseAnswer) {
   152  				startIndex := strings.Index(lowercaseValue, lowercaseAnswer)
   153  				return &property, startIndex, startIndex + len(lowercaseAnswer)
   154  			}
   155  		}
   156  	}
   157  	propertyNotFound := ""
   158  	return &propertyNotFound, 0, 0
   159  }