curfil  ..
 All Classes Functions Variables Typedefs Friends Groups Pages
predict.h
1 #ifndef CURFIL_PREDICT
2 #define CURFIL_PREDICT
3 
4 #include <cuv/ndarray.hpp>
5 #include <string>
6 
7 #include "random_forest_image.h"
8 
9 namespace curfil {
10 
21 
22 public:
23 
27  explicit ConfusionMatrix() :
28  data(), normalized(false) {
29  }
30 
34  explicit ConfusionMatrix(const ConfusionMatrix& other) :
35  data(other.data.copy()), normalized(other.normalized), ignoredLabels(other.ignoredLabels) {
36  }
37 
41  explicit ConfusionMatrix(size_t numClasses) :
42  data(numClasses, numClasses), normalized(false) {
43  assert(numClasses > 0);
44  reset();
45  }
46 
50  explicit ConfusionMatrix(size_t numClasses, std::vector<LabelType> ignoredLabels) :
51  data(numClasses, numClasses), normalized(false), ignoredLabels(ignoredLabels) {
52  assert(numClasses > 0);
53  reset();
54  }
55 
60  data = other.data.copy();
61  normalized = other.normalized;
62  ignoredLabels = other.ignoredLabels;
63  return *this;
64  }
65 
69  void reset();
70 
76  void resize(unsigned int numClasses);
77 
82  bool isNormalized() const {
83  return normalized;
84  }
85 
89  void operator+=(const ConfusionMatrix& other);
90 
95  operator()(int label, int prediction) {
96  return data(label, prediction);
97  }
98 
102  double operator()(int label, int prediction) const {
103  return data(label, prediction);
104  }
105 
110  void increment(int label, int prediction) {
111  if (normalized) {
112  throw std::runtime_error("confusion matrix is already normalized");
113  }
114  assert(label < static_cast<int>(getNumClasses()));
115  assert(prediction < static_cast<int>(getNumClasses()));
116  (data(label, prediction))++;
117  }
118 
122  unsigned int getNumClasses() const {
123  assert(data.ndim() == 2);
124  assert(data.shape(0) == data.shape(1));
125  return data.shape(0);
126  }
127 
131  void normalize();
132 
138  double averageClassAccuracy(bool includeVoid = true) const;
139 
140 private:
142  bool normalized;
143  std::vector<LabelType> ignoredLabels;
144 
145 };
146 
152 double calculatePixelAccuracy(const LabelImage& prediction, const LabelImage& groundTruth,
153  const bool includeVoid = true, const std::vector<LabelType> ignoredLabels = NULL, ConfusionMatrix* confusionMatrix = 0);
154 
164 void test(RandomForestImage& randomForest, const std::string& folderTesting,
165  const std::string& folderPrediction, const bool useDepthFilling,
166  const bool writeProbabilityImages);
167 
168 }
169 
170 std::ostream& operator<<(std::ostream& o, const curfil::ConfusionMatrix& confusionMatrix);
171 
172 #endif