1 #ifndef CURFIL_HYPEROPT
2 #define CURFIL_HYPEROPT
4 #include <boost/asio/io_service.hpp>
5 #include <mdbq/client.hpp>
6 #include <mongo/bson/bson.h>
10 #include "random_forest_image.h"
11 #include "random_tree_image.h"
15 bool continueSearching(
const std::vector<double>& currentBestAccuracies,
16 const std::vector<double>& currentRunAccuracies);
18 enum LossFunctionType {
20 CLASS_ACCURACY_WITHOUT_VOID,
22 PIXEL_ACCURACY_WITHOUT_VOID
34 double pixelAccuracyWithoutVoid;
35 LossFunctionType lossFunctionType;
43 double pixelAccuracyWithoutVoid,
const LossFunctionType lossFunctionType) :
44 confusionMatrix(confusionMatrix),
45 pixelAccuracy(pixelAccuracy),
46 pixelAccuracyWithoutVoid(pixelAccuracyWithoutVoid),
47 lossFunctionType(lossFunctionType),
54 mongo::BSONObj
toBSON()
const;
60 return confusionMatrix;
67 this->lossFunctionType = lossFunctionType;
100 return pixelAccuracyWithoutVoid;
107 this->randomSeed = randomSeed;
127 const std::vector<LabeledRGBDImage>& allRGBDImages;
128 const std::vector<LabeledRGBDImage>& allTestImages;
131 bool useDepthFilling;
132 std::vector<int> deviceIds;
134 int imageCacheSizeMB;
137 std::string subsamplingType;
138 std::vector<std::string> ignoredColors;
140 bool horizontalFlipping;
142 LossFunctionType lossFunction;
144 boost::asio::io_service ios;
148 const std::vector<LabeledRGBDImage>& trainImages);
150 void randomSplit(
const int randomSeed,
const double testRatio,
151 std::vector<LabeledRGBDImage>& trainImages,
152 std::vector<LabeledRGBDImage>& testImages);
155 const double histogramBias,
double& variance);
158 const std::vector<LabeledRGBDImage>& testImages);
160 double getParameterDouble(
const mongo::BSONObj& task,
const std::string& field);
162 static double getAverageLossAndVariance(
const std::vector<Result>& results,
double& variance);
164 static LossFunctionType parseLossFunction(
const std::string& lossFunction);
172 const std::vector<LabeledRGBDImage>& allRGBDImages,
173 const std::vector<LabeledRGBDImage>& allTestImages,
175 bool useDepthFilling,
176 const std::vector<int>& deviceIds,
178 int imageCacheSizeMB,
181 const std::string& subsamplingType,
182 const std::vector<std::string>& ignoredColors,
184 bool horizontalFlipping,
186 const std::string& lossFunction,
187 const std::string& url,
const std::string& db,
const mongo::BSONObj& jobSelector) :
188 Client(url, db, jobSelector),
189 allRGBDImages(allRGBDImages),
190 allTestImages(allTestImages),
191 useCIELab(useCIELab),
192 useDepthFilling(useDepthFilling),
193 deviceIds(deviceIds),
194 maxImages(maxImages),
195 imageCacheSizeMB(imageCacheSizeMB),
196 randomSeed(randomSeed),
197 numThreads(numThreads),
198 subsamplingType(subsamplingType),
199 ignoredColors(ignoredColors),
200 useDepthImages(useDepthImages),
201 horizontalFlipping(horizontalFlipping),
202 numLabels(numLabels),
203 lossFunction(parseLossFunction(lossFunction))