forpy  2
ileaf.h
Go to the documentation of this file.
1 /* Author: Christoph Lassner. */
2 #pragma once
3 #ifndef FORPY_LEAFS_ILEAF_H_
4 #define FORPY_LEAFS_ILEAF_H_
5 
6 #include "../global.h"
7 
8 #include "../util/serialization/basics.h"
9 
10 #include <vector>
11 
12 #include "../data_providers/idataprovider.h"
13 #include "../threshold_optimizers/ithreshopt.h"
14 #include "../types.h"
15 
16 namespace forpy {
23 class ILeaf {
24  public:
25  inline virtual ~ILeaf(){};
27  virtual std::shared_ptr<ILeaf> create_duplicate() const VIRTUAL_PTR;
28 
34  inline virtual bool is_compatible_with(
35  const IDataProvider & /*data_provider*/) {
36  return true;
37  };
38 
39  virtual bool is_compatible_with(const IThreshOpt &threshopt) VIRTUAL(bool);
40 
41  virtual void transfer_or_run_check(ILeaf *other, IThreshOpt *thresh_opt,
43 
47  virtual void make_leaf(const TodoMark &todo_info,
48  const IDataProvider &data_provider,
49  Desk *desk) const VIRTUAL_VOID;
50 
52  virtual size_t get_result_columns(const size_t &n_trees = 1,
53  const bool &predict_proba = false,
54  const bool &for_forest = false) const
55  VIRTUAL(size_t);
56 
58  virtual Data<Mat> get_result_type(const bool &predict_proba,
59  const bool &for_forest = false) const
60  VIRTUAL(Data<Mat>);
61 
68  inline virtual Data<Mat> get_result(const id_t &node_id,
69  const bool &predict_proba = false,
70  const bool &for_forest = false) const {
71  auto res_v = get_result_type(predict_proba);
72  Data<Mat> ret;
73  res_v.match(
74  [&](const auto &res_mt) {
75  typedef typename get_core<decltype(res_mt.data()[0])>::type RT;
76  ret.set<Mat<RT>>(
77  Mat<RT>::Zero(1, this->get_result_columns(1, predict_proba)));
78  Data<MatRef> dref = MatRef<RT>(ret.get_unchecked<Mat<RT>>());
79  this->get_result(node_id, dref, predict_proba, for_forest);
80  },
81  [](const Empty &) { throw EmptyException(); });
82  return ret;
83  };
84 
88  virtual void get_result(const id_t &node_id, Data<MatRef> &target,
89  const bool &predict_proba,
90  const bool &for_forest) const VIRTUAL_VOID;
91 
98  inline virtual Data<Mat> get_result(
99  const std::vector<Data<Mat>> &leaf_results,
100  const Vec<float> &weights = Vec<float>(),
101  const bool &predict_proba = false) const {
102  Data<Mat> ret;
103  leaf_results[0].match(
104  [&](const auto &lr0) {
105  typedef typename get_core<decltype(lr0.data())>::type RT;
106  ret.set<Mat<RT>>(Mat<RT>::Zero(
107  lr0.rows(),
108  this->get_result_columns(leaf_results.size(), predict_proba, false)));
109  Data<MatRef> dref = MatRef<RT>(ret.get_unchecked<Mat<RT>>());
110  this->get_result(leaf_results, dref, weights, predict_proba);
111  },
112  [&](const Empty &) { throw EmptyException(); });
113  return ret;
114  };
115 
117  virtual void get_result(
118  const std::vector<Data<Mat>> &leaf_results, Data<MatRef> &target_v,
119  const Vec<float> &weights = Vec<float>(),
120  const bool &predict_proba = false) const VIRTUAL_VOID;
121 
123  virtual void ensure_capacity(const size_t &n) VIRTUAL_VOID;
124 
126  virtual void finalize_capacity(const size_t &n) VIRTUAL_VOID;
127 
129  virtual const std::vector<Mat<float>> *get_map() const = 0;
130 
131  virtual bool operator==(const ILeaf &rhs) const VIRTUAL(bool);
132 
133  protected:
135  inline ILeaf(){};
136 
137  private:
138  friend class cereal::access;
139  template <class Archive>
140  void serialize(Archive &, const uint &){};
141 
143 };
144 }; // namespace forpy
145 #endif // FORPY_LEAFS_ILEAF_H_
Find an optimal threshold.
Definition: ithreshopt.h:23
virtual bool is_compatible_with(const IDataProvider &)
Checks compatibility with a certain IDataProvider.
Definition: ileaf.h:34
#define VIRTUAL_VOID
Definition: global.h:32
virtual void make_leaf(const TodoMark &todo_info, const IDataProvider &data_provider, Desk *desk) const VIRTUAL_VOID
Creates a leaf with the specified node_id and data.
A data provider for the training of one tree.
Definition: idataprovider.h:22
STL namespace.
size_t id_t
Element id type.
Definition: types.h:106
DISALLOW_COPY_AND_ASSIGN(ILeaf)
typename mu::variant< Empty, STOT< float >, STOT< double >, STOT< uint >, STOT< uint8_t > > Data
Storing a variant of the provided data container type.
Definition: storage.h:126
virtual ~ILeaf()
Definition: ileaf.h:25
#define VIRTUAL_PTR
Definition: global.h:33
virtual void transfer_or_run_check(ILeaf *other, IThreshOpt *thresh_opt, IDataProvider *dprov) VIRTUAL_VOID
#define VIRTUAL(type)
Definition: global.h:31
Stores the parameters for one marked tree node.
Definition: types.h:152
virtual void finalize_capacity(const size_t &n) VIRTUAL_VOID
Cut down capacity to exactly n leafs.
virtual Data< Mat > get_result_type(const bool &predict_proba, const bool &for_forest=false) const VIRTUAL(Data< Mat >)
virtual Data< Mat > get_result(const std::vector< Data< Mat >> &leaf_results, const Vec< float > &weights=Vec< float >(), const bool &predict_proba=false) const
Combine leaf results of several trees with weights.
Definition: ileaf.h:98
Eigen::Ref< Mat< DT > > MatRef
Parameterized standard non-const matrix ref type.
Definition: types.h:69
virtual const std::vector< Mat< float > > * get_map() const =0
Get all leafs.
Stores and returns leaf values, and combines them to forest results.
Definition: ileaf.h:23
A struct to represent an empty variant.
Definition: storage.h:67
virtual void ensure_capacity(const size_t &n) VIRTUAL_VOID
Ensure that storage is available for at least n leafs.
virtual Data< Mat > get_result(const id_t &node_id, const bool &predict_proba=false, const bool &for_forest=false) const
Get the leaf data for the leaf with the given id.
Definition: ileaf.h:68
Eigen::Matrix< DT, Eigen::Dynamic, 1, Eigen::ColMajor > Vec
Definition: types.h:73
Eigen::Matrix< DT, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor > Mat
Parameterized Matrix type (row major).
Definition: types.h:52
Main thread desk object.
Definition: desk.h:201
Get the core datatype with removed pointer, reference and const modifiers.
Definition: storage.h:136
virtual size_t get_result_columns(const size_t &n_trees=1, const bool &predict_proba=false, const bool &for_forest=false) const VIRTUAL(size_t)
virtual std::shared_ptr< ILeaf > create_duplicate() const VIRTUAL_PTR
unsigned int uint
Convenience typedef for unsigned int.
Definition: types.h:113
void serialize(Archive &, const uint &)
Definition: ileaf.h:140