3 #ifndef FORPY_FOREST_H_ 4 #define FORPY_FOREST_H_ 8 #include "./util/serialization/basics.h" 47 const uint &max_depth = std::numeric_limits<uint>::max(),
48 const uint &min_samples_at_leaf = 1,
49 const uint &min_samples_at_node = 2,
50 const std::shared_ptr<IDecider> &decider_template =
nullptr,
51 const std::shared_ptr<ILeaf> &leaf_manager_template =
nullptr,
68 Forest(std::string filename);
85 const size_t &num_threads = 1,
const bool &bootstrap =
true,
86 const std::vector<float> &weights = std::vector<float>());
95 std::vector<size_t> result(
trees.size());
97 for (
const auto &tree_ptr :
trees) {
98 result[tree_id] = tree_ptr->get_depth();
115 Forest *
fit_dprov(
const std::shared_ptr<IDataProvider> &fdata_provider,
116 const bool &bootstrap =
true);
141 const bool &use_fast_prediction_if_available =
true,
143 if (num_threads == 0)
146 std::vector<Data<Mat>> results(0);
148 results.reserve(
trees.size());
149 for (
size_t i = 0; i <
trees.size(); ++i) {
151 data_v, 1, use_fast_prediction_if_available,
predict_proba,
true));
152 tree_weights(i) =
trees[i]->get_weight();
159 const int &num_threads = 1,
160 const bool &use_fast_prediction_if_available =
true) {
161 return predict(data_v, num_threads, use_fast_prediction_if_available,
true);
166 return trees[0]->get_input_data_dimensions();
171 return trees[0]->get_decider();
179 for (
auto &tree :
trees) tree->enable_fast_prediction();
184 for (
auto &tree :
trees) tree->disable_fast_prediction();
189 return trees[0]->get_leaf_manager();
194 std::vector<float> result(
trees.size());
195 for (std::size_t tree_idx = 0; tree_idx <
trees.size(); ++tree_idx) {
196 result[tree_idx] =
trees.at(tree_idx)->get_weight();
203 if (weights.size() !=
trees.size())
205 " weights, received " +
206 std::to_string(weights.size()));
207 for (std::size_t tree_idx = 0; tree_idx <
trees.size(); ++tree_idx)
208 trees[tree_idx]->set_weight(weights[tree_idx]);
217 void save(
const std::string &filename)
const;
220 if (
trees.size() != rhs.
trees.size())
return false;
221 for (
size_t i = 0; i <
trees.size(); ++i) {
222 if (!(*(
trees.at(i)) == *(rhs.
trees.at(i))))
return false;
230 stream <<
"forpy::Forest[" <<
self.trees.size() <<
" trees]";
235 friend class cereal::access;
236 template <
class Archive>
244 std::vector<std::shared_ptr<Tree>>
trees;
253 const uint &max_depth = std::numeric_limits<uint>::max(),
254 const uint &min_samples_at_leaf = 1,
255 const uint &min_samples_at_node = 2,
256 const uint &n_valid_features_to_use = 0,
257 const bool &autoscale_valid_features =
true,
259 const size_t &n_thresholds = 0,
260 const float &gain_threshold = 1E-7f);
262 inline std::unordered_map<std::string, mu::variant<uint, size_t, float, bool>>
268 const std::unordered_map<
269 std::string, mu::variant<uint, size_t, float, bool>> &
params) {
270 return std::make_shared<ClassificationForest>(
271 GetWithDefVar<size_t>(
params,
"n_trees", 10),
272 GetWithDefVar<uint>(
params,
"max_depth",
273 std::numeric_limits<uint>::max()),
274 GetWithDefVar<uint>(
params,
"min_samples_at_leaf", 1),
275 GetWithDefVar<uint>(
params,
"min_samples_at_node", 2),
276 GetWithDefVar<uint>(
params,
"n_valid_features_to_use", 0),
277 GetWithDefVar<bool>(
params,
"autoscale_valid_features",
true),
278 GetWithDefVar<uint>(
params,
"random_seed", 1),
279 GetWithDefVar<size_t>(
params,
"n_thresholds", 0),
280 GetWithDefVar<float>(
params,
"gain_threshold", 1E-7f));
285 stream <<
"forpy::ClassificationForest[" 286 << mu::static_variant_cast<
size_t>(
self.params.at(
"n_trees"))
292 std::unordered_map<std::string, mu::variant<uint, size_t, float, bool>>
295 template <
class Archive>
297 ar(cereal::make_nvp(
"base", cereal::base_class<Forest>(
this)),
307 const uint &max_depth = std::numeric_limits<uint>::max(),
308 const uint &min_samples_at_leaf = 1,
309 const uint &min_samples_at_node = 2,
310 const uint &n_valid_features_to_use = 0,
311 const bool &autoscale_valid_features =
false,
313 const float &gain_threshold = 1E-7f,
314 const bool &store_variance =
false,
315 const bool &summarize =
false);
317 inline std::unordered_map<std::string, mu::variant<uint, size_t, float, bool>>
323 const std::unordered_map<
324 std::string, mu::variant<uint, size_t, float, bool>> &
params) {
325 return std::make_shared<RegressionForest>(
326 GetWithDefVar<size_t>(
params,
"n_trees", 10),
327 GetWithDefVar<uint>(
params,
"max_depth",
328 std::numeric_limits<uint>::max()),
329 GetWithDefVar<uint>(
params,
"min_samples_at_leaf", 1),
330 GetWithDefVar<uint>(
params,
"min_samples_at_node", 2),
331 GetWithDefVar<uint>(
params,
"n_valid_features_to_use", 0),
332 GetWithDefVar<bool>(
params,
"autoscale_valid_features",
false),
333 GetWithDefVar<uint>(
params,
"random_seed", 1),
334 GetWithDefVar<size_t>(
params,
"n_thresholds", 0),
335 GetWithDefVar<float>(
params,
"gain_threshold", 1E-7f),
336 GetWithDefVar<bool>(
params,
"store_variance",
false),
337 GetWithDefVar<bool>(
params,
"summarize",
false));
342 stream <<
"forpy::RegressionForest[" 343 << mu::static_variant_cast<
size_t>(
self.params.at(
"n_trees"))
349 std::unordered_map<std::string, mu::variant<uint, size_t, float, bool>>
352 template <
class Archive>
354 ar(cereal::make_nvp(
"base", cereal::base_class<Forest>(
this)),
361 #endif // FORPY_FOREST_H_ void serialize(Archive &ar, const uint &)
friend std::ostream & operator<<(std::ostream &stream, const Forest &self)
std::unordered_map< std::string, mu::variant< uint, size_t, float, bool > > params
std::shared_ptr< ClassificationForest > set_params(const std::unordered_map< std::string, mu::variant< uint, size_t, float, bool >> ¶ms)
void save(const std::string &filename) const
std::unordered_map< std::string, mu::variant< uint, size_t, float, bool > > params
friend class cereal::access
void serialize(Archive &ar, const uint &)
void enable_fast_prediction()
std::vector< std::shared_ptr< Tree > > trees
std::shared_ptr< const IDecider > get_decider() const
typename mu::variant< Empty, STOT< float >, STOT< double >, STOT< uint >, STOT< uint8_t > > Data
Storing a variant of the provided data container type.
void set_tree_weights(const std::vector< float > &weights) const
void serialize(Archive &ar, const uint &)
Forest(const uint &n_trees=10, const uint &max_depth=std::numeric_limits< uint >::max(), const uint &min_samples_at_leaf=1, const uint &min_samples_at_node=2, const std::shared_ptr< IDecider > &decider_template=nullptr, const std::shared_ptr< ILeaf > &leaf_manager_template=nullptr, const uint &random_seed=1)
Data< Mat > predict(const Data< MatCRef > &data_v, const int &num_threads=1, const bool &use_fast_prediction_if_available=true, const bool &predict_proba=false)
std::unordered_map< std::string, mu::variant< uint, size_t, float, bool > > get_params(const bool &=false) const
friend std::ostream & operator<<(std::ostream &stream, const RegressionForest &self)
friend class cereal::access
ClassificationForest(const std::string &filename)
RegressionForest(const std::string &filename)
std::vector< float > get_tree_weights() const
DISALLOW_COPY_AND_ASSIGN(ClassificationForest)
void disable_fast_prediction()
Eigen::Matrix< DT, Eigen::Dynamic, 1, Eigen::ColMajor > Vec
Data< Mat > predict_proba(const Data< MatCRef > &data_v, const int &num_threads=1, const bool &use_fast_prediction_if_available=true)
std::shared_ptr< RegressionForest > set_params(const std::unordered_map< std::string, mu::variant< uint, size_t, float, bool >> ¶ms)
std::vector< size_t > get_depths() const
std::shared_ptr< const ILeaf > get_leaf_manager() const
std::vector< std::shared_ptr< Tree > > get_trees() const
DISALLOW_COPY_AND_ASSIGN(RegressionForest)
friend std::ostream & operator<<(std::ostream &stream, const ClassificationForest &self)
bool operator==(const Forest &rhs) const
std::unordered_map< std::string, mu::variant< uint, size_t, float, bool > > get_params(const bool &=false) const
unsigned int uint
Convenience typedef for unsigned int.
size_t get_input_data_dimensions() const
Forest * fit(const Data< MatCRef > &data_v, const Data< MatCRef > &annotation_v, const size_t &num_threads=1, const bool &bootstrap=true, const std::vector< float > &weights=std::vector< float >())
Forest * fit_dprov(const std::shared_ptr< IDataProvider > &fdata_provider, const bool &bootstrap=true)
The fitting function for a forest.
DISALLOW_COPY_AND_ASSIGN(Forest)