github.com/Accefy/pop@v0.0.0-20230428174248-e9f677eab5b9/associations/associations_for_struct.go (about)

     1  package associations
     2  
     3  import (
     4  	"fmt"
     5  	"reflect"
     6  	"regexp"
     7  	"strings"
     8  
     9  	"github.com/gobuffalo/pop/v6/columns"
    10  )
    11  
    12  // If a field match with the regexp, it will be considered as a valid field definition.
    13  // e.g: "MyField"             => valid.
    14  // e.g: "MyField.NestedField" => valid.
    15  // e.g: "MyField."            => not valid.
    16  // e.g: "MyField.*"           => not valid for now.
    17  var validAssociationExpRegexp = regexp.MustCompile(`^(([a-zA-Z0-9]*)(\.[a-zA-Z0-9]+)?)+$`)
    18  
    19  // associationBuilders is a map that helps to aisle associations finding process
    20  // with the associations implementation. Every association MUST register its builder
    21  // in this map using its init() method. see ./has_many_association.go as a guide.
    22  var associationBuilders = map[string]associationBuilder{}
    23  
    24  // ForStruct returns all associations for
    25  // the struct specified. It takes into account tags
    26  // associations like has_many, belongs_to, has_one.
    27  // it throws an error when it finds a field that does
    28  // not exist for a model.
    29  func ForStruct(s interface{}, fields ...string) (Associations, error) {
    30  	return forStruct(s, s, fields)
    31  }
    32  
    33  // forStruct is a recursive helper that passes the root model down for embedded fields
    34  func forStruct(parent, s interface{}, fields []string) (Associations, error) {
    35  	t, v := getModelDefinition(s)
    36  	if t.Kind() != reflect.Struct {
    37  		return nil, fmt.Errorf("could not get struct associations: not a struct but %T", s)
    38  	}
    39  	fields = trimFields(fields)
    40  	associations := Associations{}
    41  	fieldsWithInnerAssociation := map[string]InnerAssociations{}
    42  
    43  	// validate if fields contains a non existing field in struct.
    44  	// and verify is it has inner associations.
    45  	for i := range fields {
    46  		var innerField string
    47  
    48  		if !validAssociationExpRegexp.MatchString(fields[i]) {
    49  			return associations, fmt.Errorf("association '%s' does not match the format %s", fields[i], "'<field>' or '<field>.<nested-field>'")
    50  		}
    51  
    52  		fields[i], innerField = extractFieldAndInnerFields(fields[i])
    53  
    54  		if _, ok := t.FieldByName(fields[i]); !ok {
    55  			return associations, fmt.Errorf("field %s does not exist in model %s", fields[i], t.Name())
    56  		}
    57  
    58  		if innerField != "" {
    59  			var found bool
    60  			innerF, _ := extractFieldAndInnerFields(innerField)
    61  
    62  			for j := range fieldsWithInnerAssociation[fields[i]] {
    63  				f, _ := extractFieldAndInnerFields(fieldsWithInnerAssociation[fields[i]][j].Fields[0])
    64  				if innerF == f {
    65  					fieldsWithInnerAssociation[fields[i]][j].Fields = append(fieldsWithInnerAssociation[fields[i]][j].Fields, innerField)
    66  					found = true
    67  					break
    68  				}
    69  			}
    70  
    71  			if !found {
    72  				fieldsWithInnerAssociation[fields[i]] = append(fieldsWithInnerAssociation[fields[i]], InnerAssociation{fields[i], []string{innerField}})
    73  			}
    74  		}
    75  	}
    76  
    77  	for i := 0; i < t.NumField(); i++ {
    78  		f := t.Field(i)
    79  
    80  		// inline embedded field
    81  		if f.Anonymous {
    82  			field := v.Field(i)
    83  			// we need field to be a pointer, so that we can later set the value
    84  			// if the embedded field is of type struct {...}, we have to take its address
    85  			if field.Kind() != reflect.Ptr {
    86  				field = field.Addr()
    87  			}
    88  			if fieldIsNil(field) {
    89  				// initialize zero value
    90  				field = reflect.New(field.Type().Elem())
    91  				// we can only get in this case if v.Field(i) is a pointer type because it could not be nil otherwise
    92  				//  => it is safe to set it here as is
    93  				v.Field(i).Set(field)
    94  			}
    95  			innerAssociations, err := forStruct(parent, field.Interface(), fields)
    96  			if err != nil {
    97  				return nil, err
    98  			}
    99  			associations = append(associations, innerAssociations...)
   100  			continue
   101  		}
   102  
   103  		// ignores those fields not included in fields list.
   104  		if len(fields) > 0 && fieldIgnoredIn(fields, f.Name) {
   105  			continue
   106  		}
   107  
   108  		tags := columns.TagsFor(f)
   109  
   110  		for name, builder := range associationBuilders {
   111  			tag := tags.Find(name)
   112  			if !tag.Empty() {
   113  				pt, pv := getModelDefinition(parent)
   114  				params := associationParams{
   115  					field:             f,
   116  					model:             parent,
   117  					modelType:         pt,
   118  					modelValue:        pv,
   119  					popTags:           tags,
   120  					innerAssociations: fieldsWithInnerAssociation[f.Name],
   121  				}
   122  
   123  				a, err := builder(params)
   124  				if err != nil {
   125  					return associations, err
   126  				}
   127  
   128  				associations = append(associations, a)
   129  				break
   130  			}
   131  		}
   132  	}
   133  
   134  	return associations, nil
   135  }
   136  
   137  func getModelDefinition(s interface{}) (reflect.Type, reflect.Value) {
   138  	v := reflect.ValueOf(s)
   139  	v = reflect.Indirect(v)
   140  	t := v.Type()
   141  	return t, v
   142  }
   143  
   144  func trimFields(fields []string) []string {
   145  	var trimFields []string
   146  	for _, f := range fields {
   147  		if strings.TrimSpace(f) != "" {
   148  			trimFields = append(trimFields, strings.TrimSpace(f))
   149  		}
   150  	}
   151  	return trimFields
   152  }
   153  
   154  func fieldIgnoredIn(fields []string, field string) bool {
   155  	for _, f := range fields {
   156  		if f == field {
   157  			return false
   158  		}
   159  	}
   160  	return true
   161  }
   162  
   163  func extractFieldAndInnerFields(field string) (string, string) {
   164  	if !strings.Contains(field, ".") {
   165  		return field, ""
   166  	}
   167  
   168  	dotIndex := strings.Index(field, ".")
   169  	return field[:dotIndex], field[dotIndex+1:]
   170  }