1 #ifndef CURFIL_RANDOMTREE_H
2 #define CURFIL_RANDOMTREE_H
4 #include <boost/format.hpp>
5 #include <boost/lexical_cast.hpp>
6 #include <boost/random.hpp>
7 #include <boost/shared_ptr.hpp>
8 #include <boost/weak_ptr.hpp>
11 #include <cuv/ndarray.hpp>
17 #include <tbb/concurrent_vector.h>
18 #include <tbb/parallel_for.h>
27 typedef uint8_t LabelType;
28 typedef unsigned int WeightType;
29 typedef double FeatureResponseType;
31 enum HorizontalFlipSetting {
32 NoFlip = 0, Flip = 1, Both = 2
37 bool isScoreBetter(
const ScoreType bestScore,
const ScoreType score,
const int featureNr);
42 const double histogramBias);
51 enum AccelerationMode {
70 SplitFunction(
size_t featureId,
const FeatureFunction& feature,
float threshold, ScoreType score) :
71 featureId(featureId), feature(feature), threshold(threshold), score(score) {
75 featureId(0), feature(), threshold(std::numeric_limits<float>::quiet_NaN()), score(
76 std::numeric_limits<float>::quiet_NaN()) {
83 assert(!isnan(other.threshold));
84 assert(!isnan(other.score));
85 featureId = other.featureId;
86 feature = other.feature;
87 threshold = other.threshold;
94 SplitBranch
split(
const Instance& instance,
bool& flippedSameSplit)
const {
95 HorizontalFlipSetting horFlipSetting = instance.getHorFlipSetting();
96 flippedSameSplit =
true;
97 bool splitValue, value2;
98 switch (horFlipSetting) {
100 splitValue = feature.calculateFeatureResponse(instance) <=
getThreshold();
103 splitValue = feature.calculateFeatureResponse(instance,
true) <=
getThreshold();
106 splitValue = feature.calculateFeatureResponse(instance) <=
getThreshold();
107 value2 = feature.calculateFeatureResponse(instance,
true) <=
getThreshold();
108 if (splitValue != value2)
109 flippedSameSplit =
false;
112 splitValue = feature.calculateFeatureResponse(instance) <=
getThreshold();
115 return ((splitValue) ? LEFT : RIGHT);
129 assert(!isnan(threshold));
137 assert(!isnan(score));
150 FeatureFunction feature;
175 maxSamplesPerBatch(0),
176 accelerationMode(GPU_ONLY),
183 horizontalFlipping(0) {
195 unsigned int samplesPerImage,
196 unsigned int featureCount,
197 unsigned int minSampleCount,
205 unsigned int maxSamplesPerBatch,
206 AccelerationMode accelerationMode,
207 bool useCIELab =
true,
208 bool useDepthFilling =
false,
209 const std::vector<int> deviceIds = std::vector<int>(1, 0),
210 const std::string subsamplingType =
"classUniform",
211 const std::vector<std::string>& ignoredColors = std::vector<std::string>(),
212 bool useDepthImages =
true,
213 bool horizontalFlipping =
false) :
214 randomSeed(randomSeed),
215 samplesPerImage(samplesPerImage),
216 featureCount(featureCount),
217 minSampleCount(minSampleCount),
219 boxRadius(boxRadius),
220 regionSize(regionSize),
221 thresholds(thresholds),
222 numThreads(numThreads),
223 maxImages(maxImages),
224 imageCacheSize(imageCacheSize),
225 maxSamplesPerBatch(maxSamplesPerBatch),
226 accelerationMode(accelerationMode),
227 useCIELab(useCIELab),
228 useDepthFilling(useDepthFilling),
229 deviceIds(deviceIds),
230 subsamplingType(subsamplingType),
231 ignoredColors(ignoredColors),
232 useDepthImages(useDepthImages),
233 horizontalFlipping(horizontalFlipping)
235 for (
size_t c = 0; c < ignoredColors.size(); c++) {
236 if (ignoredColors[c].empty()) {
237 throw std::runtime_error(std::string(
"illegal color: '") + ignoredColors[c] +
"'");
240 if (maxImages > 0 && maxImages < imageCacheSize) {
241 throw std::runtime_error(
242 (boost::format(
"illegal configuration: maxImages (%d) must not be lower than imageCacheSize (%d)")
243 % maxImages % imageCacheSize).str());
251 this->randomSeed = randomSeed;
265 return samplesPerImage;
279 return minSampleCount;
328 return imageCacheSize;
335 return maxSamplesPerBatch;
342 return accelerationMode;
349 this->accelerationMode = accelerationMode;
373 this->deviceIds = deviceIds;
380 return subsamplingType;
394 return useDepthFilling;
401 return useDepthImages;
408 return horizontalFlipping;
415 return ignoredColors;
432 return this->
equals(other,
true);
439 return (!(*
this == other));
446 unsigned int samplesPerImage;
447 unsigned int featureCount;
448 unsigned int minSampleCount;
456 unsigned int maxSamplesPerBatch;
457 AccelerationMode accelerationMode;
459 bool useDepthFilling;
460 std::vector<int> deviceIds;
461 std::string subsamplingType;
462 std::vector<std::string> ignoredColors;
464 bool horizontalFlipping;
471 template<
class Instance,
class FeatureFunction>
480 const std::vector<const Instance*>& samples,
size_t numClasses,
483 nodeId(nodeId), level(level), parent(parent), leaf(true), trainSamples(),
484 numClasses(numClasses), histogram(numClasses), allPixelsHistogram(numClasses), timers(),
485 split(), left(), right() {
487 assert(histogram.
ndim() == 1);
488 for (
size_t label = 0; label < numClasses; label++) {
489 histogram[label] = 0;
490 allPixelsHistogram[label] = 0;
493 for (
size_t i = 0; i < samples.size(); i++) {
494 histogram[samples[i]->getLabel()] += samples[i]->getWeight();
495 trainSamples.push_back(*samples[i]);
496 if (samples[i]->getHorFlipSetting() == Both) {
497 histogram[samples[i]->getLabel()] += samples[i]->getWeight();
507 const std::vector<WeightType>& histogram) :
508 nodeId(nodeId), level(level), parent(parent), leaf(true), trainSamples(),
509 numClasses(histogram.size()), histogram(histogram.size()), timers(),
510 split(), left(), right() {
512 for (
size_t i = 0; i < histogram.size(); i++) {
513 this->histogram[i] = histogram[i];
522 size_t nonZeroClasses = 0;
523 for (
size_t i = 0; i < histogram.
size(); i++) {
524 const WeightType classCount = histogram[i];
525 if (classCount == 0) {
529 if (nonZeroClasses > 1) {
533 assert(nonZeroClasses > 0);
534 return (nonZeroClasses == 1);
541 assert(histogram.
size() == numClasses);
549 std::set<unsigned int>& nodeSet,
bool includeRoot)
const {
551 if (!
isRoot() || includeRoot) {
552 nodeSet.insert(nodeId);
557 boost::shared_ptr<RandomTree<Instance, FeatureFunction> > node;
558 if (split.split(instance) == LEFT) {
563 assert(node != NULL);
564 node->collectNodeIndices(instance, nodeSet, includeRoot);
577 left->collectLeafNodes(leafSet);
578 right->collectLeafNodes(leafSet);
585 LabelType
classify(
const Instance& instance)
const {
586 return traverseToLeaf(instance)->getDominantClass();
600 return trainSamples.size();
614 return timerAnnotations;
629 timerAnnotations[key] = boost::lexical_cast<std::string>(annotation);
643 timers[key] = timeInSeconds;
650 timers[key] += timeInSeconds;
656 boost::shared_ptr<const RandomTree<Instance, FeatureFunction> >
getLeft()
const {
663 boost::shared_ptr<const RandomTree<Instance, FeatureFunction> >
getRight()
const {
678 return (parent.lock().get() == NULL);
688 return parent.lock()->getRoot();
698 std::string featureType(split.getFeature().getTypeString());
699 featureCounts[featureType]++;
701 left->countFeatures(featureCounts);
702 right->countFeatures(featureCounts);
709 return (
isLeaf() ? 1 : (1 + left->countNodes() + right->countNodes()));
716 return (
isLeaf() ? 1 : (left->countLeafNodes() + right->countLeafNodes()));
725 return (1 + std::max(left->getTreeDepth(), right->getTreeDepth()));
743 assert(left->getNodeId() > this->
getNodeId());
744 assert(right->getNodeId() > this->
getNodeId());
746 assert(left->parent.lock().get() ==
this);
747 assert(right->parent.lock().get() ==
this);
749 assert(this->left.get() == 0);
750 assert(this->right.get() == 0);
786 size_t rootNodeId =
getRoot()->getNodeId();
795 const double histogramBias) {
797 normalizedHistogram = detail::normalizeHistogram(histogram, priorDistribution, histogramBias);
798 assert(normalizedHistogram.
shape() == histogram.
shape());
802 for (
size_t i = 0; i < normalizedHistogram.
size(); i++) {
803 sum += normalizedHistogram[i];
806 CURFIL_INFO(
"normalized histogram of node " <<
getNodeId() <<
", level " <<
getLevel() <<
" has zero sum");
809 left->normalizeHistograms(priorDistribution, histogramBias);
810 right->normalizeHistograms(priorDistribution, histogramBias);
813 if (normalizedHistogram.
shape() != histogram.
shape()) {
814 CURFIL_ERROR(
"node: " << nodeId <<
" (level " << level <<
")");
815 CURFIL_ERROR(
"histogram: " << histogram);
816 CURFIL_ERROR(
"normalized histogram: " << normalizedHistogram);
817 throw std::runtime_error(
"failed to normalize histogram");
832 if (normalizedHistogram.
shape() != histogram.
shape()) {
833 CURFIL_ERROR(
"node: " << nodeId <<
" (level " << level <<
")");
834 CURFIL_ERROR(
"histogram: " << histogram);
835 CURFIL_ERROR(
"normalized histogram: " << normalizedHistogram);
836 throw std::runtime_error(
"histogram not normalized");
838 return normalizedHistogram;
853 allPixelsHistogram[label] += value;
862 size_t label = instance.getLabel();
863 allPixelsHistogram[label] += 1;
870 bool flippedSameSplit;
871 if (split.split(instance,flippedSameSplit) == LEFT) {
872 return left->setAllPixelsHistogram(instance);
874 return right->setAllPixelsHistogram(instance);
885 for (
size_t label = 0; label < numClasses; label++) {
887 if (histogram[label] != 0)
888 { histogram[label] = allPixelsHistogram[label];}
892 left->updateHistograms();
893 right->updateHistograms();
902 for (
size_t label = 0; label < numClasses; label++) {
903 histogram[label] = 0;
906 for (
size_t i = 0; i < samples.size(); i++) {
907 histogram[samples[i]->getLabel()] += samples[i]->getWeight();
922 const boost::weak_ptr<RandomTree<Instance, FeatureFunction> > parent;
927 std::vector<Instance> trainSamples;
936 std::map<std::string, double> timers;
937 std::map<std::string, std::string> timerAnnotations;
942 boost::shared_ptr<RandomTree<Instance, FeatureFunction> > left;
943 boost::shared_ptr<RandomTree<Instance, FeatureFunction> > right;
945 LabelType getDominantClass()
const {
946 double max = std::numeric_limits<double>::quiet_NaN();
947 assert(histogram.
size() == numClasses);
948 LabelType maxClass = 0;
950 for (LabelType classNr = 0; classNr < histogram.
size(); classNr++) {
951 const WeightType& count = histogram[classNr];
952 assert(count >= 0.0);
953 if (isnan(max) || count > max) {
961 const RandomTree<Instance, FeatureFunction>* traverseToLeaf(
const Instance& instance)
const {
968 bool flippedSameSplit;
969 if (split.split(instance, flippedSameSplit) == LEFT) {
970 return left->traverseToLeaf(instance);
972 return right->traverseToLeaf(instance);
988 seed(seed), lower(lower), upper(upper), rng(seed), distribution(lower, upper) {
989 assert(upper >= lower);
996 seed(other.seed), lower(other.lower), upper(other.upper), rng(seed), distribution(lower, upper) {
1034 boost::uniform_int<> distribution;
1046 count(0), samples(0), reservoir() {
1053 count(0), samples(samples), reservoir() {
1054 reservoir.reserve(samples);
1055 assert(reservoir.empty());
1064 throw std::runtime_error(
"no samples to sample");
1067 assert(sampler.
getUpper() >
static_cast<int>(count));
1069 if (reservoir.size() < samples) {
1070 reservoir.push_back(sample);
1072 assert(count >= samples);
1074 size_t rand = sampler.
getNext() % (count + 1);
1075 if (rand < samples) {
1076 reservoir[rand] =
sample;
1093 std::vector<T> reservoir;
1126 return Sampler(seed++, lower, upper);
1134 template<
class Instance,
class FeatureEvaluation,
class FeatureFunction>
1148 id(id), numClasses(numClasses), configuration(configuration) {
1155 if (node->hasPureHistogram()) {
1165 typedef boost::shared_ptr<RandomTree<Instance, FeatureFunction> > RandomTreePointer;
1166 typedef std::vector<const Instance*> Samples;
1171 const SplitFunction<Instance, FeatureFunction>& bestSplit,
size_t histogramSize)
const {
1173 const unsigned int leftRightStride = 1;
1176 double totalLeft = sum(leftHistogram);
1177 double totalRight = sum(rightHistogram);
1179 WeightType leftHistogramArray[histogramSize];
1180 WeightType rightHistogramArray[histogramSize];
1181 WeightType allHistogramArray[histogramSize];
1183 std::stringstream strLeft;
1184 std::stringstream strRight;
1185 std::stringstream strAll;
1187 for (
size_t i=0; i<histogramSize; i++)
1189 leftHistogramArray[i] = leftHistogram[i];
1190 rightHistogramArray[i] = rightHistogram[i];
1191 allHistogramArray[i] = allHistogram[i];
1193 strLeft<<leftHistogramArray[i]<<
",";
1194 strRight<<rightHistogramArray[i]<<
",";
1195 strAll<<allHistogramArray[i]<<
",";
1199 allHistogramArray, totalLeft, totalRight);
1200 double diff = std::fabs(actualScore - bestSplit.getScore());
1202 std::ostringstream o;
1204 o <<
"actual score and best split score differ: " << diff << std::endl;
1205 o <<
"actual score: " << actualScore << std::endl;
1206 o <<
"best split score: " << bestSplit.getScore() << std::endl;
1207 o <<
"total left: " << totalLeft << std::endl;
1208 o <<
"total right: " << totalRight << std::endl;
1209 o <<
"histogram: " << strAll.str() << std::endl;
1210 o <<
"histogram left: " << strLeft.str() << std::endl;
1211 o <<
"histogram right: " << strRight.str() << std::endl;
1212 throw std::runtime_error(o.str());
1221 void train(FeatureEvaluation& featureEvaluation,
1223 const std::vector<std::pair<RandomTreePointer, Samples> >& samplesPerNode,
1224 int idNode,
int currentLevel = 1)
const {
1227 if (currentLevel == configuration.
getMaxDepth()) {
1231 CURFIL_INFO(
"training level " << currentLevel <<
". nodes: " << samplesPerNode.size());
1235 std::vector<std::pair<RandomTreePointer, Samples> > samplesPerNodeNextLevel;
1237 std::vector<SplitFunction<Instance, FeatureFunction> > bestSplits = featureEvaluation.evaluateBestSplits(
1238 randomSource, samplesPerNode);
1240 assert(bestSplits.size() == samplesPerNode.size());
1242 std::vector<boost::shared_ptr<Instance> > flipped;
1244 for (
size_t i = 0; i < samplesPerNode.size(); i++) {
1246 const std::pair<RandomTreePointer, Samples>& it = samplesPerNode[i];
1248 const std::vector<const Instance*>& samples = it.second;
1250 boost::shared_ptr<RandomTree<Instance, FeatureFunction> > currentNode = it.first;
1251 assert(currentNode);
1258 std::vector<const Instance*> samplesLeft;
1259 std::vector<const Instance*> samplesRight;
1261 size_t numClasses = currentNode->getHistogram().size();
1266 for (
size_t c=0; c<numClasses; c++)
1268 leftHistogram[c] = 0;
1269 rightHistogram[c] = 0;
1270 allHistogram[c] = 0;
1273 unsigned int totalFlipped = 0;
1274 unsigned int rightFlipped = 0;
1275 unsigned int leftFlipped = 0;
1279 for (
size_t sample = 0; sample < samples.size(); sample++) {
1280 assert(samples[sample] != NULL);
1281 bool flippedSameSplit;
1282 Instance* ptr =
const_cast<Instance *
>(samples[sample]);
1284 SplitBranch splitResult = bestSplit.
split(*ptr, flippedSameSplit);
1286 HorizontalFlipSetting flipSetting = samples[sample]->getHorFlipSetting();
1288 if (flipSetting == Both && !flippedSameSplit)
1289 {ptr->setHorFlipSetting(NoFlip);
1292 allHistogram[samples[sample]->getLabel()] += samples[sample]->getWeight();
1293 if (splitResult == LEFT) {
1294 samplesLeft.push_back(ptr);
1295 leftHistogram[samples[sample]->getLabel()] += samples[sample]->getWeight();
1296 if (flipSetting == Both) {
1298 allHistogram[samples[sample]->getLabel()] += samples[sample]->getWeight();
1299 if (flippedSameSplit)
1300 leftHistogram[samples[sample]->getLabel()] += samples[sample]->getWeight();
1302 {rightHistogram[samples[sample]->getLabel()] += samples[sample]->getWeight();
1304 flipped.push_back(boost::shared_ptr<Instance>(
new Instance((samples[sample]->getRGBDImage()), samples[sample]->getLabel(), samples[sample]->getX(), samples[sample]->getY(), Flip)));
1305 samplesRight.push_back(flipped.back().get());
1309 samplesRight.push_back(ptr);
1310 rightHistogram[samples[sample]->getLabel()] += samples[sample]->getWeight();
1311 if (flipSetting == Both) {
1313 allHistogram[samples[sample]->getLabel()] += samples[sample]->getWeight();
1314 if (flippedSameSplit)
1315 rightHistogram[samples[sample]->getLabel()] += samples[sample]->getWeight();
1317 {leftHistogram[samples[sample]->getLabel()] += samples[sample]->getWeight();
1319 flipped.push_back(boost::shared_ptr<Instance>(
new Instance((samples[sample]->getRGBDImage()), samples[sample]->getLabel(), samples[sample]->getX(), samples[sample]->getY(), Flip)));
1320 samplesLeft.push_back(flipped.back().get());
1326 assert(samplesLeft.size() + samplesRight.size() == samples.size() + off);
1328 boost::shared_ptr<RandomTree<Instance, FeatureFunction> > leftNode = boost::make_shared<
RandomTree<Instance,
1329 FeatureFunction> >(++idNode, currentNode->getLevel() + 1, samplesLeft, numClasses, currentNode);
1331 boost::shared_ptr<RandomTree<Instance, FeatureFunction> > rightNode = boost::make_shared<
1333 numClasses, currentNode);
1336 compareHistograms(allHistogram, leftHistogram, rightHistogram, bestSplit, numClasses);
1339 bool errorEmptyChildren =
false;
1340 if (samplesLeft.empty() && leftFlipped == 0)
1342 errorEmptyChildren =
true;
1344 if (samplesRight.empty() && rightFlipped == 0)
1346 errorEmptyChildren =
true;
1348 if (errorEmptyChildren) {
1349 CURFIL_ERROR(
"best split score: " << bestSplit.
getScore());
1350 CURFIL_ERROR(
"samples: " << samples.size());
1351 CURFIL_ERROR(
"threshold: " << bestSplit.
getThreshold());
1352 CURFIL_ERROR(
"feature: " << bestSplit.
getFeature());
1353 CURFIL_ERROR(
"histogram: " << currentNode->getHistogram());
1354 CURFIL_ERROR(
"samplesLeft: " << samplesLeft.size());
1355 CURFIL_ERROR(
"samplesRight: " << samplesRight.size());
1356 CURFIL_ERROR(
"leftFlipped "<<leftFlipped<<
" rightFlipped "<<rightFlipped<<
" totalFlipped "<<totalFlipped)
1358 compareHistograms(allHistogram, leftHistogram, rightHistogram, bestSplit, numClasses);
1360 if (samplesLeft.empty()) {
1361 throw std::runtime_error(
"no samples in left node");
1363 if (samplesRight.empty()) {
1364 throw std::runtime_error(
"no samples in right node");
1368 if (!samplesLeft.empty() && !samplesRight.empty()) {
1369 currentNode->addChildren(bestSplit, leftNode, rightNode);
1371 if (shouldContinueGrowing(leftNode)) {
1372 samplesPerNodeNextLevel.push_back(std::make_pair(leftNode, samplesLeft));
1375 if (shouldContinueGrowing(rightNode)) {
1376 samplesPerNodeNextLevel.push_back(std::make_pair(rightNode, samplesRight));
1380 idNode = idNode - 2;
1383 CURFIL_INFO(
"training level " << currentLevel <<
" took " << trainTimer.format(3));
1384 if (!samplesPerNodeNextLevel.empty()) {
1385 train(featureEvaluation, randomSource, samplesPerNodeNextLevel, idNode, currentLevel + 1);
1394 for (
size_t i = 0; i < vector.
size(); i++) {
1402 const TrainingConfiguration configuration;