github.com/royge/pop@v4.13.1+incompatible/associations/many_to_many_association.go (about)

     1  package associations
     2  
     3  import (
     4  	"fmt"
     5  	"reflect"
     6  	"time"
     7  
     8  	"github.com/gobuffalo/flect"
     9  	"github.com/gobuffalo/pop/internal/defaults"
    10  	"github.com/gofrs/uuid"
    11  )
    12  
    13  type manyToManyAssociation struct {
    14  	fieldType           reflect.Type
    15  	fieldValue          reflect.Value
    16  	model               reflect.Value
    17  	manyToManyTableName string
    18  	owner               interface{}
    19  	fkID                string
    20  	orderBy             string
    21  	primaryID           string
    22  	*associationSkipable
    23  	*associationComposite
    24  }
    25  
    26  func init() {
    27  	associationBuilders["many_to_many"] = func(p associationParams) (Association, error) {
    28  		// Validates if model.ID is nil, this association will be skipped.
    29  		var skipped bool
    30  		model := p.modelValue
    31  		if fieldIsNil(model.FieldByName("ID")) {
    32  			skipped = true
    33  		}
    34  
    35  		return &manyToManyAssociation{
    36  			fieldType:           p.modelValue.FieldByName(p.field.Name).Type(),
    37  			fieldValue:          p.modelValue.FieldByName(p.field.Name),
    38  			owner:               p.model,
    39  			model:               model,
    40  			manyToManyTableName: p.popTags.Find("many_to_many").Value,
    41  			fkID:                p.popTags.Find("fk_id").Value,
    42  			orderBy:             p.popTags.Find("order_by").Value,
    43  			primaryID:           p.popTags.Find("primary_id").Value,
    44  			associationSkipable: &associationSkipable{
    45  				skipped: skipped,
    46  			},
    47  			associationComposite: &associationComposite{innerAssociations: p.innerAssociations},
    48  		}, nil
    49  	}
    50  }
    51  
    52  func (m *manyToManyAssociation) Kind() reflect.Kind {
    53  	return m.fieldType.Kind()
    54  }
    55  
    56  func (m *manyToManyAssociation) Interface() interface{} {
    57  	val := reflect.New(m.fieldType.Elem())
    58  	if m.fieldValue.Kind() == reflect.Ptr {
    59  		m.fieldValue.Set(val)
    60  		return m.fieldValue.Interface()
    61  	}
    62  
    63  	// This piece of code clears a slice in case it is filled with elements.
    64  	if m.fieldValue.Kind() == reflect.Slice || m.fieldValue.Kind() == reflect.Array {
    65  		valPointer := m.fieldValue.Addr()
    66  		valPointer.Elem().Set(reflect.MakeSlice(valPointer.Type().Elem(), 0, valPointer.Elem().Cap()))
    67  		return valPointer.Interface()
    68  	}
    69  
    70  	return m.fieldValue.Addr().Interface()
    71  }
    72  
    73  // Constraint returns the content for a where clause, and the args
    74  // needed to execute it.
    75  func (m *manyToManyAssociation) Constraint() (string, []interface{}) {
    76  	modelColumnID := defaults.String(m.primaryID, fmt.Sprintf("%s%s", flect.Underscore(m.model.Type().Name()), "_id"))
    77  
    78  	var columnFieldID string
    79  	i := reflect.Indirect(m.fieldValue)
    80  	t := i.Type()
    81  	if i.Kind() == reflect.Slice || i.Kind() == reflect.Array {
    82  		t = t.Elem()
    83  	}
    84  	if t.Kind() == reflect.Ptr {
    85  		t = t.Elem()
    86  	}
    87  	columnFieldID = defaults.String(m.fkID, fmt.Sprintf("%s%s", flect.Underscore(t.Name()), "_id"))
    88  
    89  	subQuery := fmt.Sprintf("select %s from %s where %s = ?", columnFieldID, m.manyToManyTableName, modelColumnID)
    90  	modelIDValue := m.model.FieldByName("ID").Interface()
    91  
    92  	return fmt.Sprintf("id in (%s)", subQuery), []interface{}{modelIDValue}
    93  }
    94  
    95  func (m *manyToManyAssociation) OrderBy() string {
    96  	return m.orderBy
    97  }
    98  
    99  func (m *manyToManyAssociation) BeforeInterface() interface{} {
   100  	if m.fieldValue.Kind() == reflect.Ptr {
   101  		return m.fieldValue.Interface()
   102  	}
   103  	return m.fieldValue.Addr().Interface()
   104  }
   105  
   106  func (m *manyToManyAssociation) BeforeSetup() error {
   107  	return nil
   108  }
   109  
   110  func (m *manyToManyAssociation) Statements() []AssociationStatement {
   111  	var statements []AssociationStatement
   112  
   113  	modelColumnID := fmt.Sprintf("%s%s", flect.Underscore(m.model.Type().Name()), "_id")
   114  	var columnFieldID string
   115  	i := reflect.Indirect(m.fieldValue)
   116  	if i.Kind() == reflect.Slice || i.Kind() == reflect.Array {
   117  		t := i.Type().Elem()
   118  		columnFieldID = fmt.Sprintf("%s%s", flect.Underscore(t.Name()), "_id")
   119  	} else {
   120  		columnFieldID = fmt.Sprintf("%s%s", flect.Underscore(i.Type().Name()), "_id")
   121  	}
   122  
   123  	for i := 0; i < m.fieldValue.Len(); i++ {
   124  		v := m.fieldValue.Index(i)
   125  		manyIDValue := v.FieldByName("ID").Interface()
   126  		modelIDValue := m.model.FieldByName("ID").Interface()
   127  		stm := "INSERT INTO %s (%s,%s,%s,%s) SELECT ?,?,?,? WHERE NOT EXISTS (SELECT * FROM %s WHERE %s = ? AND %s = ?)"
   128  
   129  		if IsZeroOfUnderlyingType(manyIDValue) || IsZeroOfUnderlyingType(modelIDValue) {
   130  			continue
   131  		}
   132  
   133  		associationStm := AssociationStatement{
   134  			Statement: fmt.Sprintf(stm, m.manyToManyTableName, modelColumnID, columnFieldID, "created_at", "updated_at", m.manyToManyTableName, modelColumnID, columnFieldID),
   135  			Args:      []interface{}{modelIDValue, manyIDValue, time.Now(), time.Now(), modelIDValue, manyIDValue},
   136  		}
   137  
   138  		if m.model.FieldByName("ID").Type().Name() == "UUID" {
   139  			stm = "INSERT INTO %s (%s,%s,%s,%s,%s) SELECT ?,?,?,?,? WHERE NOT EXISTS (SELECT * FROM %s WHERE %s = ? AND %s = ?)"
   140  			id, _ := uuid.NewV4()
   141  			associationStm = AssociationStatement{
   142  				Statement: fmt.Sprintf(stm, m.manyToManyTableName, "id", modelColumnID, columnFieldID, "created_at", "updated_at", m.manyToManyTableName, modelColumnID, columnFieldID),
   143  				Args:      []interface{}{id, modelIDValue, manyIDValue, time.Now(), time.Now(), modelIDValue, manyIDValue},
   144  			}
   145  		}
   146  
   147  		statements = append(statements, associationStm)
   148  	}
   149  
   150  	return statements
   151  }