curfil  ..
 All Classes Functions Variables Typedefs Friends Groups Pages
hyperopt.h
1 #ifndef CURFIL_HYPEROPT
2 #define CURFIL_HYPEROPT
3 
4 #include <boost/asio/io_service.hpp>
5 #include <mdbq/client.hpp>
6 #include <mongo/bson/bson.h>
7 
8 #include "image.h"
9 #include "predict.h"
10 #include "random_forest_image.h"
11 #include "random_tree_image.h"
12 
13 namespace curfil {
14 
15 bool continueSearching(const std::vector<double>& currentBestAccuracies,
16  const std::vector<double>& currentRunAccuracies);
17 
18 enum LossFunctionType {
19  CLASS_ACCURACY, //
20  CLASS_ACCURACY_WITHOUT_VOID, //
21  PIXEL_ACCURACY, //
22  PIXEL_ACCURACY_WITHOUT_VOID //
23 };
24 
29 class Result {
30 
31 private:
32  ConfusionMatrix confusionMatrix;
33  double pixelAccuracy;
34  double pixelAccuracyWithoutVoid;
35  LossFunctionType lossFunctionType;
36  int randomSeed;
37 
38 public:
42  Result(const ConfusionMatrix& confusionMatrix, double pixelAccuracy,
43  double pixelAccuracyWithoutVoid, const LossFunctionType lossFunctionType) :
44  confusionMatrix(confusionMatrix),
45  pixelAccuracy(pixelAccuracy),
46  pixelAccuracyWithoutVoid(pixelAccuracyWithoutVoid),
47  lossFunctionType(lossFunctionType),
48  randomSeed(0) {
49  }
50 
54  mongo::BSONObj toBSON() const;
55 
60  return confusionMatrix;
61  }
62 
66  void setLossFunctionType(const LossFunctionType& lossFunctionType) {
67  this->lossFunctionType = lossFunctionType;
68  }
69 
73  double getLoss() const;
74 
78  double getClassAccuracy() const {
79  return confusionMatrix.averageClassAccuracy(true);
80  }
81 
85  double getClassAccuracyWithoutVoid() const {
86  return confusionMatrix.averageClassAccuracy(false);
87  }
88 
92  double getPixelAccuracy() const {
93  return pixelAccuracy;
94  }
95 
99  double getPixelAccuracyWithoutVoid() const {
100  return pixelAccuracyWithoutVoid;
101  }
102 
106  void setRandomSeed(int randomSeed) {
107  this->randomSeed = randomSeed;
108  }
109 
113  int getRandomSeed() const {
114  return randomSeed;
115  }
116 
117 };
118 
123 class HyperoptClient: public mdbq::Client {
124 
125 private:
126 
127  const std::vector<LabeledRGBDImage>& allRGBDImages;
128  const std::vector<LabeledRGBDImage>& allTestImages;
129 
130  bool useCIELab;
131  bool useDepthFilling;
132  std::vector<int> deviceIds;
133  int maxImages;
134  int imageCacheSizeMB;
135  int randomSeed;
136  int numThreads;
137  std::string subsamplingType;
138  std::vector<std::string> ignoredColors;
139  bool useDepthImages;
140  bool horizontalFlipping;
141  size_t numLabels;
142  LossFunctionType lossFunction;
143 
144  boost::asio::io_service ios;
145 
146  RandomForestImage train(size_t trees,
147  const TrainingConfiguration& configuration,
148  const std::vector<LabeledRGBDImage>& trainImages);
149 
150  void randomSplit(const int randomSeed, const double testRatio,
151  std::vector<LabeledRGBDImage>& trainImages,
152  std::vector<LabeledRGBDImage>& testImages);
153 
154  double measureTrueLoss(unsigned int numTrees, TrainingConfiguration configuration,
155  const double histogramBias, double& variance);
156 
157  const Result test(const RandomForestImage& randomForest,
158  const std::vector<LabeledRGBDImage>& testImages);
159 
160  double getParameterDouble(const mongo::BSONObj& task, const std::string& field);
161 
162  static double getAverageLossAndVariance(const std::vector<Result>& results, double& variance);
163 
164  static LossFunctionType parseLossFunction(const std::string& lossFunction);
165 
166 public:
167 
172  const std::vector<LabeledRGBDImage>& allRGBDImages,
173  const std::vector<LabeledRGBDImage>& allTestImages,
174  bool useCIELab,
175  bool useDepthFilling,
176  const std::vector<int>& deviceIds,
177  int maxImages,
178  int imageCacheSizeMB,
179  int randomSeed,
180  int numThreads,
181  const std::string& subsamplingType,
182  const std::vector<std::string>& ignoredColors,
183  bool useDepthImages,
184  bool horizontalFlipping,
185  size_t numLabels,
186  const std::string& lossFunction,
187  const std::string& url, const std::string& db, const mongo::BSONObj& jobSelector) :
188  Client(url, db, jobSelector),
189  allRGBDImages(allRGBDImages),
190  allTestImages(allTestImages),
191  useCIELab(useCIELab),
192  useDepthFilling(useDepthFilling),
193  deviceIds(deviceIds),
194  maxImages(maxImages),
195  imageCacheSizeMB(imageCacheSizeMB),
196  randomSeed(randomSeed),
197  numThreads(numThreads),
198  subsamplingType(subsamplingType),
199  ignoredColors(ignoredColors),
200  useDepthImages(useDepthImages),
201  horizontalFlipping(horizontalFlipping),
202  numLabels(numLabels),
203  lossFunction(parseLossFunction(lossFunction))
204  {
205  }
206 
210  void handle_task(const mongo::BSONObj& task);
211 
212  void run();
213 };
214 
215 }
216 
217 #endif