github.com/schollz/clusters@v0.0.0-20221201012527-c6c68863636f/dbscan.go (about)

     1  package clusters
     2  
     3  import (
     4  	"sync"
     5  )
     6  
     7  type dbscanClusterer struct {
     8  	minpts, workers int
     9  	eps             float64
    10  
    11  	distance DistanceFunc
    12  
    13  	// slices holding the cluster mapping and sizes. Access is synchronized to avoid read during computation.
    14  	mu   sync.RWMutex
    15  	a, b []int
    16  
    17  	// variables used for concurrent computation of nearest neighbours
    18  	l, s, o, f int
    19  	j          chan *rangeJob
    20  	m          *sync.Mutex
    21  	w          *sync.WaitGroup
    22  	r          *[]int
    23  	p          []float64
    24  
    25  	// visited points
    26  	v []bool
    27  
    28  	// dataset
    29  	d [][]float64
    30  }
    31  
    32  // Implementation of DBSCAN algorithm with concurrent nearest neighbour computation. The number of goroutines acting concurrently
    33  // is controlled via workers argument. Passing 0 will result in this number being chosen arbitrarily.
    34  func DBSCAN(minpts int, eps float64, workers int, distance DistanceFunc) (HardClusterer, error) {
    35  	if minpts < 1 {
    36  		return nil, errZeroMinpts
    37  	}
    38  
    39  	if workers < 0 {
    40  		return nil, errZeroWorkers
    41  	}
    42  
    43  	if eps <= 0 {
    44  		return nil, errZeroEpsilon
    45  	}
    46  
    47  	var d DistanceFunc
    48  	{
    49  		if distance != nil {
    50  			d = distance
    51  		} else {
    52  			d = EuclideanDistance
    53  		}
    54  	}
    55  
    56  	return &dbscanClusterer{
    57  		minpts:   minpts,
    58  		workers:  workers,
    59  		eps:      eps,
    60  		distance: d,
    61  	}, nil
    62  }
    63  
    64  func (c *dbscanClusterer) IsOnline() bool {
    65  	return false
    66  }
    67  
    68  func (c *dbscanClusterer) WithOnline(o Online) HardClusterer {
    69  	return c
    70  }
    71  
    72  func (c *dbscanClusterer) Learn(data [][]float64) error {
    73  	if len(data) == 0 {
    74  		return errEmptySet
    75  	}
    76  
    77  	c.mu.Lock()
    78  
    79  	c.l = len(data)
    80  	c.s = c.numWorkers()
    81  	c.o = c.s - 1
    82  	c.f = c.l / c.s
    83  
    84  	c.d = data
    85  
    86  	c.v = make([]bool, c.l)
    87  
    88  	c.a = make([]int, c.l)
    89  	c.b = make([]int, 0)
    90  
    91  	c.startNearestWorkers()
    92  
    93  	c.run()
    94  
    95  	c.endNearestWorkers()
    96  
    97  	c.v = nil
    98  	c.p = nil
    99  	c.r = nil
   100  
   101  	c.mu.Unlock()
   102  
   103  	return nil
   104  }
   105  
   106  func (c *dbscanClusterer) Sizes() []int {
   107  	c.mu.RLock()
   108  	defer c.mu.RUnlock()
   109  
   110  	return c.b
   111  }
   112  
   113  func (c *dbscanClusterer) Guesses() []int {
   114  	c.mu.RLock()
   115  	defer c.mu.RUnlock()
   116  
   117  	return c.a
   118  }
   119  
   120  func (c *dbscanClusterer) Predict(p []float64) int {
   121  	var (
   122  		l int
   123  		d float64
   124  		m float64 = c.distance(p, c.d[0])
   125  	)
   126  
   127  	for i := 1; i < len(c.d); i++ {
   128  		if d = c.distance(p, c.d[i]); d < m {
   129  			m = d
   130  			l = i
   131  		}
   132  	}
   133  
   134  	return c.a[l]
   135  }
   136  
   137  func (c *dbscanClusterer) Online(observations chan []float64, done chan struct{}) chan *HCEvent {
   138  	return nil
   139  }
   140  
   141  // private
   142  func (c *dbscanClusterer) run() {
   143  	var (
   144  		n, m, l, k = 1, 0, 0, 0
   145  		ns, nss    = make([]int, 0), make([]int, 0)
   146  	)
   147  
   148  	for i := 0; i < c.l; i++ {
   149  		if c.v[i] {
   150  			continue
   151  		}
   152  
   153  		c.v[i] = true
   154  
   155  		c.nearest(i, &l, &ns)
   156  
   157  		if l < c.minpts {
   158  			c.a[i] = -1
   159  		} else {
   160  			c.a[i] = n
   161  
   162  			c.b = append(c.b, 0)
   163  			c.b[m]++
   164  
   165  			for j := 0; j < l; j++ {
   166  				if !c.v[ns[j]] {
   167  					c.v[ns[j]] = true
   168  
   169  					c.nearest(ns[j], &k, &nss)
   170  
   171  					if k >= c.minpts {
   172  						l += k
   173  						ns = append(ns, nss...)
   174  					}
   175  				}
   176  
   177  				if c.a[ns[j]] == 0 {
   178  					c.a[ns[j]] = n
   179  					c.b[m]++
   180  				}
   181  			}
   182  
   183  			n++
   184  			m++
   185  		}
   186  	}
   187  }
   188  
   189  /* Divide work among c.s workers, where c.s is determined
   190   * by the size of the data. This is based on an assumption that neighbour points of p
   191   * are located in relatively small subsection of the input data, so the dataset can be scanned
   192   * concurrently without blocking a big number of goroutines trying to write to r */
   193  func (c *dbscanClusterer) nearest(p int, l *int, r *[]int) {
   194  	var b int
   195  
   196  	*r = (*r)[:0]
   197  
   198  	c.p = c.d[p]
   199  	c.r = r
   200  
   201  	c.w.Add(c.s)
   202  
   203  	for i := 0; i < c.l; i += c.f {
   204  		if c.l-i <= c.f {
   205  			b = c.l - 1
   206  		} else {
   207  			b = i + c.f
   208  		}
   209  
   210  		c.j <- &rangeJob{
   211  			a: i,
   212  			b: b,
   213  		}
   214  	}
   215  
   216  	c.w.Wait()
   217  
   218  	*l = len(*r)
   219  }
   220  
   221  func (c *dbscanClusterer) startNearestWorkers() {
   222  	c.j = make(chan *rangeJob, c.l)
   223  
   224  	c.m = &sync.Mutex{}
   225  	c.w = &sync.WaitGroup{}
   226  
   227  	for i := 0; i < c.s; i++ {
   228  		go c.nearestWorker()
   229  	}
   230  }
   231  
   232  func (c *dbscanClusterer) endNearestWorkers() {
   233  	close(c.j)
   234  
   235  	c.j = nil
   236  
   237  	c.m = nil
   238  	c.w = nil
   239  }
   240  
   241  func (c *dbscanClusterer) nearestWorker() {
   242  	for j := range c.j {
   243  		for i := j.a; i < j.b; i++ {
   244  			if c.distance(c.p, c.d[i]) < c.eps {
   245  				c.m.Lock()
   246  				*c.r = append(*c.r, i)
   247  				c.m.Unlock()
   248  			}
   249  		}
   250  
   251  		c.w.Done()
   252  	}
   253  }
   254  
   255  func (c *dbscanClusterer) numWorkers() int {
   256  	var b int
   257  
   258  	if c.l < 1000 {
   259  		b = 1
   260  	} else if c.l < 10000 {
   261  		b = 10
   262  	} else if c.l < 100000 {
   263  		b = 100
   264  	} else {
   265  		b = 1000
   266  	}
   267  
   268  	if c.workers == 0 {
   269  		return b
   270  	}
   271  
   272  	if c.workers < b {
   273  		return c.workers
   274  	}
   275  
   276  	return b
   277  
   278  }