github.com/integration-system/go-cmp@v0.0.0-20190131081942-ac5582987a2f/cmp/cmpopts/struct_filter.go (about)

     1  // Copyright 2017, The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE.md file.
     4  
     5  package cmpopts
     6  
     7  import (
     8  	"fmt"
     9  	"reflect"
    10  	"strings"
    11  
    12  	"github.com/integration-system/go-cmp/cmp"
    13  )
    14  
    15  // filterField returns a new Option where opt is only evaluated on paths that
    16  // include a specific exported field on a single struct type.
    17  // The struct type is specified by passing in a value of that type.
    18  //
    19  // The name may be a dot-delimited string (e.g., "Foo.Bar") to select a
    20  // specific sub-field that is embedded or nested within the parent struct.
    21  func filterField(typ interface{}, name string, opt cmp.Option) cmp.Option {
    22  	// TODO: This is currently unexported over concerns of how helper filters
    23  	// can be composed together easily.
    24  	// TODO: Add tests for FilterField.
    25  
    26  	sf := newStructFilter(typ, name)
    27  	return cmp.FilterPath(sf.filter, opt)
    28  }
    29  
    30  type structFilter struct {
    31  	t  reflect.Type // The root struct type to match on
    32  	ft fieldTree    // Tree of fields to match on
    33  }
    34  
    35  func newStructFilter(typ interface{}, names ...string) structFilter {
    36  	// TODO: Perhaps allow * as a special identifier to allow ignoring any
    37  	// number of path steps until the next field match?
    38  	// This could be useful when a concrete struct gets transformed into
    39  	// an anonymous struct where it is not possible to specify that by type,
    40  	// but the transformer happens to provide guarantees about the names of
    41  	// the transformed fields.
    42  
    43  	t := reflect.TypeOf(typ)
    44  	if t == nil || t.Kind() != reflect.Struct {
    45  		panic(fmt.Sprintf("%T must be a struct", typ))
    46  	}
    47  	var ft fieldTree
    48  	for _, name := range names {
    49  		cname, err := canonicalName(t, name)
    50  		if err != nil {
    51  			panic(fmt.Sprintf("%s: %v", strings.Join(cname, "."), err))
    52  		}
    53  		ft.insert(cname)
    54  	}
    55  	return structFilter{t, ft}
    56  }
    57  
    58  func (sf structFilter) filter(p cmp.Path) bool {
    59  	for i, ps := range p {
    60  		if ps.Type().AssignableTo(sf.t) && sf.ft.matchPrefix(p[i+1:]) {
    61  			return true
    62  		}
    63  	}
    64  	return false
    65  }
    66  
    67  // fieldTree represents a set of dot-separated identifiers.
    68  //
    69  // For example, inserting the following selectors:
    70  //	Foo
    71  //	Foo.Bar.Baz
    72  //	Foo.Buzz
    73  //	Nuka.Cola.Quantum
    74  //
    75  // Results in a tree of the form:
    76  //	{sub: {
    77  //		"Foo": {ok: true, sub: {
    78  //			"Bar": {sub: {
    79  //				"Baz": {ok: true},
    80  //			}},
    81  //			"Buzz": {ok: true},
    82  //		}},
    83  //		"Nuka": {sub: {
    84  //			"Cola": {sub: {
    85  //				"Quantum": {ok: true},
    86  //			}},
    87  //		}},
    88  //	}}
    89  type fieldTree struct {
    90  	ok  bool                 // Whether this is a specified node
    91  	sub map[string]fieldTree // The sub-tree of fields under this node
    92  }
    93  
    94  // insert inserts a sequence of field accesses into the tree.
    95  func (ft *fieldTree) insert(cname []string) {
    96  	if ft.sub == nil {
    97  		ft.sub = make(map[string]fieldTree)
    98  	}
    99  	if len(cname) == 0 {
   100  		ft.ok = true
   101  		return
   102  	}
   103  	sub := ft.sub[cname[0]]
   104  	sub.insert(cname[1:])
   105  	ft.sub[cname[0]] = sub
   106  }
   107  
   108  // matchPrefix reports whether any selector in the fieldTree matches
   109  // the start of path p.
   110  func (ft fieldTree) matchPrefix(p cmp.Path) bool {
   111  	for _, ps := range p {
   112  		switch ps := ps.(type) {
   113  		case cmp.StructField:
   114  			ft = ft.sub[ps.Name()]
   115  			if ft.ok {
   116  				return true
   117  			}
   118  			if len(ft.sub) == 0 {
   119  				return false
   120  			}
   121  		case cmp.Indirect:
   122  		default:
   123  			return false
   124  		}
   125  	}
   126  	return false
   127  }
   128  
   129  // canonicalName returns a list of identifiers where any struct field access
   130  // through an embedded field is expanded to include the names of the embedded
   131  // types themselves.
   132  //
   133  // For example, suppose field "Foo" is not directly in the parent struct,
   134  // but actually from an embedded struct of type "Bar". Then, the canonical name
   135  // of "Foo" is actually "Bar.Foo".
   136  //
   137  // Suppose field "Foo" is not directly in the parent struct, but actually
   138  // a field in two different embedded structs of types "Bar" and "Baz".
   139  // Then the selector "Foo" causes a panic since it is ambiguous which one it
   140  // refers to. The user must specify either "Bar.Foo" or "Baz.Foo".
   141  func canonicalName(t reflect.Type, sel string) ([]string, error) {
   142  	var name string
   143  	sel = strings.TrimPrefix(sel, ".")
   144  	if sel == "" {
   145  		return nil, fmt.Errorf("name must not be empty")
   146  	}
   147  	if i := strings.IndexByte(sel, '.'); i < 0 {
   148  		name, sel = sel, ""
   149  	} else {
   150  		name, sel = sel[:i], sel[i:]
   151  	}
   152  
   153  	// Type must be a struct or pointer to struct.
   154  	if t.Kind() == reflect.Ptr {
   155  		t = t.Elem()
   156  	}
   157  	if t.Kind() != reflect.Struct {
   158  		return nil, fmt.Errorf("%v must be a struct", t)
   159  	}
   160  
   161  	// Find the canonical name for this current field name.
   162  	// If the field exists in an embedded struct, then it will be expanded.
   163  	if !isExported(name) {
   164  		// Disallow unexported fields:
   165  		//	* To discourage people from actually touching unexported fields
   166  		//	* FieldByName is buggy (https://golang.org/issue/4876)
   167  		return []string{name}, fmt.Errorf("name must be exported")
   168  	}
   169  	sf, ok := t.FieldByName(name)
   170  	if !ok {
   171  		return []string{name}, fmt.Errorf("does not exist")
   172  	}
   173  	var ss []string
   174  	for i := range sf.Index {
   175  		ss = append(ss, t.FieldByIndex(sf.Index[:i+1]).Name)
   176  	}
   177  	if sel == "" {
   178  		return ss, nil
   179  	}
   180  	ssPost, err := canonicalName(sf.Type, sel)
   181  	return append(ss, ssPost...), err
   182  }