go.chromium.org/luci@v0.0.0-20240309015107-7cdc2e660f33/common/proto/protowalk/field_processor_cache.go (about)

     1  // Copyright 2022 The LUCI Authors.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //      http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package protowalk
    16  
    17  import (
    18  	"fmt"
    19  	"reflect"
    20  	"sync"
    21  
    22  	"google.golang.org/protobuf/reflect/protoreflect"
    23  
    24  	"go.chromium.org/luci/common/data/stringset"
    25  )
    26  
    27  // fieldProcessorSelectors maps from all the registered processors to an
    28  // id, which can be used with registeredFieldProcessorsID to find the callback
    29  // function and data type, and will be used internally in caches.
    30  var (
    31  	fieldProcessorSelectors   map[reflect.Type]FieldSelector
    32  	fieldProcessorSelectorsMu sync.RWMutex
    33  )
    34  
    35  // fieldProcessorCacheKey is the key for the global field processor cache
    36  type fieldProcessorCacheKey struct {
    37  	message    protoreflect.FullName
    38  	processorT reflect.Type
    39  }
    40  
    41  // fieldProcessorCacheValue is a single cache value for the global field
    42  // processor cache.
    43  type fieldProcessorCacheValue struct {
    44  	// fieldNum is the field 'tag number' in the proto message for the
    45  	// corresponding field. We store this instead of the FieldDescriptor because
    46  	// it's only 4 bytes, rather than a 16+byte interface.
    47  	fieldNum protoreflect.FieldNumber
    48  
    49  	// All cacheable attributes
    50  	ProcessAttr
    51  	recurseAttr
    52  }
    53  
    54  // fieldProcessorCacheEntry corresponds to a single Message+FieldProcessor
    55  // combination.
    56  //
    57  // Each value is the combined result of asking FieldProcessor.ShouldProcess on the
    58  // field, plus any recursion attribute (if this field is a Message (single,
    59  // repeated or map)) which recursively contains fields which need to be
    60  // processed by this FieldProcessor.
    61  //
    62  // This is always kept ordered by field number, and is immutable.
    63  type fieldProcessorCacheEntry []fieldProcessorCacheValue
    64  
    65  // globalFieldProcessorCache maps (Message+FieldProcessor) combinations to
    66  // a (possibly empty) slice which indicates which fields need to be processed
    67  // and/or recursed by the FieldProcessor.
    68  var globalFieldProcessorCache = map[fieldProcessorCacheKey]fieldProcessorCacheEntry{}
    69  var globalFieldProcessorCacheMu sync.RWMutex
    70  
    71  // resetGlobalFieldProcessorCache is only used in tests.
    72  func resetGlobalFieldProcessorCache() {
    73  	globalFieldProcessorCacheMu.Lock()
    74  	defer globalFieldProcessorCacheMu.Unlock()
    75  	for k := range globalFieldProcessorCache {
    76  		delete(globalFieldProcessorCache, k)
    77  	}
    78  }
    79  
    80  type cacheEntryBuilder struct {
    81  	ret fieldProcessorCacheEntry
    82  	tmp map[protoreflect.FieldNumber]fieldProcessorCacheValue
    83  }
    84  
    85  func newCacheEntryBuilder() *cacheEntryBuilder {
    86  	return &cacheEntryBuilder{
    87  		ret: make(fieldProcessorCacheEntry, 0),
    88  		tmp: map[protoreflect.FieldNumber]fieldProcessorCacheValue{},
    89  	}
    90  }
    91  
    92  // generateCacheEntry returns the fieldProcessorCacheEntry for this
    93  // message/processor combination.
    94  //
    95  // This will calculate and return the cache entry.
    96  //
    97  // If msg has recursive fields, we may not be able to get the final cache values
    98  // for those fields until we get up one level, so use tmpRet to keep the temporary
    99  // cache values for them.
   100  func generateCacheEntry(msg protoreflect.MessageDescriptor, processor *procBundle, visitedSubMsgs stringset.Set, tmp map[string]*cacheEntryBuilder) *cacheEntryBuilder {
   101  	fields := msg.Fields()
   102  	msgName := string(msg.FullName())
   103  	for f := 0; f < fields.Len(); f++ {
   104  		finalRet := true
   105  		field := fields.Get(f)
   106  		value := fieldProcessorCacheValue{
   107  			fieldNum:    field.Number(),
   108  			ProcessAttr: processor.sel(field),
   109  		}
   110  
   111  		if !value.ProcessAttr.Valid() {
   112  			panic(fmt.Errorf("(%T).ShouldProcess returned invalid ProcessAttr value: %d",
   113  				processor, value.ProcessAttr))
   114  		}
   115  
   116  		if field.IsMap() {
   117  			if mapVal := field.MapValue(); mapVal.Kind() == protoreflect.MessageKind {
   118  				if len(setCacheEntry(mapVal.Message(), processor, visitedSubMsgs, tmp)) > 0 {
   119  					switch field.MapKey().Kind() {
   120  					case protoreflect.BoolKind:
   121  						value.recurseAttr = recurseMapBool
   122  					case protoreflect.Int32Kind, protoreflect.Int64Kind:
   123  						value.recurseAttr = recurseMapInt
   124  					case protoreflect.Uint32Kind, protoreflect.Uint64Kind:
   125  						value.recurseAttr = recurseMapUint
   126  					case protoreflect.StringKind:
   127  						value.recurseAttr = recurseMapString
   128  					}
   129  				}
   130  			}
   131  		} else if field.Kind() == protoreflect.MessageKind {
   132  			fldName := string(field.FullName())
   133  			subMsgName := string(field.Message().FullName())
   134  			if visitedSubMsgs.Add(fldName) {
   135  				if len(setCacheEntry(field.Message(), processor, visitedSubMsgs, tmp)) > 0 {
   136  					if field.IsList() {
   137  						value.recurseAttr = recurseRepeated
   138  					} else {
   139  						value.recurseAttr = recurseOne
   140  					}
   141  				}
   142  			} else {
   143  				// Found a recursive message, for example
   144  				// message Outer {
   145  				//	 message Inner {
   146  				//	 	 string value = 1;
   147  				//	   Inner next = 2;
   148  				//	 }
   149  				//   Inner inner = 1;
   150  				// }
   151  				// And we're processing .inner.next.
   152  				// We should not call setCacheEntry for it again because it will
   153  				// get us to an infinite loop.
   154  				// And it will reuse the cache entry for .inner, so we really don't
   155  				// need to call setCacheEntry.
   156  				finalRet = false
   157  				if field.IsList() {
   158  					value.recurseAttr = recurseRepeated
   159  				} else {
   160  					value.recurseAttr = recurseOne
   161  				}
   162  
   163  				for _, v := range tmp[subMsgName].ret {
   164  					if v.ProcessAttr != ProcessNever {
   165  						finalRet = true
   166  						break
   167  					}
   168  				}
   169  			}
   170  		}
   171  
   172  		// We want an entry in the cache if we have to process the field:
   173  		//   * directly (i.e. processor applies directly to field), OR
   174  		//   * recursively (i.e. the field is a message kind, and that
   175  		//     message contains a field (or another recursion) that we
   176  		//     must follow).
   177  		if value.ProcessAttr != ProcessNever || value.recurseAttr != recurseNone {
   178  			if bdr, ok := tmp[msgName]; ok {
   179  				bdr.tmp[value.fieldNum] = value
   180  				if finalRet {
   181  					bdr.ret = append(bdr.ret, value)
   182  					delete(bdr.tmp, value.fieldNum)
   183  				}
   184  			}
   185  		}
   186  	}
   187  	return tmp[msgName]
   188  }
   189  
   190  // setCacheEntry will ensure that globalFieldProcessorCache is populated for
   191  // `msg` for the given `processor`.
   192  //
   193  // Returns the entry for this message/processor combination.
   194  func setCacheEntry(msg protoreflect.MessageDescriptor, processor *procBundle, visitedSubMsgs stringset.Set, tmp map[string]*cacheEntryBuilder) (ret fieldProcessorCacheEntry) {
   195  	key := fieldProcessorCacheKey{
   196  		message:    msg.FullName(),
   197  		processorT: processor.proc,
   198  	}
   199  
   200  	globalFieldProcessorCacheMu.RLock()
   201  	ret, ok := globalFieldProcessorCache[key]
   202  	globalFieldProcessorCacheMu.RUnlock()
   203  	if ok {
   204  		return
   205  	}
   206  
   207  	if _, ok := tmp[string(msg.FullName())]; !ok {
   208  		tmp[string(msg.FullName())] = newCacheEntryBuilder()
   209  	}
   210  	ceb := generateCacheEntry(msg, processor, visitedSubMsgs, tmp)
   211  	if len(ceb.tmp) > 0 {
   212  		// We haven't got the final values for some fields.
   213  		// Do not set cache for now.
   214  		return ceb.ret
   215  	}
   216  
   217  	if len(ceb.ret) == 0 {
   218  		// The message doesn't need to be processed by the processor.
   219  		return ceb.ret
   220  	}
   221  
   222  	ret = ceb.ret
   223  	globalFieldProcessorCacheMu.Lock()
   224  	if ce, ok := globalFieldProcessorCache[key]; !ok {
   225  		globalFieldProcessorCache[key] = ret
   226  	} else {
   227  		ret = ce
   228  	}
   229  	globalFieldProcessorCacheMu.Unlock()
   230  
   231  	return
   232  }