github.com/weaviate/weaviate@v1.24.6/usecases/schema/read_consensus.go (about)

     1  //                           _       _
     2  // __      _____  __ ___   ___  __ _| |_ ___
     3  // \ \ /\ / / _ \/ _` \ \ / / |/ _` | __/ _ \
     4  //  \ V  V /  __/ (_| |\ V /| | (_| | ||  __/
     5  //   \_/\_/ \___|\__,_| \_/ |_|\__,_|\__\___|
     6  //
     7  //  Copyright © 2016 - 2024 Weaviate B.V. All rights reserved.
     8  //
     9  //  CONTACT: hello@weaviate.io
    10  //
    11  
    12  package schema
    13  
    14  import (
    15  	"context"
    16  	"encoding/json"
    17  	"fmt"
    18  	"reflect"
    19  	"sort"
    20  
    21  	"github.com/sirupsen/logrus"
    22  	"github.com/weaviate/weaviate/entities/models"
    23  	"github.com/weaviate/weaviate/usecases/cluster"
    24  	"github.com/weaviate/weaviate/usecases/sharding"
    25  )
    26  
    27  type parserFn func(ctx context.Context, schema *State) error
    28  
    29  func newReadConsensus(parser parserFn,
    30  	logger logrus.FieldLogger,
    31  ) cluster.ConsensusFn {
    32  	return func(ctx context.Context,
    33  		in []*cluster.Transaction,
    34  	) (*cluster.Transaction, error) {
    35  		if len(in) == 0 || in[0].Type != ReadSchema {
    36  			return nil, nil
    37  		}
    38  
    39  		var consensus *cluster.Transaction
    40  		for i, tx := range in {
    41  
    42  			typed, err := UnmarshalTransaction(tx.Type, tx.Payload.(json.RawMessage))
    43  			if err != nil {
    44  				return nil, fmt.Errorf("unmarshal tx: %w", err)
    45  			}
    46  
    47  			err = parser(ctx, typed.(ReadSchemaPayload).Schema)
    48  			if err != nil {
    49  				return nil, fmt.Errorf("parse schema %w", err)
    50  			}
    51  
    52  			if i == 0 {
    53  				consensus = tx
    54  				consensus.Payload = typed
    55  				continue
    56  			}
    57  
    58  			if consensus.ID != tx.ID {
    59  				return nil, fmt.Errorf("comparing txs with different IDs: %s vs %s",
    60  					consensus.ID, tx.ID)
    61  			}
    62  			previous := consensus.Payload.(ReadSchemaPayload).Schema
    63  			current := typed.(ReadSchemaPayload).Schema
    64  			if err := Equal(previous, current); err != nil {
    65  				diff := Diff("previous", previous, "current", current)
    66  				logger.WithFields(logrusStartupSyncFields()).WithFields(logrus.Fields{
    67  					"diff": diff,
    68  				}).Errorf("trying to reach cluster consensus on schema: %v", err)
    69  
    70  				return nil, fmt.Errorf("did not reach consensus on schema in cluster: %w", err)
    71  			}
    72  		}
    73  
    74  		return consensus, nil
    75  	}
    76  }
    77  
    78  // Equal compares two schema states for equality
    79  // First the object classes are sorted, because
    80  // they are unordered. Then we can make the comparison
    81  // using DeepEqual
    82  func Equal(lhs, rhs *State) error {
    83  	if lhs == nil && rhs == nil {
    84  		return nil
    85  	}
    86  	if lhs == nil || rhs == nil {
    87  		return fmt.Errorf("nil state %p, %p", lhs, rhs)
    88  	}
    89  	if err := equalClasses(lhs.ObjectSchema, rhs.ObjectSchema); err != nil {
    90  		return fmt.Errorf("class models mismatch: %w", err)
    91  	}
    92  	if err := equalSharding(lhs.ShardingState, rhs.ShardingState); err != nil {
    93  		return fmt.Errorf("sharding state mismatch: %w", err)
    94  	}
    95  	return nil
    96  }
    97  
    98  func equalClasses(lhs, rhs *models.Schema) error {
    99  	if lhs == nil && rhs == nil {
   100  		return nil
   101  	}
   102  	if lhs == nil || rhs == nil {
   103  		return fmt.Errorf("model mismatch: %p!=%p", lhs, rhs)
   104  	}
   105  	m, n := len(lhs.Classes), len(rhs.Classes)
   106  	if n != m {
   107  		return fmt.Errorf("class count mismatch: %d!=%d", m, n)
   108  	}
   109  	if m == 0 {
   110  		return nil
   111  	}
   112  	// sort classes so we can compare them one by one
   113  	sort.Slice(lhs.Classes, func(i, j int) bool {
   114  		return lhs.Classes[i].Class < lhs.Classes[j].Class
   115  	})
   116  
   117  	sort.Slice(rhs.Classes, func(i, j int) bool {
   118  		return rhs.Classes[i].Class < rhs.Classes[j].Class
   119  	})
   120  
   121  	for i, cls := range lhs.Classes {
   122  		x := rhs.Classes[i]
   123  		if !reflect.DeepEqual(cls, rhs.Classes[i]) {
   124  			n1, n2 := "", ""
   125  			if cls != nil {
   126  				n1 = cls.Class
   127  			}
   128  			if x != nil {
   129  				n2 = cls.Class
   130  			}
   131  			return fmt.Errorf("class mismatch at position %d: %s %s", i, n1, n2)
   132  		}
   133  	}
   134  
   135  	return nil
   136  }
   137  
   138  func equalSharding(l, r map[string]*sharding.State) error {
   139  	m, n := len(l), len(r)
   140  	if m != n {
   141  		return fmt.Errorf("class count mismatch: %d!=%d", m, n)
   142  	}
   143  	if m == 0 {
   144  		return nil
   145  	}
   146  	for cls, u := range l {
   147  		v := r[cls]
   148  		if a, b := u.PartitioningEnabled, v.PartitioningEnabled; a != b {
   149  			return fmt.Errorf("class %s: partitioning %t %t", cls, a, b)
   150  		}
   151  		if u.Config != v.Config {
   152  			return fmt.Errorf("class %s: config mismatch", cls)
   153  		}
   154  
   155  		if nl, nr := len(u.Physical), len(v.Physical); nl != nr {
   156  			return fmt.Errorf("class %s: number of physical shards: local=%d remote=%d", cls, nl, nr)
   157  		}
   158  		for k, lu := range u.Physical {
   159  			if !reflect.DeepEqual(lu, v.Physical[k]) {
   160  				return fmt.Errorf("class %q: physical shard %q", cls, k)
   161  			}
   162  		}
   163  
   164  		if nl, nr := len(u.Virtual), len(v.Virtual); nl != nr {
   165  			return fmt.Errorf("class %s: number of virtual shards: local=%d remote=%d", cls, nl, nr)
   166  		}
   167  
   168  		for i, lu := range u.Virtual {
   169  			if !reflect.DeepEqual(lu, v.Virtual[i]) {
   170  				return fmt.Errorf("class %s: virtual shard at position %d", cls, i)
   171  			}
   172  		}
   173  
   174  	}
   175  	return nil
   176  }