curfil  ..
 All Classes Functions Variables Typedefs Friends Groups Pages
random_tree_image.h
1 #ifndef CURFIL_RANDOMTREEIMAGE_H
2 #define CURFIL_RANDOMTREEIMAGE_H
3 
4 #include <algorithm>
5 #include <assert.h>
6 #include <boost/make_shared.hpp>
7 #include <cuv/ndarray.hpp>
8 #include <list>
9 #include <stdint.h>
10 #include <vector>
11 
12 #include "image.h"
13 #include "random_tree.h"
14 
15 namespace curfil {
16 
25 class XY {
26 
27 public:
28  XY() :
29  x(0), y(0) {
30  }
34  XY(int x, int y) :
35  x(x), y(y) {
36  }
40  XY(const XY& other) :
41  x(other.x), y(other.y) {
42  }
46  XY& operator=(const XY& other) {
47  x = other.x;
48  y = other.y;
49  return (*this);
50  }
51 
55  XY normalize(const Depth& depth) const {
56  assert(depth.isValid());
57  int newX = static_cast<int>(x / depth.getFloatValue());
58  int newY = static_cast<int>(y / depth.getFloatValue());
59  return XY(newX, newY);
60  }
61 
65  bool operator==(const XY& other) const {
66  return (x == other.x && y == other.y);
67  }
68 
72  bool operator!=(const XY& other) const {
73  return !(*this == other);
74  }
75 
79  int getX() const {
80  return x;
81  }
82 
86  int getY() const {
87  return y;
88  }
89 
90 private:
91  int x, y;
92 };
93 
94 typedef XY Region;
95 typedef XY Offset;
96 typedef XY Point;
97 
103 
104 public:
105 
113  PixelInstance(const RGBDImage* image, const LabelType& label, uint16_t x, uint16_t y, HorizontalFlipSetting setting = NoFlip) :
114  image(image), label(label), point(x, y), depth(Depth::INVALID), horFlipSetting(setting) {
115  assert(image != NULL);
116  assert(image->inImage(x, y));
117  if (!image->hasIntegratedDepth()) {
118  throw std::runtime_error("image is not integrated");
119  }
120 
121  int aboveValid = (y > 0) ? image->getDepthValid(x, y - 1) : 0;
122  int leftValid = (x > 0) ? image->getDepthValid(x - 1, y) : 0;
123  int aboveLeftValid = (x > 0 && y > 0) ? image->getDepthValid(x - 1, y - 1) : 0;
124 
125  int valid = image->getDepthValid(x, y) - (leftValid + aboveValid - aboveLeftValid);
126  assert(valid == 0 || valid == 1);
127 
128  if (valid == 1) {
129  Depth above = (y > 0) ? image->getDepth(x, y - 1) : Depth(0);
130  Depth left = (x > 0) ? image->getDepth(x - 1, y) : Depth(0);
131  Depth aboveLeft = (x > 0 && y > 0) ? image->getDepth(x - 1, y - 1) : Depth(0);
132 
133  depth = image->getDepth(x, y) - (left + above - aboveLeft);
134  assert(depth.isValid());
135  } else {
136  assert(!depth.isValid());
137  }
138  }
139 
148  PixelInstance(const RGBDImage* image, const LabelType& label, const Depth& depth,
149  uint16_t x, uint16_t y, HorizontalFlipSetting setting = NoFlip) :
150  image(image), label(label), point(x, y), depth(depth), horFlipSetting(setting) {
151  assert(image != NULL);
152  assert(image->inImage(x, y));
153  assert(depth.isValid());
154  }
155 
159  const RGBDImage* getRGBDImage() const {
160  return image;
161  }
162 
166  int width() const {
167  return image->getWidth();
168  }
169 
173  int height() const {
174  return image->getHeight();
175  }
176 
180  uint16_t getX() const {
181  return static_cast<uint16_t>(point.getX());
182  }
183 
187  uint16_t getY() const {
188  return static_cast<uint16_t>(point.getY());
189  }
190 
195  FeatureResponseType averageRegionColor(const Offset& offset, const Region& region, uint8_t channel) const {
196 
197  assert(region.getX() >= 0);
198  assert(region.getY() >= 0);
199 
200  assert(image->hasIntegratedColor());
201 
202  const int width = std::max(1, region.getX());
203  const int height = std::max(1, region.getY());
204 
205  int x = getX() + offset.getX();
206  int y = getY() + offset.getY();
207 
208  int leftX = x - width;
209  int rightX = x + width;
210  int upperY = y - height;
211  int lowerY = y + height;
212 
213  if (leftX < 0 || rightX >= image->getWidth() || upperY < 0 || lowerY >= image->getHeight()) {
214  return std::numeric_limits<double>::quiet_NaN();
215  }
216 
217  assert(inImage(x, y));
218 
219  Point upperLeft(leftX, upperY);
220  Point upperRight(rightX, upperY);
221  Point lowerLeft(leftX, lowerY);
222  Point lowerRight(rightX, lowerY);
223 
224  FeatureResponseType lowerRightPixel = getColor(lowerRight, channel);
225  FeatureResponseType lowerLeftPixel = getColor(lowerLeft, channel);
226  FeatureResponseType upperRightPixel = getColor(upperRight, channel);
227  FeatureResponseType upperLeftPixel = getColor(upperLeft, channel);
228 
229  if (isnan(lowerRightPixel) || isnan(lowerLeftPixel) || isnan(upperRightPixel) || isnan(upperLeftPixel))
230  return std::numeric_limits<double>::quiet_NaN();
231 
232  FeatureResponseType sum = (lowerRightPixel - upperRightPixel) + (upperLeftPixel - lowerLeftPixel);
233 
234  return sum;
235  }
236 
241  FeatureResponseType averageRegionDepth(const Offset& offset, const Region& region) const {
242  assert(region.getX() >= 0);
243  assert(region.getY() >= 0);
244 
245  assert(image->hasIntegratedDepth());
246 
247  const int width = std::max(1, region.getX());
248  const int height = std::max(1, region.getY());
249 
250  int x = getX() + offset.getX();
251  int y = getY() + offset.getY();
252 
253  int leftX = x - width;
254  int rightX = x + width;
255  int upperY = y - height;
256  int lowerY = y + height;
257 
258  if (leftX < 0 || rightX >= image->getWidth() || upperY < 0 || lowerY >= image->getHeight()) {
259  return std::numeric_limits<double>::quiet_NaN();
260  }
261 
262  assert(inImage(x, y));
263 
264  Point upperLeft(leftX, upperY);
265  Point upperRight(rightX, upperY);
266  Point lowerLeft(leftX, lowerY);
267  Point lowerRight(rightX, lowerY);
268 
269  int upperLeftValid = getDepthValid(upperLeft);
270  int upperRightValid = getDepthValid(upperRight);
271  int lowerRightValid = getDepthValid(lowerRight);
272  int lowerLeftValid = getDepthValid(lowerLeft);
273 
274  int numValid = (lowerRightValid - upperRightValid) + (upperLeftValid - lowerLeftValid);
275  assert(numValid >= 0);
276 
277  if (numValid == 0) {
278  return std::numeric_limits<double>::quiet_NaN();
279  }
280 
281  const int lowerRightDepth = getDepth(lowerRight).getIntValue();
282  const int lowerLeftDepth = getDepth(lowerLeft).getIntValue();
283  const int upperRightDepth = getDepth(upperRight).getIntValue();
284  const int upperLeftDepth = getDepth(upperLeft).getIntValue();
285 
286  int sum = (lowerRightDepth - upperRightDepth) + (upperLeftDepth - lowerLeftDepth);
287  FeatureResponseType feat = sum / static_cast<FeatureResponseType>(1000);
288  return (feat / numValid);
289  }
290 
294  LabelType getLabel() const {
295  return label;
296  }
297 
301  const Depth getDepth() const {
302  return depth;
303  }
304 
308  WeightType getWeight() const {
309  return 1;
310  }
311 
315  HorizontalFlipSetting getHorFlipSetting() const
316  {
317  return horFlipSetting;
318  }
319 
323  void setHorFlipSetting(HorizontalFlipSetting setting)
324  {
325  horFlipSetting = setting;
326  }
327 
328 private:
329  const RGBDImage* image;
330  LabelType label;
331  Point point;
332  Depth depth;
333  HorizontalFlipSetting horFlipSetting;
334 
335  float getColor(const Point& pos, uint8_t channel) const {
336  if (!inImage(pos)) {
337  return std::numeric_limits<float>::quiet_NaN();
338  }
339  assert(image->hasIntegratedColor());
340  return image->getColor(pos.getX(), pos.getY(), channel);
341  }
342 
343  Depth getDepth(const Point& pos) const {
344  if (!inImage(pos)) {
345  return Depth::INVALID;
346  }
347  assert(image->hasIntegratedDepth());
348  const Depth depth = image->getDepth(pos.getX(), pos.getY());
349  // include zero as it is an integral
350  assert(depth.getIntValue() >= 0);
351  return depth;
352  }
353 
354  int getDepthValid(const Point& pos) const {
355  return image->getDepthValid(pos.getX(), pos.getY());
356  }
357 
358  bool inImage(int x, int y) const {
359  return image->inImage(x, y);
360  }
361 
362  bool inImage(const Point& pos) const {
363  return inImage(pos.getX(), pos.getY());
364  }
365 };
366 
367 enum FeatureType {
368  DEPTH = 0, COLOR = 1
369 };
370 
378 
379 public:
380 
390  ImageFeatureFunction(FeatureType featureType,
391  const Offset& offset1,
392  const Region& region1,
393  const uint8_t channel1,
394  const Offset& offset2,
395  const Region& region2,
396  const uint8_t channel2) :
397  featureType(featureType),
398  offset1(offset1),
399  region1(region1),
400  channel1(channel1),
401  offset2(offset2),
402  region2(region2),
403  channel2(channel2) {
404  if (offset1 == offset2) {
405  throw std::runtime_error("illegal feature: offset1 equals offset2");
406  }
407  assert(isValid());
408  }
409 
411  featureType(), offset1(), region1(), channel1(), offset2(), region2(), channel2() {
412  }
413 
417  int getSortKey() const {
418  int32_t sortKey = 0;
419  sortKey |= static_cast<uint8_t>(getType() & 0x03) << 30; // 2 bit for the type
420  sortKey |= static_cast<uint8_t>(getChannel1() & 0x0F) << 26; // 4 bit for channel1
421  sortKey |= static_cast<uint8_t>(getChannel2() & 0x0F) << 22; // 4 bit for channel2
422  sortKey |= static_cast<uint8_t>((getOffset1().getY() + 127) & 0xFF) << 14; // 8 bit for offset1.y
423  sortKey |= static_cast<uint8_t>((getOffset1().getX() + 127) & 0xFF) << 6; // 8 bit for offset1.x
424  return sortKey;
425  }
426 
430  FeatureType getType() const {
431  return featureType;
432  }
433 
437  std::string getTypeString() const {
438  switch (featureType) {
439  case COLOR:
440  return "color";
441  case DEPTH:
442  return "depth";
443  default:
444  throw std::runtime_error("unknown feature");
445  }
446  }
447 
451  bool isValid() const {
452  return (offset1 != offset2);
453  }
454 
458  FeatureResponseType calculateFeatureResponse(const PixelInstance& instance, bool flipRegion = false) const {
459  assert(isValid());
460  switch (featureType) {
461  case DEPTH:
462  return calculateDepthFeature(instance, flipRegion);
463  case COLOR:
464  return calculateColorFeature(instance, flipRegion);
465  default:
466  assert(false);
467  break;
468  }
469  return 0;
470  }
471 
475  const Offset& getOffset1() const {
476  return offset1;
477  }
478 
482  const Region& getRegion1() const {
483  return region1;
484  }
485 
489  uint8_t getChannel1() const {
490  return channel1;
491  }
492 
496  const Offset& getOffset2() const {
497  return offset2;
498  }
499 
503  const Region& getRegion2() const {
504  return region2;
505  }
506 
510  uint8_t getChannel2() const {
511  return channel2;
512  }
513 
517  bool operator!=(const ImageFeatureFunction& other) const {
518  return !(*this == other);
519  }
520 
524  bool operator==(const ImageFeatureFunction& other) const;
525 
526 private:
527  FeatureType featureType;
528 
529  Offset offset1;
530  Region region1;
531  uint8_t channel1;
532 
533  Offset offset2;
534  Region region2;
535  uint8_t channel2;
536 
537  FeatureResponseType calculateColorFeature(const PixelInstance& instance, bool flipRegion) const {
538 
539  const Depth depth = instance.getDepth();
540  if (!depth.isValid()) {
541  return std::numeric_limits<double>::quiet_NaN();
542  }
543 
544  FeatureResponseType a;
545  if (flipRegion)
546  a = instance.averageRegionColor(Offset(-offset1.getX(),offset1.getY()).normalize(depth), region1.normalize(depth),
547  channel1);
548  else
549  a = instance.averageRegionColor(offset1.normalize(depth), region1.normalize(depth),
550  channel1);
551  if (isnan(a))
552  return a;
553 
554  FeatureResponseType b;
555  if (flipRegion)
556  b = instance.averageRegionColor(Offset(-offset2.getX(),offset2.getY()).normalize(depth), region2.normalize(depth),
557  channel2);
558  else
559  b = instance.averageRegionColor(offset2.normalize(depth), region2.normalize(depth),
560  channel2);
561  if (isnan(b))
562  return b;
563 
564  return (a - b);
565  }
566 
567  FeatureResponseType calculateDepthFeature(const PixelInstance& instance, bool flipRegion) const {
568 
569  const Depth depth = instance.getDepth();
570  if (!depth.isValid()) {
571  return std::numeric_limits<double>::quiet_NaN();
572  }
573 
574  FeatureResponseType a;
575  if (flipRegion)
576  a = instance.averageRegionDepth(Offset(-offset1.getX(),offset1.getY()).normalize(depth), region1.normalize(depth));
577  else
578  a = instance.averageRegionDepth(offset1.normalize(depth), region1.normalize(depth));
579  if (isnan(a)) {
580  return a;
581  }
582 
583  FeatureResponseType b;
584  if (flipRegion)
585  b = instance.averageRegionDepth(Offset(-offset2.getX(),offset2.getY()).normalize(depth), region2.normalize(depth));
586  else
587  b = instance.averageRegionDepth(offset2.normalize(depth), region2.normalize(depth));
588  if (isnan(b)) {
589  return b;
590  }
591 
592  assert(a > 0);
593  assert(b > 0);
594 
595  return (a - b);
596  }
597 
598 };
599 
608 template<class memory_space>
610 
611 public:
612 
616 private:
617 
620  m_features(features), m_thresholds(thresholds) {
621  }
622 
623 public:
624 
628  explicit ImageFeaturesAndThresholds(size_t numFeatures, size_t numThresholds,
629  boost::shared_ptr<cuv::allocator> allocator) :
630  m_features(11, numFeatures, allocator), m_thresholds(numThresholds, numFeatures, allocator) {
631  }
632 
636  template<class other_memory_space>
638  m_features(other.features().copy()), m_thresholds(other.thresholds().copy()) {
639  }
640 
644  template<class other_memory_space>
646  m_features = other.features().copy();
647  m_thresholds = other.thresholds().copy();
648  return (*this);
649  }
650 
656  }
657 
662  return m_features;
663  }
664 
669  return m_features;
670  }
671 
676  return m_features[cuv::indices[0][cuv::index_range()]];
677  }
678 
683  return m_features[cuv::indices[1][cuv::index_range()]];
684  }
685 
690  return m_features[cuv::indices[2][cuv::index_range()]];
691  }
692 
697  return m_features[cuv::indices[3][cuv::index_range()]];
698  }
699 
704  return m_features[cuv::indices[4][cuv::index_range()]];
705  }
706 
711  return m_features[cuv::indices[5][cuv::index_range()]];
712  }
713 
718  return m_features[cuv::indices[6][cuv::index_range()]];
719  }
720 
725  return m_features[cuv::indices[7][cuv::index_range()]];
726  }
727 
732  return m_features[cuv::indices[8][cuv::index_range()]];
733  }
734 
739  return m_features[cuv::indices[9][cuv::index_range()]];
740  }
741 
746  return m_features[cuv::indices[10][cuv::index_range()]];
747  }
748 
753  return this->m_thresholds;
754  }
755 
760  return m_features[cuv::indices[0][cuv::index_range()]];
761  }
762 
767  return m_features[cuv::indices[1][cuv::index_range()]];
768  }
769 
774  return m_features[cuv::indices[2][cuv::index_range()]];
775  }
776 
781  return m_features[cuv::indices[3][cuv::index_range()]];
782  }
783 
788  return m_features[cuv::indices[4][cuv::index_range()]];
789  }
790 
795  return m_features[cuv::indices[5][cuv::index_range()]];
796  }
797 
802  return m_features[cuv::indices[6][cuv::index_range()]];
803  }
804 
809  return m_features[cuv::indices[7][cuv::index_range()]];
810  }
811 
816  return m_features[cuv::indices[8][cuv::index_range()]];
817  }
818 
823  return m_features[cuv::indices[9][cuv::index_range()]];
824  }
825 
830  return m_features[cuv::indices[10][cuv::index_range()]];
831  }
832 
837  return this->m_thresholds;
838  }
839 
843  double getThreshold(size_t threshNr, size_t featNr) const {
844  return m_thresholds(threshNr, featNr);
845  }
846 
850  void setFeatureFunction(size_t feat, const ImageFeatureFunction& feature) {
851 
852  types()(feat) = static_cast<int8_t>(feature.getType());
853 
854  offset1X()(feat) = feature.getOffset1().getX();
855  offset1Y()(feat) = feature.getOffset1().getY();
856  offset2X()(feat) = feature.getOffset2().getX();
857  offset2Y()(feat) = feature.getOffset2().getY();
858 
859  region1X()(feat) = feature.getRegion1().getX();
860  region1Y()(feat) = feature.getRegion1().getY();
861  region2X()(feat) = feature.getRegion2().getX();
862  region2Y()(feat) = feature.getRegion2().getY();
863 
864  channel1()(feat) = feature.getChannel1();
865  channel2()(feat) = feature.getChannel2();
866 
867  assert(getFeatureFunction(feat) == feature);
868  }
869 
874  const Offset offset1(offset1X()(feat), offset1Y()(feat));
875  const Offset offset2(offset2X()(feat), offset2Y()(feat));
876  const Offset region1(region1X()(feat), region1Y()(feat));
877  const Offset region2(region2X()(feat), region2Y()(feat));
878  return ImageFeatureFunction(static_cast<FeatureType>(static_cast<int8_t>(types()(feat))),
879  offset1, region1,
880  channel1()(feat),
881  offset2, region2,
882  channel2()(feat));
883  }
884 
885 };
886 
895 template<class memory_space>
896 class Samples {
897 
898 public:
901  float* depths;
902  int* sampleX;
903  int* sampleY;
905  uint8_t* labels;
906  HorizontalFlipSetting* horFlipSetting;
911  Samples(const Samples& samples) :
912  data(samples.data),
913  depths(reinterpret_cast<float*>(data[cuv::indices[0][cuv::index_range()]].ptr())),
914  sampleX(reinterpret_cast<int*>(data[cuv::indices[1][cuv::index_range()]].ptr())),
915  sampleY(reinterpret_cast<int*>(data[cuv::indices[2][cuv::index_range()]].ptr())),
916  imageNumbers(reinterpret_cast<int*>(data[cuv::indices[3][cuv::index_range()]].ptr())),
917  labels(reinterpret_cast<uint8_t*>(data[cuv::indices[4][cuv::index_range()]].ptr())),
918  horFlipSetting(reinterpret_cast<HorizontalFlipSetting*>(data[cuv::indices[5][cuv::index_range()]].ptr())){
919  }
920 
924  template<class T>
925  Samples(const Samples<T>& samples, cudaStream_t stream) :
926  data(samples.data, stream),
927  depths(reinterpret_cast<float*>(data[cuv::indices[0][cuv::index_range()]].ptr())),
928  sampleX(reinterpret_cast<int*>(data[cuv::indices[1][cuv::index_range()]].ptr())),
929  sampleY(reinterpret_cast<int*>(data[cuv::indices[2][cuv::index_range()]].ptr())),
930  imageNumbers(reinterpret_cast<int*>(data[cuv::indices[3][cuv::index_range()]].ptr())),
931  labels(reinterpret_cast<uint8_t*>(data[cuv::indices[4][cuv::index_range()]].ptr())),
932  horFlipSetting(reinterpret_cast<HorizontalFlipSetting*>(data[cuv::indices[5][cuv::index_range()]].ptr())){
933  }
934 
938  Samples(size_t numSamples, boost::shared_ptr<cuv::allocator>& allocator) :
939  data(6, numSamples, allocator),
940  depths(reinterpret_cast<float*>(data[cuv::indices[0][cuv::index_range()]].ptr())),
941  sampleX(reinterpret_cast<int*>(data[cuv::indices[1][cuv::index_range()]].ptr())),
942  sampleY(reinterpret_cast<int*>(data[cuv::indices[2][cuv::index_range()]].ptr())),
943  imageNumbers(reinterpret_cast<int*>(data[cuv::indices[3][cuv::index_range()]].ptr())),
944  labels(reinterpret_cast<uint8_t*>(data[cuv::indices[4][cuv::index_range()]].ptr())),
945  horFlipSetting(reinterpret_cast<HorizontalFlipSetting*>(data[cuv::indices[5][cuv::index_range()]].ptr()))
946  {
947  assert_equals(imageNumbers, data.ptr() + 3 * numSamples);
948  assert_equals(labels, reinterpret_cast<uint8_t*>(data.ptr() + 4 * numSamples));
949  }
950 
951 };
952 
962 public:
963  // box_radius: > 0, half the box side length to uniformly sample
964  // (dx,dy) offsets from.
965 
969  ImageFeatureEvaluation(const size_t treeId, const TrainingConfiguration& configuration) :
970  treeId(treeId), configuration(configuration),
971  imageWidth(0), imageHeight(0),
972  sampleDataAllocator(boost::make_shared<cuv::pooled_cuda_allocator>("sampleData")),
973  featuresAllocator(boost::make_shared<cuv::pooled_cuda_allocator>("feature")),
974  keysIndicesAllocator(boost::make_shared<cuv::pooled_cuda_allocator>("keysIndices")),
975  scoresAllocator(boost::make_shared<cuv::pooled_cuda_allocator>("scores")),
976  countersAllocator(boost::make_shared<cuv::pooled_cuda_allocator>("counters")),
977  featureResponsesAllocator(boost::make_shared<cuv::pooled_cuda_allocator>("featureResponses")) {
978  assert(configuration.getBoxRadius() > 0);
979  assert(configuration.getRegionSize() > 0);
980 
981  initDevice();
982  }
983 
987  std::vector<SplitFunction<PixelInstance, ImageFeatureFunction> > evaluateBestSplits(RandomSource& randomSource,
988  const std::vector<std::pair<boost::shared_ptr<RandomTree<PixelInstance, ImageFeatureFunction> >,
989  std::vector<const PixelInstance*> > >& samplesPerNode);
990 
994  std::vector<std::vector<const PixelInstance*> > prepare(const std::vector<const PixelInstance*>& samples,
996 
1000  std::vector<std::vector<const PixelInstance*> > prepare(const std::vector<const PixelInstance*>& samples,
1001  RandomTree<PixelInstance, ImageFeatureFunction>& node, cuv::dev_memory_space, bool keepMutexLocked = true);
1002 
1007  const std::vector<const PixelInstance*>& batches,
1008  int seed, const bool sort, cuv::host_memory_space);
1009 
1014  const std::vector<const PixelInstance*>& batches,
1015  int seed, const bool sort, cuv::dev_memory_space);
1016 
1020  template<class memory_space>
1021  void sortFeatures(ImageFeaturesAndThresholds<memory_space>& featuresAndThresholds,
1022  const cuv::ndarray<int, memory_space>& keysIndices) const;
1023 
1027  template<class memory_space>
1030  const std::vector<std::vector<const PixelInstance*> >& batches,
1031  const ImageFeaturesAndThresholds<memory_space>& featuresAndThresholds,
1033 
1037  template<class memory_space>
1040  const ImageFeaturesAndThresholds<memory_space>& featuresAndThresholds,
1041  const cuv::ndarray<WeightType, memory_space>& histogram);
1042 
1043 private:
1044 
1045  void selectDevice();
1046 
1047  void initDevice();
1048 
1049  void copyFeaturesToDevice();
1050 
1051  Samples<cuv::dev_memory_space> copySamplesToDevice(const std::vector<const PixelInstance*>& samples,
1052  cudaStream_t stream);
1053 
1054  const ImageFeatureFunction sampleFeature(RandomSource& randomSource,
1055  const std::vector<const PixelInstance*>&) const;
1056 
1057  const size_t treeId;
1058  const TrainingConfiguration& configuration;
1059 
1060  unsigned int imageWidth;
1061  unsigned int imageHeight;
1062 
1063  boost::shared_ptr<cuv::allocator> sampleDataAllocator;
1064  boost::shared_ptr<cuv::allocator> featuresAllocator;
1065  boost::shared_ptr<cuv::allocator> keysIndicesAllocator;
1066  boost::shared_ptr<cuv::allocator> scoresAllocator;
1067  boost::shared_ptr<cuv::allocator> countersAllocator;
1068  boost::shared_ptr<cuv::allocator> featureResponsesAllocator;
1069 };
1070 
1076 public:
1077 
1082  RandomTreeImage(int id, const TrainingConfiguration& configuration);
1083 
1090  const TrainingConfiguration& configuration,
1091  const cuv::ndarray<WeightType, cuv::host_memory_space>& classLabelPriorDistribution);
1092 
1096  void train(const std::vector<LabeledRGBDImage>& trainLabelImages,
1097  RandomSource& randomSource, size_t subsampleCount, size_t numLabels);
1098 
1102  void test(const RGBDImage* image, LabelImage& prediction) const;
1103 
1109  void normalizeHistograms(const double histogramBias);
1110 
1114  const boost::shared_ptr<RandomTree<PixelInstance, ImageFeatureFunction> >& getTree() const {
1115  return tree;
1116  }
1117 
1122  return classLabelPriorDistribution;
1123  }
1124 
1128  size_t getId() const {
1129  return id;
1130  }
1131 
1135  bool shouldIgnoreLabel(const LabelType& label) const;
1136 
1137 private:
1138 
1139  void doTrain(RandomSource& randomSource, size_t numClasses,
1140  std::vector<const PixelInstance*>& subsamples);
1141 
1142  bool finishedTraining;
1143  size_t id;
1144 
1145  const TrainingConfiguration configuration;
1146 
1147  boost::shared_ptr<RandomTree<PixelInstance, ImageFeatureFunction> > tree;
1148 
1149  cuv::ndarray<WeightType, cuv::host_memory_space> classLabelPriorDistribution;
1150 
1151  void calculateLabelPriorDistribution(const std::vector<LabeledRGBDImage>& trainLabelImages);
1152 
1153  std::vector<PixelInstance> subsampleTrainingDataPixelUniform(
1154  const std::vector<LabeledRGBDImage>& trainLabelImages,
1155  RandomSource& randomSource, size_t subsampleCount) const;
1156 
1157  std::vector<PixelInstance> subsampleTrainingDataClassUniform(
1158  const std::vector<LabeledRGBDImage>& trainLabelImages,
1159  RandomSource& randomSource, size_t subsampleCount) const;
1160 
1161 };
1162 
1163 }
1164 
1165 std::ostream& operator<<(std::ostream& os, const curfil::RandomTreeImage& tree);
1166 
1167 std::ostream& operator<<(std::ostream& os, const curfil::ImageFeatureFunction& featureFunction);
1168 
1169 std::ostream& operator<<(std::ostream& os, const curfil::XY& xy);
1170 
1171 #endif
1172