mlpack  2.2.5
ns_model.hpp
Go to the documentation of this file.
1 
15 #ifndef MLPACK_METHODS_NEIGHBOR_SEARCH_NS_MODEL_HPP
16 #define MLPACK_METHODS_NEIGHBOR_SEARCH_NS_MODEL_HPP
17 
23 #include <boost/variant.hpp>
24 #include "neighbor_search.hpp"
25 
26 namespace mlpack {
27 namespace neighbor {
28 
32 template<typename SortPolicy,
33  template<typename TreeMetricType,
34  typename TreeStatType,
35  typename TreeMatType> class TreeType>
36 using NSType = NeighborSearch<SortPolicy,
38  arma::mat,
39  TreeType,
41  NeighborSearchStat<SortPolicy>,
42  arma::mat>::template DualTreeTraverser>;
43 
44 template<typename SortPolicy>
46 {
47  static const std::string Name() { return "neighbor_search_model"; }
48 };
49 
50 template<>
52 {
53  static const std::string Name() { return "nearest_neighbor_search_model"; }
54 };
55 
56 template<>
58 {
59  static const std::string Name() { return "furthest_neighbor_search_model"; }
60 };
61 
66 class MonoSearchVisitor : public boost::static_visitor<void>
67 {
68  private:
70  const size_t k;
72  arma::Mat<size_t>& neighbors;
74  arma::mat& distances;
75 
76  public:
78  template<typename NSType>
79  void operator()(NSType* ns) const;
80 
82  MonoSearchVisitor(const size_t k,
83  arma::Mat<size_t>& neighbors,
84  arma::mat& distances) :
85  k(k),
86  neighbors(neighbors),
87  distances(distances)
88  {};
89 };
90 
97 template<typename SortPolicy>
98 class BiSearchVisitor : public boost::static_visitor<void>
99 {
100  private:
102  const arma::mat& querySet;
104  const size_t k;
106  arma::Mat<size_t>& neighbors;
108  arma::mat& distances;
110  const size_t leafSize;
112  const double tau;
114  const double rho;
115 
117  template<typename NSType>
118  void SearchLeaf(NSType* ns) const;
119 
120  public:
122  template<template<typename TreeMetricType,
123  typename TreeStatType,
124  typename TreeMatType> class TreeType>
126 
128  template<template<typename TreeMetricType,
129  typename TreeStatType,
130  typename TreeMatType> class TreeType>
131  void operator()(NSTypeT<TreeType>* ns) const;
132 
134  void operator()(NSTypeT<tree::KDTree>* ns) const;
135 
137  void operator()(NSTypeT<tree::BallTree>* ns) const;
138 
140  void operator()(SpillKNN* ns) const;
141 
143  void operator()(NSTypeT<tree::Octree>* ns) const;
144 
146  BiSearchVisitor(const arma::mat& querySet,
147  const size_t k,
148  arma::Mat<size_t>& neighbors,
149  arma::mat& distances,
150  const size_t leafSize,
151  const double tau,
152  const double rho);
153 };
154 
161 template<typename SortPolicy>
162 class TrainVisitor : public boost::static_visitor<void>
163 {
164  private:
166  arma::mat&& referenceSet;
168  size_t leafSize;
170  const double tau;
172  const double rho;
173 
175  template<typename NSType>
176  void TrainLeaf(NSType* ns) const;
177 
178  public:
180  template<template<typename TreeMetricType,
181  typename TreeStatType,
182  typename TreeMatType> class TreeType>
184 
186  template<template<typename TreeMetricType,
187  typename TreeStatType,
188  typename TreeMatType> class TreeType>
189  void operator()(NSTypeT<TreeType>* ns) const;
190 
192  void operator()(NSTypeT<tree::KDTree>* ns) const;
193 
195  void operator()(NSTypeT<tree::BallTree>* ns) const;
196 
198  void operator()(SpillKNN* ns) const;
199 
201  void operator()(NSTypeT<tree::Octree>* ns) const;
202 
205  TrainVisitor(arma::mat&& referenceSet,
206  const size_t leafSize,
207  const double tau,
208  const double rho);
209 };
210 
214 class SearchModeVisitor : public boost::static_visitor<NeighborSearchMode>
215 {
216  public:
218  template<typename NSType>
220 };
221 
225 class SetSearchModeVisitor : public boost::static_visitor<void>
226 {
227  NeighborSearchMode searchMode;
228  public:
231  searchMode(searchMode)
232  {};
233 
235  template<typename NSType>
236  void operator()(NSType* ns) const;
237 };
238 
242 class EpsilonVisitor : public boost::static_visitor<double&>
243 {
244  public:
246  template<typename NSType>
247  double& operator()(NSType *ns) const;
248 };
249 
253 class ReferenceSetVisitor : public boost::static_visitor<const arma::mat&>
254 {
255  public:
257  template<typename NSType>
258  const arma::mat& operator()(NSType *ns) const;
259 };
260 
264 class DeleteVisitor : public boost::static_visitor<void>
265 {
266  public:
268  template<typename NSType>
269  void operator()(NSType *ns) const;
270 };
271 
282 template<typename SortPolicy>
283 class NSModel
284 {
285  public:
288  {
304  };
305 
306  private:
308  TreeTypes treeType;
309 
311  size_t leafSize;
312 
314  double tau;
316  double rho;
317 
319  bool randomBasis;
321  arma::mat q;
322 
328  boost::variant<NSType<SortPolicy, tree::KDTree>*,
340  SpillKNN*,
343 
344  public:
349  NSModel(TreeTypes treeType = TreeTypes::KD_TREE, bool randomBasis = false);
350 
352  ~NSModel();
353 
355  template<typename Archive>
356  void Serialize(Archive& ar, const unsigned int /* version */);
357 
359  const arma::mat& Dataset() const;
360 
364  void SetSearchMode(const NeighborSearchMode mode);
365 
367  double Epsilon() const;
368  double& Epsilon();
369 
371  size_t LeafSize() const { return leafSize; }
372  size_t& LeafSize() { return leafSize; }
373 
375  double Tau() const { return tau; }
376  double& Tau() { return tau; }
377 
379  double Rho() const { return rho; }
380  double& Rho() { return rho; }
381 
383  TreeTypes TreeType() const { return treeType; }
384  TreeTypes& TreeType() { return treeType; }
385 
387  bool RandomBasis() const { return randomBasis; }
388  bool& RandomBasis() { return randomBasis; }
389 
391  void BuildModel(arma::mat&& referenceSet,
392  const size_t leafSize,
393  const NeighborSearchMode searchMode,
394  const double epsilon = 0);
395 
397  void Search(arma::mat&& querySet,
398  const size_t k,
399  arma::Mat<size_t>& neighbors,
400  arma::mat& distances);
401 
403  void Search(const size_t k,
404  arma::Mat<size_t>& neighbors,
405  arma::mat& distances);
406 
408  std::string TreeName() const;
409 };
410 
411 } // namespace neighbor
412 } // namespace mlpack
413 
415 BOOST_TEMPLATE_CLASS_VERSION(template<typename SortPolicy>,
417 
418 // Include implementation.
419 #include "ns_model_impl.hpp"
420 
421 #endif
double Epsilon() const
Expose Epsilon.
void operator()(NSTypeT< TreeType > *ns) const
Default Bichromatic neighbor search on the given NSType instance.
MonoSearchVisitor(const size_t k, arma::Mat< size_t > &neighbors, arma::mat &distances)
Construct the MonoSearchVisitor object with the given parameters.
Definition: ns_model.hpp:82
EpsilonVisitor exposes the Epsilon method of the given NSType.
Definition: ns_model.hpp:242
TrainVisitor(arma::mat &&referenceSet, const size_t leafSize, const double tau, const double rho)
Construct the TrainVisitor object with the given reference set, leafSize for BinarySpaceTrees, and tau and rho for spill trees.
Linear algebra utility functions, generally performed on matrices or vectors.
Definition: binarize.hpp:18
bool RandomBasis() const
Expose randomBasis.
Definition: ns_model.hpp:387
TreeTypes
Enum type to identify each accepted tree type.
Definition: ns_model.hpp:287
const arma::mat & operator()(NSType *ns) const
Return the reference set.
SetSearchModeVisitor modifies the SearchMode method of the given NSType.
Definition: ns_model.hpp:225
BOOST_TEMPLATE_CLASS_VERSION(template< typename SortPolicy >, mlpack::neighbor::NSModel< SortPolicy >, 1)
Set the serialization version of the NSModel class.
ReferenceSetVisitor exposes the referenceSet of the given NSType.
Definition: ns_model.hpp:253
SearchModeVisitor exposes the SearchMode() method of the given NSType.
Definition: ns_model.hpp:214
void operator()(NSTypeT< TreeType > *ns) const
Default Train on the given NSType instance.
The NeighborSearch class is a template class for performing distance-based neighbor searches...
void operator()(NSType *ns) const
Set the search mode.
TreeTypes TreeType() const
Expose treeType.
Definition: ns_model.hpp:383
NeighborSearchMode operator()(NSType *ns) const
Return the search mode.
const arma::mat & Dataset() const
Expose the dataset.
NeighborSearch< SortPolicy, metric::EuclideanDistance, arma::mat, TreeType, TreeType< metric::EuclideanDistance, NeighborSearchStat< SortPolicy >, arma::mat >::template DualTreeTraverser > NSType
Alias template for euclidean neighbor search.
Definition: ns_model.hpp:42
This class implements the necessary methods for the SortPolicy template parameter of the NeighborSear...
~NSModel()
Clean memory, if necessary.
This class implements the necessary methods for the SortPolicy template parameter of the NeighborSear...
static const std::string Name()
Definition: ns_model.hpp:47
size_t LeafSize() const
Expose leafSize.
Definition: ns_model.hpp:371
BiSearchVisitor executes a bichromatic neighbor search on the given NSType.
Definition: ns_model.hpp:98
std::string TreeName() const
Return a string representation of the current tree type.
void Serialize(Archive &ar, const unsigned int)
Serialize the neighbor search model.
TreeTypes & TreeType()
Definition: ns_model.hpp:384
The NSModel class provides an easy way to serialize a model, abstracts away the different types of tr...
Definition: ns_model.hpp:283
double Tau() const
Expose tau.
Definition: ns_model.hpp:375
NSModel(TreeTypes treeType=TreeTypes::KD_TREE, bool randomBasis=false)
Initialize the NSModel with the given type and whether or not a random basis should be used...
SetSearchModeVisitor(const NeighborSearchMode searchMode)
Construct the SetSearchModeVisitor object with the given mode.
Definition: ns_model.hpp:230
TrainVisitor sets the reference set to a new reference set on the given NSType.
void Search(arma::mat &&querySet, const size_t k, arma::Mat< size_t > &neighbors, arma::mat &distances)
Perform neighbor search. The query set will be reordered.
BiSearchVisitor(const arma::mat &querySet, const size_t k, arma::Mat< size_t > &neighbors, arma::mat &distances, const size_t leafSize, const double tau, const double rho)
Construct the BiSearchVisitor.
void operator()(NSType *ns) const
Perform monochromatic nearest neighbor search.
MonoSearchVisitor executes a monochromatic neighbor search on the given NSType.
Definition: ns_model.hpp:66
void BuildModel(arma::mat &&referenceSet, const size_t leafSize, const NeighborSearchMode searchMode, const double epsilon=0)
Build the reference tree.
NeighborSearchMode SearchMode() const
Access the search mode.
DeleteVisitor deletes the given NSType instance.
Definition: ns_model.hpp:264
LMetric< 2, true > EuclideanDistance
The Euclidean (L2) distance.
Definition: lmetric.hpp:112
double Rho() const
Expose rho.
Definition: ns_model.hpp:379
double & operator()(NSType *ns) const
Return epsilon, the approximation parameter.
void SetSearchMode(const NeighborSearchMode mode)
Modify the search mode.
NeighborSearchMode
NeighborSearchMode represents the different neighbor search modes available.
void operator()(NSType *ns) const
Delete the NSType object.