curfil  ..
 All Classes Functions Variables Typedefs Friends Groups Pages
random_forest_image.h
1 #ifndef CURFIL_RANDOM_FOREST_IMAGE_H
2 #define CURFIL_RANDOM_FOREST_IMAGE_H
3 
4 #include <boost/shared_ptr.hpp>
5 #include <vector>
6 
7 #include "random_tree_image.h"
8 
9 namespace curfil {
10 
11 class TreeNodes;
12 
18 
19 public:
20 
29  explicit RandomForestImage(const std::vector<std::string>& treeFiles,
30  const std::vector<int>& deviceIds = std::vector<int>(1, 0),
31  const AccelerationMode accelerationMode = GPU_ONLY,
32  const double histogramBias = 0.0);
33 
40  explicit RandomForestImage(unsigned int treeCount,
41  const TrainingConfiguration& configuration);
42 
49  explicit RandomForestImage(const std::vector<boost::shared_ptr<RandomTreeImage> >& ensemble,
50  const TrainingConfiguration& configuration);
51 
59  void train(const std::vector<LabeledRGBDImage>& trainLabelImages, size_t numLabels = 0, bool trainTreesSequentially = false);
60 
68  LabelImage predict(const RGBDImage& image,
70  const bool onGPU = true, bool useDepthImages = true) const;
71 
72 
76  LabelImage improveHistograms(const RGBDImage& trainingImage, const LabelImage& labelImage, const bool onGPU = true, bool useDepthImages = true) const;
77 
81  void updateTreesHistograms();
85  std::map<std::string, size_t> countFeatures() const;
86 
90  LabelType getNumClasses() const;
91 
95  const boost::shared_ptr<RandomTreeImage> getTree(size_t treeNr) const {
96 #ifdef NDEBUG
97  return ensemble[treeNr];
98 #else
99  return ensemble.at(treeNr);
100 #endif
101  }
102 
106  const std::vector<boost::shared_ptr<RandomTreeImage> >& getTrees() const {
107  return ensemble;
108  }
109 
114  return configuration;
115  }
116 
120  bool shouldIgnoreLabel(const LabelType& label) const;
121 
125  std::map<LabelType, RGBColor> getLabelColorMap() const;
126 
130  void normalizeHistograms(const double histogramBias);
131 
132 private:
133 
134  TrainingConfiguration configuration;
135 
136  std::vector<boost::shared_ptr<RandomTreeImage> > ensemble;
137  std::vector<boost::shared_ptr<const TreeNodes> > treeData;
138  boost::shared_ptr<cuv::allocator> m_predictionAllocator;
139 };
140 
141 }
142 
143 std::ostream& operator<<(std::ostream& os, const curfil::RandomForestImage& ensemble);
144 
145 #endif