github.com/kaydxh/golang@v0.0.131/pkg/gocv/cgo/third_path/opencv4/include/opencv2/flann/lsh_index.h (about) 1 /*********************************************************************** 2 * Software License Agreement (BSD License) 3 * 4 * Copyright 2008-2009 Marius Muja (mariusm@cs.ubc.ca). All rights reserved. 5 * Copyright 2008-2009 David G. Lowe (lowe@cs.ubc.ca). All rights reserved. 6 * 7 * THE BSD LICENSE 8 * 9 * Redistribution and use in source and binary forms, with or without 10 * modification, are permitted provided that the following conditions 11 * are met: 12 * 13 * 1. Redistributions of source code must retain the above copyright 14 * notice, this list of conditions and the following disclaimer. 15 * 2. Redistributions in binary form must reproduce the above copyright 16 * notice, this list of conditions and the following disclaimer in the 17 * documentation and/or other materials provided with the distribution. 18 * 19 * THIS SOFTWARE IS PROVIDED BY THE AUTHOR ``AS IS'' AND ANY EXPRESS OR 20 * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES 21 * OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. 22 * IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR ANY DIRECT, INDIRECT, 23 * INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT 24 * NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, 25 * DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY 26 * THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 27 * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF 28 * THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 29 *************************************************************************/ 30 31 /*********************************************************************** 32 * Author: Vincent Rabaud 33 *************************************************************************/ 34 35 #ifndef OPENCV_FLANN_LSH_INDEX_H_ 36 #define OPENCV_FLANN_LSH_INDEX_H_ 37 38 //! @cond IGNORED 39 40 #include <algorithm> 41 #include <cstring> 42 #include <map> 43 #include <vector> 44 45 #include "nn_index.h" 46 #include "matrix.h" 47 #include "result_set.h" 48 #include "heap.h" 49 #include "lsh_table.h" 50 #include "allocator.h" 51 #include "random.h" 52 #include "saving.h" 53 54 #ifdef _MSC_VER 55 #pragma warning(push) 56 #pragma warning(disable: 4702) //disable unreachable code 57 #endif 58 59 namespace cvflann 60 { 61 62 struct LshIndexParams : public IndexParams 63 { 64 LshIndexParams(int table_number = 12, int key_size = 20, int multi_probe_level = 2) 65 { 66 (*this)["algorithm"] = FLANN_INDEX_LSH; 67 // The number of hash tables to use 68 (*this)["table_number"] = table_number; 69 // The length of the key in the hash tables 70 (*this)["key_size"] = key_size; 71 // Number of levels to use in multi-probe (0 for standard LSH) 72 (*this)["multi_probe_level"] = multi_probe_level; 73 } 74 }; 75 76 /** 77 * Locality-sensitive hashing index 78 * 79 * Contains the tables and other information for indexing a set of points 80 * for nearest-neighbor matching. 81 */ 82 template<typename Distance> 83 class LshIndex : public NNIndex<Distance> 84 { 85 public: 86 typedef typename Distance::ElementType ElementType; 87 typedef typename Distance::ResultType DistanceType; 88 89 /** Constructor 90 * @param input_data dataset with the input features 91 * @param params parameters passed to the LSH algorithm 92 * @param d the distance used 93 */ 94 LshIndex(const Matrix<ElementType>& input_data, const IndexParams& params = LshIndexParams(), 95 Distance d = Distance()) : 96 dataset_(input_data), index_params_(params), distance_(d) 97 { 98 // cv::flann::IndexParams sets integer params as 'int', so it is used with get_param 99 // in place of 'unsigned int' 100 table_number_ = get_param(index_params_,"table_number",12); 101 key_size_ = get_param(index_params_,"key_size",20); 102 multi_probe_level_ = get_param(index_params_,"multi_probe_level",2); 103 104 feature_size_ = (unsigned)dataset_.cols; 105 fill_xor_mask(0, key_size_, multi_probe_level_, xor_masks_); 106 } 107 108 109 LshIndex(const LshIndex&); 110 LshIndex& operator=(const LshIndex&); 111 112 /** 113 * Builds the index 114 */ 115 void buildIndex() CV_OVERRIDE 116 { 117 tables_.resize(table_number_); 118 for (int i = 0; i < table_number_; ++i) { 119 lsh::LshTable<ElementType>& table = tables_[i]; 120 table = lsh::LshTable<ElementType>(feature_size_, key_size_); 121 122 // Add the features to the table 123 table.add(dataset_); 124 } 125 } 126 127 flann_algorithm_t getType() const CV_OVERRIDE 128 { 129 return FLANN_INDEX_LSH; 130 } 131 132 133 void saveIndex(FILE* stream) CV_OVERRIDE 134 { 135 save_value(stream,table_number_); 136 save_value(stream,key_size_); 137 save_value(stream,multi_probe_level_); 138 save_value(stream, dataset_); 139 } 140 141 void loadIndex(FILE* stream) CV_OVERRIDE 142 { 143 load_value(stream, table_number_); 144 load_value(stream, key_size_); 145 load_value(stream, multi_probe_level_); 146 load_value(stream, dataset_); 147 // Building the index is so fast we can afford not storing it 148 buildIndex(); 149 150 index_params_["algorithm"] = getType(); 151 index_params_["table_number"] = table_number_; 152 index_params_["key_size"] = key_size_; 153 index_params_["multi_probe_level"] = multi_probe_level_; 154 } 155 156 /** 157 * Returns size of index. 158 */ 159 size_t size() const CV_OVERRIDE 160 { 161 return dataset_.rows; 162 } 163 164 /** 165 * Returns the length of an index feature. 166 */ 167 size_t veclen() const CV_OVERRIDE 168 { 169 return feature_size_; 170 } 171 172 /** 173 * Computes the index memory usage 174 * Returns: memory used by the index 175 */ 176 int usedMemory() const CV_OVERRIDE 177 { 178 return (int)(dataset_.rows * sizeof(int)); 179 } 180 181 182 IndexParams getParameters() const CV_OVERRIDE 183 { 184 return index_params_; 185 } 186 187 /** 188 * \brief Perform k-nearest neighbor search 189 * \param[in] queries The query points for which to find the nearest neighbors 190 * \param[out] indices The indices of the nearest neighbors found 191 * \param[out] dists Distances to the nearest neighbors found 192 * \param[in] knn Number of nearest neighbors to return 193 * \param[in] params Search parameters 194 */ 195 virtual void knnSearch(const Matrix<ElementType>& queries, Matrix<int>& indices, Matrix<DistanceType>& dists, int knn, const SearchParams& params) CV_OVERRIDE 196 { 197 CV_Assert(queries.cols == veclen()); 198 CV_Assert(indices.rows >= queries.rows); 199 CV_Assert(dists.rows >= queries.rows); 200 CV_Assert(int(indices.cols) >= knn); 201 CV_Assert(int(dists.cols) >= knn); 202 203 204 KNNUniqueResultSet<DistanceType> resultSet(knn); 205 for (size_t i = 0; i < queries.rows; i++) { 206 resultSet.clear(); 207 std::fill_n(indices[i], knn, -1); 208 std::fill_n(dists[i], knn, std::numeric_limits<DistanceType>::max()); 209 findNeighbors(resultSet, queries[i], params); 210 if (get_param(params,"sorted",true)) resultSet.sortAndCopy(indices[i], dists[i], knn); 211 else resultSet.copy(indices[i], dists[i], knn); 212 } 213 } 214 215 216 /** 217 * Find set of nearest neighbors to vec. Their indices are stored inside 218 * the result object. 219 * 220 * Params: 221 * result = the result object in which the indices of the nearest-neighbors are stored 222 * vec = the vector for which to search the nearest neighbors 223 * maxCheck = the maximum number of restarts (in a best-bin-first manner) 224 */ 225 void findNeighbors(ResultSet<DistanceType>& result, const ElementType* vec, const SearchParams& /*searchParams*/) CV_OVERRIDE 226 { 227 getNeighbors(vec, result); 228 } 229 230 private: 231 /** Defines the comparator on score and index 232 */ 233 typedef std::pair<float, unsigned int> ScoreIndexPair; 234 struct SortScoreIndexPairOnSecond 235 { 236 bool operator()(const ScoreIndexPair& left, const ScoreIndexPair& right) const 237 { 238 return left.second < right.second; 239 } 240 }; 241 242 /** Fills the different xor masks to use when getting the neighbors in multi-probe LSH 243 * @param key the key we build neighbors from 244 * @param lowest_index the lowest index of the bit set 245 * @param level the multi-probe level we are at 246 * @param xor_masks all the xor mask 247 */ 248 void fill_xor_mask(lsh::BucketKey key, int lowest_index, unsigned int level, 249 std::vector<lsh::BucketKey>& xor_masks) 250 { 251 xor_masks.push_back(key); 252 if (level == 0) return; 253 for (int index = lowest_index - 1; index >= 0; --index) { 254 // Create a new key 255 lsh::BucketKey new_key = key | (1 << index); 256 fill_xor_mask(new_key, index, level - 1, xor_masks); 257 } 258 } 259 260 /** Performs the approximate nearest-neighbor search. 261 * @param vec the feature to analyze 262 * @param do_radius flag indicating if we check the radius too 263 * @param radius the radius if it is a radius search 264 * @param do_k flag indicating if we limit the number of nn 265 * @param k_nn the number of nearest neighbors 266 * @param checked_average used for debugging 267 */ 268 void getNeighbors(const ElementType* vec, bool /*do_radius*/, float radius, bool do_k, unsigned int k_nn, 269 float& /*checked_average*/) 270 { 271 static std::vector<ScoreIndexPair> score_index_heap; 272 273 if (do_k) { 274 unsigned int worst_score = std::numeric_limits<unsigned int>::max(); 275 typename std::vector<lsh::LshTable<ElementType> >::const_iterator table = tables_.begin(); 276 typename std::vector<lsh::LshTable<ElementType> >::const_iterator table_end = tables_.end(); 277 for (; table != table_end; ++table) { 278 size_t key = table->getKey(vec); 279 std::vector<lsh::BucketKey>::const_iterator xor_mask = xor_masks_.begin(); 280 std::vector<lsh::BucketKey>::const_iterator xor_mask_end = xor_masks_.end(); 281 for (; xor_mask != xor_mask_end; ++xor_mask) { 282 size_t sub_key = key ^ (*xor_mask); 283 const lsh::Bucket* bucket = table->getBucketFromKey(sub_key); 284 if (bucket == 0) continue; 285 286 // Go over each descriptor index 287 std::vector<lsh::FeatureIndex>::const_iterator training_index = bucket->begin(); 288 std::vector<lsh::FeatureIndex>::const_iterator last_training_index = bucket->end(); 289 DistanceType hamming_distance; 290 291 // Process the rest of the candidates 292 for (; training_index < last_training_index; ++training_index) { 293 hamming_distance = distance_(vec, dataset_[*training_index], dataset_.cols); 294 295 if (hamming_distance < worst_score) { 296 // Insert the new element 297 score_index_heap.push_back(ScoreIndexPair(hamming_distance, training_index)); 298 std::push_heap(score_index_heap.begin(), score_index_heap.end()); 299 300 if (score_index_heap.size() > (unsigned int)k_nn) { 301 // Remove the highest distance value as we have too many elements 302 std::pop_heap(score_index_heap.begin(), score_index_heap.end()); 303 score_index_heap.pop_back(); 304 // Keep track of the worst score 305 worst_score = score_index_heap.front().first; 306 } 307 } 308 } 309 } 310 } 311 } 312 else { 313 typename std::vector<lsh::LshTable<ElementType> >::const_iterator table = tables_.begin(); 314 typename std::vector<lsh::LshTable<ElementType> >::const_iterator table_end = tables_.end(); 315 for (; table != table_end; ++table) { 316 size_t key = table->getKey(vec); 317 std::vector<lsh::BucketKey>::const_iterator xor_mask = xor_masks_.begin(); 318 std::vector<lsh::BucketKey>::const_iterator xor_mask_end = xor_masks_.end(); 319 for (; xor_mask != xor_mask_end; ++xor_mask) { 320 size_t sub_key = key ^ (*xor_mask); 321 const lsh::Bucket* bucket = table->getBucketFromKey(sub_key); 322 if (bucket == 0) continue; 323 324 // Go over each descriptor index 325 std::vector<lsh::FeatureIndex>::const_iterator training_index = bucket->begin(); 326 std::vector<lsh::FeatureIndex>::const_iterator last_training_index = bucket->end(); 327 DistanceType hamming_distance; 328 329 // Process the rest of the candidates 330 for (; training_index < last_training_index; ++training_index) { 331 // Compute the Hamming distance 332 hamming_distance = distance_(vec, dataset_[*training_index], dataset_.cols); 333 if (hamming_distance < radius) score_index_heap.push_back(ScoreIndexPair(hamming_distance, training_index)); 334 } 335 } 336 } 337 } 338 } 339 340 /** Performs the approximate nearest-neighbor search. 341 * This is a slower version than the above as it uses the ResultSet 342 * @param vec the feature to analyze 343 */ 344 void getNeighbors(const ElementType* vec, ResultSet<DistanceType>& result) 345 { 346 typename std::vector<lsh::LshTable<ElementType> >::const_iterator table = tables_.begin(); 347 typename std::vector<lsh::LshTable<ElementType> >::const_iterator table_end = tables_.end(); 348 for (; table != table_end; ++table) { 349 size_t key = table->getKey(vec); 350 std::vector<lsh::BucketKey>::const_iterator xor_mask = xor_masks_.begin(); 351 std::vector<lsh::BucketKey>::const_iterator xor_mask_end = xor_masks_.end(); 352 for (; xor_mask != xor_mask_end; ++xor_mask) { 353 size_t sub_key = key ^ (*xor_mask); 354 const lsh::Bucket* bucket = table->getBucketFromKey((lsh::BucketKey)sub_key); 355 if (bucket == 0) continue; 356 357 // Go over each descriptor index 358 std::vector<lsh::FeatureIndex>::const_iterator training_index = bucket->begin(); 359 std::vector<lsh::FeatureIndex>::const_iterator last_training_index = bucket->end(); 360 DistanceType hamming_distance; 361 362 // Process the rest of the candidates 363 for (; training_index < last_training_index; ++training_index) { 364 // Compute the Hamming distance 365 hamming_distance = distance_(vec, dataset_[*training_index], (int)dataset_.cols); 366 result.addPoint(hamming_distance, *training_index); 367 } 368 } 369 } 370 } 371 372 /** The different hash tables */ 373 std::vector<lsh::LshTable<ElementType> > tables_; 374 375 /** The data the LSH tables where built from */ 376 Matrix<ElementType> dataset_; 377 378 /** The size of the features (as ElementType[]) */ 379 unsigned int feature_size_; 380 381 IndexParams index_params_; 382 383 /** table number */ 384 int table_number_; 385 /** key size */ 386 int key_size_; 387 /** How far should we look for neighbors in multi-probe LSH */ 388 int multi_probe_level_; 389 390 /** The XOR masks to apply to a key to get the neighboring buckets */ 391 std::vector<lsh::BucketKey> xor_masks_; 392 393 Distance distance_; 394 }; 395 } 396 397 #ifdef _MSC_VER 398 #pragma warning(pop) 399 #endif 400 401 //! @endcond 402 403 #endif //OPENCV_FLANN_LSH_INDEX_H_