forpy  2
forest.h
Go to the documentation of this file.
1 /* Author: Christoph Lassner. */
2 #pragma once
3 #ifndef FORPY_FOREST_H_
4 #define FORPY_FOREST_H_
5 
6 #include "./global.h"
7 
8 #include "./util/serialization/basics.h"
9 
10 #include <fstream>
11 #include <functional>
12 #include <memory>
13 #include <string>
14 #include <vector>
15 
17 #include "./deciders/fastdecider.h"
20 #include "./tree.h"
21 #include "./types.h"
22 #include "./util/threading/ctpl.h"
23 
24 namespace forpy {
28 class Forest {
29  public:
46  Forest(const uint &n_trees = 10,
47  const uint &max_depth = std::numeric_limits<uint>::max(),
48  const uint &min_samples_at_leaf = 1,
49  const uint &min_samples_at_node = 2,
50  const std::shared_ptr<IDecider> &decider_template = nullptr,
51  const std::shared_ptr<ILeaf> &leaf_manager_template = nullptr,
52  const uint &random_seed = 1);
53 
60  Forest(std::vector<std::shared_ptr<Tree>> &trees);
61 
68  Forest(std::string filename);
69 
84  Forest *fit(const Data<MatCRef> &data_v, const Data<MatCRef> &annotation_v,
85  const size_t &num_threads = 1, const bool &bootstrap = true,
86  const std::vector<float> &weights = std::vector<float>());
87 
94  std::vector<size_t> get_depths() const {
95  std::vector<size_t> result(trees.size());
96  size_t tree_id = 0;
97  for (const auto &tree_ptr : trees) {
98  result[tree_id] = tree_ptr->get_depth();
99  tree_id++;
100  }
101  return result;
102  }
103 
115  Forest *fit_dprov(const std::shared_ptr<IDataProvider> &fdata_provider,
116  const bool &bootstrap = true);
117 
140  Data<Mat> predict(const Data<MatCRef> &data_v, const int &num_threads = 1,
141  const bool &use_fast_prediction_if_available = true,
142  const bool &predict_proba = false) {
143  if (num_threads == 0)
144  throw ForpyException("The number of threads must be >0!");
145  if (num_threads != 1) throw ForpyException("Unimplemented!");
146  std::vector<Data<Mat>> results(0);
147  Vec<float> tree_weights(Vec<float>::Zero(trees.size()));
148  results.reserve(trees.size());
149  for (size_t i = 0; i < trees.size(); ++i) {
150  results.push_back(trees[i]->predict(
151  data_v, 1, use_fast_prediction_if_available, predict_proba, true));
152  tree_weights(i) = trees[i]->get_weight();
153  }
154  return trees[0]->combine_leaf_results(results, tree_weights, predict_proba);
155  };
156 
159  const int &num_threads = 1,
160  const bool &use_fast_prediction_if_available = true) {
161  return predict(data_v, num_threads, use_fast_prediction_if_available, true);
162  };
163 
165  inline size_t get_input_data_dimensions() const {
166  return trees[0]->get_input_data_dimensions();
167  }
168 
170  inline std::shared_ptr<const IDecider> get_decider() const {
171  return trees[0]->get_decider();
172  };
173 
175  inline std::vector<std::shared_ptr<Tree>> get_trees() const { return trees; }
176 
178  inline void enable_fast_prediction() {
179  for (auto &tree : trees) tree->enable_fast_prediction();
180  };
181 
183  inline void disable_fast_prediction() {
184  for (auto &tree : trees) tree->disable_fast_prediction();
185  };
186 
188  inline std::shared_ptr<const ILeaf> get_leaf_manager() const {
189  return trees[0]->get_leaf_manager();
190  }
191 
193  inline std::vector<float> get_tree_weights() const {
194  std::vector<float> result(trees.size());
195  for (std::size_t tree_idx = 0; tree_idx < trees.size(); ++tree_idx) {
196  result[tree_idx] = trees.at(tree_idx)->get_weight();
197  }
198  return result;
199  }
200 
202  inline void set_tree_weights(const std::vector<float> &weights) const {
203  if (weights.size() != trees.size())
204  throw ForpyException("Need " + std::to_string(trees.size()) +
205  " weights, received " +
206  std::to_string(weights.size()));
207  for (std::size_t tree_idx = 0; tree_idx < trees.size(); ++tree_idx)
208  trees[tree_idx]->set_weight(weights[tree_idx]);
209  }
210 
217  void save(const std::string &filename) const;
218 
219  inline bool operator==(const Forest &rhs) const {
220  if (trees.size() != rhs.trees.size()) return false;
221  for (size_t i = 0; i < trees.size(); ++i) {
222  if (!(*(trees.at(i)) == *(rhs.trees.at(i)))) return false;
223  }
224  if (random_seed != rhs.random_seed) return false;
225  return true;
226  };
227 
228  inline friend std::ostream &operator<<(std::ostream &stream,
229  const Forest &self) {
230  stream << "forpy::Forest[" << self.trees.size() << " trees]";
231  return stream;
232  };
233 
234  private:
235  friend class cereal::access;
236  template <class Archive>
237  void serialize(Archive &ar, const uint &) {
238  // To make this work, I had to take out a static assertion at cereal.hpp,
239  // at least for clang. Everything works as expected without the assertion
240  // and I wouldn't know why not.
241  ar(CEREAL_NVP(trees), CEREAL_NVP(random_seed));
242  };
243 
244  std::vector<std::shared_ptr<Tree>> trees;
247 }; // class Forest
248 
249 class ClassificationForest : public Forest {
250  public:
251  inline ClassificationForest(const std::string &filename) : Forest(filename){};
252  ClassificationForest(const size_t &n_trees = 10,
253  const uint &max_depth = std::numeric_limits<uint>::max(),
254  const uint &min_samples_at_leaf = 1,
255  const uint &min_samples_at_node = 2,
256  const uint &n_valid_features_to_use = 0,
257  const bool &autoscale_valid_features = true,
258  const uint &random_seed = 1,
259  const size_t &n_thresholds = 0,
260  const float &gain_threshold = 1E-7f);
261 
262  inline std::unordered_map<std::string, mu::variant<uint, size_t, float, bool>>
263  get_params(const bool & /*deep*/ = false) const {
264  return params;
265  }
266 
267  inline std::shared_ptr<ClassificationForest> set_params(
268  const std::unordered_map<
269  std::string, mu::variant<uint, size_t, float, bool>> &params) {
270  return std::make_shared<ClassificationForest>(
271  GetWithDefVar<size_t>(params, "n_trees", 10),
272  GetWithDefVar<uint>(params, "max_depth",
273  std::numeric_limits<uint>::max()),
274  GetWithDefVar<uint>(params, "min_samples_at_leaf", 1),
275  GetWithDefVar<uint>(params, "min_samples_at_node", 2),
276  GetWithDefVar<uint>(params, "n_valid_features_to_use", 0),
277  GetWithDefVar<bool>(params, "autoscale_valid_features", true),
278  GetWithDefVar<uint>(params, "random_seed", 1),
279  GetWithDefVar<size_t>(params, "n_thresholds", 0),
280  GetWithDefVar<float>(params, "gain_threshold", 1E-7f));
281  }
282 
283  inline friend std::ostream &operator<<(std::ostream &stream,
284  const ClassificationForest &self) {
285  stream << "forpy::ClassificationForest["
286  << mu::static_variant_cast<size_t>(self.params.at("n_trees"))
287  << " trees]";
288  return stream;
289  };
290 
291  private:
292  std::unordered_map<std::string, mu::variant<uint, size_t, float, bool>>
293  params;
294  friend class cereal::access;
295  template <class Archive>
296  void serialize(Archive &ar, const uint &) {
297  ar(cereal::make_nvp("base", cereal::base_class<Forest>(this)),
298  CEREAL_NVP(params));
299  }
301 };
302 
303 class RegressionForest : public Forest {
304  public:
305  inline RegressionForest(const std::string &filename) : Forest(filename){};
306  RegressionForest(const size_t &n_trees = 10,
307  const uint &max_depth = std::numeric_limits<uint>::max(),
308  const uint &min_samples_at_leaf = 1,
309  const uint &min_samples_at_node = 2,
310  const uint &n_valid_features_to_use = 0,
311  const bool &autoscale_valid_features = false,
312  const uint &random_seed = 1, const size_t &n_thresholds = 0,
313  const float &gain_threshold = 1E-7f,
314  const bool &store_variance = false,
315  const bool &summarize = false);
316 
317  inline std::unordered_map<std::string, mu::variant<uint, size_t, float, bool>>
318  get_params(const bool & /*deep*/ = false) const {
319  return params;
320  }
321 
322  inline std::shared_ptr<RegressionForest> set_params(
323  const std::unordered_map<
324  std::string, mu::variant<uint, size_t, float, bool>> &params) {
325  return std::make_shared<RegressionForest>(
326  GetWithDefVar<size_t>(params, "n_trees", 10),
327  GetWithDefVar<uint>(params, "max_depth",
328  std::numeric_limits<uint>::max()),
329  GetWithDefVar<uint>(params, "min_samples_at_leaf", 1),
330  GetWithDefVar<uint>(params, "min_samples_at_node", 2),
331  GetWithDefVar<uint>(params, "n_valid_features_to_use", 0),
332  GetWithDefVar<bool>(params, "autoscale_valid_features", false),
333  GetWithDefVar<uint>(params, "random_seed", 1),
334  GetWithDefVar<size_t>(params, "n_thresholds", 0),
335  GetWithDefVar<float>(params, "gain_threshold", 1E-7f),
336  GetWithDefVar<bool>(params, "store_variance", false),
337  GetWithDefVar<bool>(params, "summarize", false));
338  }
339 
340  inline friend std::ostream &operator<<(std::ostream &stream,
341  const RegressionForest &self) {
342  stream << "forpy::RegressionForest["
343  << mu::static_variant_cast<size_t>(self.params.at("n_trees"))
344  << " trees]";
345  return stream;
346  };
347 
348  private:
349  std::unordered_map<std::string, mu::variant<uint, size_t, float, bool>>
350  params;
351  friend class cereal::access;
352  template <class Archive>
353  void serialize(Archive &ar, const uint &) {
354  ar(cereal::make_nvp("base", cereal::base_class<Forest>(this)),
355  CEREAL_NVP(params));
356  }
358 };
359 
360 }; // namespace forpy
361 #endif // FORPY_FOREST_H_
void serialize(Archive &ar, const uint &)
Definition: forest.h:296
friend std::ostream & operator<<(std::ostream &stream, const Forest &self)
Definition: forest.h:228
std::unordered_map< std::string, mu::variant< uint, size_t, float, bool > > params
Definition: forest.h:346
std::shared_ptr< ClassificationForest > set_params(const std::unordered_map< std::string, mu::variant< uint, size_t, float, bool >> &params)
Definition: forest.h:267
void save(const std::string &filename) const
uint random_seed
Definition: forest.h:245
std::unordered_map< std::string, mu::variant< uint, size_t, float, bool > > params
Definition: forest.h:289
friend class cereal::access
Definition: forest.h:294
void serialize(Archive &ar, const uint &)
Definition: forest.h:353
void enable_fast_prediction()
Definition: forest.h:178
std::vector< std::shared_ptr< Tree > > trees
Definition: forest.h:242
std::shared_ptr< const IDecider > get_decider() const
Definition: forest.h:170
typename mu::variant< Empty, STOT< float >, STOT< double >, STOT< uint >, STOT< uint8_t > > Data
Storing a variant of the provided data container type.
Definition: storage.h:126
void set_tree_weights(const std::vector< float > &weights) const
Definition: forest.h:202
void serialize(Archive &ar, const uint &)
Definition: forest.h:237
Forest(const uint &n_trees=10, const uint &max_depth=std::numeric_limits< uint >::max(), const uint &min_samples_at_leaf=1, const uint &min_samples_at_node=2, const std::shared_ptr< IDecider > &decider_template=nullptr, const std::shared_ptr< ILeaf > &leaf_manager_template=nullptr, const uint &random_seed=1)
Data< Mat > predict(const Data< MatCRef > &data_v, const int &num_threads=1, const bool &use_fast_prediction_if_available=true, const bool &predict_proba=false)
Definition: forest.h:140
std::unordered_map< std::string, mu::variant< uint, size_t, float, bool > > get_params(const bool &=false) const
Definition: forest.h:263
friend std::ostream & operator<<(std::ostream &stream, const RegressionForest &self)
Definition: forest.h:340
friend class cereal::access
Definition: forest.h:351
ClassificationForest(const std::string &filename)
Definition: forest.h:251
RegressionForest(const std::string &filename)
Definition: forest.h:305
std::vector< float > get_tree_weights() const
Definition: forest.h:193
DISALLOW_COPY_AND_ASSIGN(ClassificationForest)
void disable_fast_prediction()
Definition: forest.h:183
Eigen::Matrix< DT, Eigen::Dynamic, 1, Eigen::ColMajor > Vec
Definition: types.h:73
Data< Mat > predict_proba(const Data< MatCRef > &data_v, const int &num_threads=1, const bool &use_fast_prediction_if_available=true)
Definition: forest.h:158
std::shared_ptr< RegressionForest > set_params(const std::unordered_map< std::string, mu::variant< uint, size_t, float, bool >> &params)
Definition: forest.h:322
std::vector< size_t > get_depths() const
Definition: forest.h:94
std::shared_ptr< const ILeaf > get_leaf_manager() const
Definition: forest.h:188
std::vector< std::shared_ptr< Tree > > get_trees() const
Definition: forest.h:175
DISALLOW_COPY_AND_ASSIGN(RegressionForest)
friend std::ostream & operator<<(std::ostream &stream, const ClassificationForest &self)
Definition: forest.h:283
bool operator==(const Forest &rhs) const
Definition: forest.h:219
std::unordered_map< std::string, mu::variant< uint, size_t, float, bool > > get_params(const bool &=false) const
Definition: forest.h:318
unsigned int uint
Convenience typedef for unsigned int.
Definition: types.h:113
size_t get_input_data_dimensions() const
Definition: forest.h:165
Forest * fit(const Data< MatCRef > &data_v, const Data< MatCRef > &annotation_v, const size_t &num_threads=1, const bool &bootstrap=true, const std::vector< float > &weights=std::vector< float >())
Forest * fit_dprov(const std::shared_ptr< IDataProvider > &fdata_provider, const bool &bootstrap=true)
The fitting function for a forest.
DISALLOW_COPY_AND_ASSIGN(Forest)