Alexandria  2.18
Please provide a description of the project.
SOMTrainer.h
Go to the documentation of this file.
1 /*
2  * Copyright (C) 2012-2021 Euclid Science Ground Segment
3  *
4  * This library is free software; you can redistribute it and/or modify it under
5  * the terms of the GNU Lesser General Public License as published by the Free
6  * Software Foundation; either version 3.0 of the License, or (at your option)
7  * any later version.
8  *
9  * This library is distributed in the hope that it will be useful, but WITHOUT
10  * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS
11  * FOR A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
12  * details.
13  *
14  * You should have received a copy of the GNU Lesser General Public License
15  * along with this library; if not, write to the Free Software Foundation, Inc.,
16  * 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
17  */
18 
19 /*
20  * @file SOMTrainer.h
21  * @author nikoapos
22  */
23 
24 #ifndef SOM_SOMTRAINER_H
25 #define SOM_SOMTRAINER_H
26 
28 #include "SOM/NeighborhoodFunc.h"
29 #include "SOM/SOM.h"
30 #include "SOM/SamplingPolicy.h"
31 
32 namespace Euclid {
33 namespace SOM {
34 
35 class SOMTrainer {
36 
37 public:
38  SOMTrainer(NeighborhoodFunc::Signature neighborhood_func, LearningRestraintFunc::Signature learning_restraint_func)
39  : m_neighborhood_func(neighborhood_func), m_learning_restraint_func(learning_restraint_func) {}
40 
41  template <std::size_t ND, typename DistFunc, typename InputIter, typename InputToWeightFunc>
42  void train(SOM<ND, DistFunc>& som, std::size_t iter_no, InputIter begin, InputIter end, InputToWeightFunc weight_func,
44 
45  // We repeat the training for iter_no iterations
46  for (std::size_t i = 0; i < iter_no; ++i) {
47 
48  // Compute the factor of the current iteration
49  auto learn_factor = m_learning_restraint_func(i, iter_no);
50  if (learn_factor == 0) {
51  continue;
52  }
53 
54  // Go through the training sample of the iteration
55  for (auto it = sampling_policy.start(begin, end); it != end; it = sampling_policy.next(it)) {
56 
57  // Get the weights of the input object
58  auto input_weights = weight_func(*it);
59 
60  // Find the coordinates of the BMU for the input
61  std::size_t bmu_x;
62  std::size_t bmu_y;
63  double nd_distance;
64  std::tie(bmu_x, bmu_y, nd_distance) = som.findBMU(*it, weight_func);
65 
66  // Now go through all the cells and update their values according their coordinates
67  for (auto cell_it = som.begin(); cell_it != som.end(); ++cell_it) {
68 
69  // Compute the factor based on the distance of the BMU and the cell
70  auto cell_x = cell_it.template axisValue<0>();
71  auto cell_y = cell_it.template axisValue<1>();
72  auto neighborhood_factor = m_neighborhood_func({bmu_x, bmu_y}, {cell_x, cell_y}, i, iter_no);
73 
74  // Get the weights of the cell and update them
75  if (neighborhood_factor != 0) {
76  auto& cell_weights = *cell_it;
77  for (std::size_t wi = 0; wi < ND; ++wi) {
78  cell_weights[wi] = cell_weights[wi] + neighborhood_factor * learn_factor * (input_weights[wi] - cell_weights[wi]);
79  }
80  }
81  }
82  }
83  }
84  }
85 
86 private:
89 };
90 
91 } // namespace SOM
92 } // namespace Euclid
93 
94 #endif /* SOM_SOMTRAINER_H */
Euclid::SOM::SOMTrainer::m_learning_restraint_func
LearningRestraintFunc::Signature m_learning_restraint_func
Definition: SOMTrainer.h:88
LearningRestraintFunc.h
Euclid::SOM::SOMTrainer
Definition: SOMTrainer.h:35
Euclid::SOM::SamplingPolicy::FullSet
Definition: SamplingPolicy.h:46
Euclid::SOM::SOMTrainer::SOMTrainer
SOMTrainer(NeighborhoodFunc::Signature neighborhood_func, LearningRestraintFunc::Signature learning_restraint_func)
Definition: SOMTrainer.h:38
std::function< double(std::pair< std::size_t, std::size_t > bmu, std::pair< std::size_t, std::size_t > cell, std::size_t iteration, std::size_t total_iterations)>
std::tie
T tie(T... args)
Euclid::SOM::SOM
Definition: SOM.h:46
Euclid::SOM::SOMTrainer::train
void train(SOM< ND, DistFunc > &som, std::size_t iter_no, InputIter begin, InputIter end, InputToWeightFunc weight_func, const SamplingPolicy::Interface< InputIter > &sampling_policy=SamplingPolicy::FullSet< InputIter >{})
Definition: SOMTrainer.h:42
Euclid::SOM::SOMTrainer::m_neighborhood_func
NeighborhoodFunc::Signature m_neighborhood_func
Definition: SOMTrainer.h:87
SOM.h
Euclid::SOM::SamplingPolicy::Interface
Definition: SamplingPolicy.h:37
NeighborhoodFunc.h
std::size_t
SamplingPolicy.h
Euclid
Definition: InstOrRefHolder.h:29