Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- From fa7cefa73a6157dd9f105e9dc5e7e0b16f3253d6 Mon Sep 17 00:00:00 2001
- From: Ryan Curtin <ryan@ratml.org>
- Date: Tue, 9 Jul 2019 20:15:29 -0400
- Subject: [PATCH] Fix cover tree statistic computation.
- ---
- .../core/tree/cover_tree/cover_tree_impl.hpp | 53 ++++++++++---------
- 1 file changed, 28 insertions(+), 25 deletions(-)
- diff --git a/src/mlpack/core/tree/cover_tree/cover_tree_impl.hpp b/src/mlpack/core/tree/cover_tree/cover_tree_impl.hpp
- index d4749027f..a16073431 100644
- --- a/src/mlpack/core/tree/cover_tree/cover_tree_impl.hpp
- +++ b/src/mlpack/core/tree/cover_tree/cover_tree_impl.hpp
- @@ -21,6 +21,18 @@
- namespace mlpack {
- namespace tree {
- +// Build the statistics, bottom-up.
- +template<typename TreeType, typename StatisticType>
- +void BuildStatistics(TreeType* node)
- +{
- + // Recurse first.
- + for (size_t i = 0; i < node->NumChildren(); ++i)
- + BuildStatistics<TreeType, StatisticType>(&node->Child(i));
- +
- + // Now build the statistic.
- + node->Stat() = StatisticType(*node);
- +}
- +
- // Create the cover tree.
- template<
- typename MetricType,
- @@ -88,9 +100,8 @@ CoverTree<MetricType, StatisticType, MatType, RootPointPolicy>::CoverTree(
- {
- children.push_back(&(old->Child(i)));
- - // Set its parent correctly, and rebuild the statistic.
- + // Set its parent correctly.
- old->Child(i).Parent() = this;
- - old->Child(i).Stat() = StatisticType(old->Child(i));
- }
- // Remove all the children so they don't get erased.
- @@ -110,8 +121,9 @@ CoverTree<MetricType, StatisticType, MatType, RootPointPolicy>::CoverTree(
- else
- scale = (int) ceil(log(furthestDescendantDistance) / log(base));
- - // Initialize statistic.
- - stat = StatisticType(*this);
- + // Initialize statistics recursively after the entire tree construction is
- + // complete.
- + BuildStatistics<CoverTree, StatisticType>(this);
- Log::Info << distanceComps << " distance computations during tree "
- << "construction." << std::endl;
- @@ -181,8 +193,6 @@ CoverTree<MetricType, StatisticType, MatType, RootPointPolicy>::CoverTree(
- // Set its parent correctly.
- old->Child(i).Parent() = this;
- - // Rebuild the statistic.
- - old->Child(i).Stat() = StatisticType(old->Child(i));
- }
- // Remove all the children so they don't get erased.
- @@ -202,8 +212,9 @@ CoverTree<MetricType, StatisticType, MatType, RootPointPolicy>::CoverTree(
- else
- scale = (int) ceil(log(furthestDescendantDistance) / log(base));
- - // Initialize statistic.
- - stat = StatisticType(*this);
- + // Initialize statistics recursively after the entire tree construction is
- + // complete.
- + BuildStatistics<CoverTree, StatisticType>(this);
- Log::Info << distanceComps << " distance computations during tree "
- << "construction." << std::endl;
- @@ -272,9 +283,8 @@ CoverTree<MetricType, StatisticType, MatType, RootPointPolicy>::CoverTree(
- {
- children.push_back(&(old->Child(i)));
- - // Set its parent correctly, and rebuild the statistic.
- + // Set its parent correctly.
- old->Child(i).Parent() = this;
- - old->Child(i).Stat() = StatisticType(old->Child(i));
- }
- // Remove all the children so they don't get erased.
- @@ -294,8 +304,9 @@ CoverTree<MetricType, StatisticType, MatType, RootPointPolicy>::CoverTree(
- else
- scale = (int) ceil(log(furthestDescendantDistance) / log(base));
- - // Initialize statistic.
- - stat = StatisticType(*this);
- + // Initialize statistics recursively after the entire tree construction is
- + // complete.
- + BuildStatistics<CoverTree, StatisticType>(this);
- Log::Info << distanceComps << " distance computations during tree "
- << "construction." << std::endl;
- @@ -363,9 +374,8 @@ CoverTree<MetricType, StatisticType, MatType, RootPointPolicy>::CoverTree(
- {
- children.push_back(&(old->Child(i)));
- - // Set its parent correctly, and rebuild the statistic.
- + // Set its parent correctly.
- old->Child(i).Parent() = this;
- - old->Child(i).Stat() = StatisticType(old->Child(i));
- }
- // Remove all the children so they don't get erased.
- @@ -385,8 +395,9 @@ CoverTree<MetricType, StatisticType, MatType, RootPointPolicy>::CoverTree(
- else
- scale = (int) ceil(log(furthestDescendantDistance) / log(base));
- - // Initialize statistic.
- - stat = StatisticType(*this);
- + // Initialize statistics recursively after the entire tree construction is
- + // complete.
- + BuildStatistics<CoverTree, StatisticType>(this);
- Log::Info << distanceComps << " distance computations during tree "
- << "construction." << std::endl;
- @@ -429,15 +440,11 @@ CoverTree<MetricType, StatisticType, MatType, RootPointPolicy>::CoverTree(
- {
- this->scale = INT_MIN;
- numDescendants = 1;
- - stat = StatisticType(*this);
- return;
- }
- // Otherwise, create the children.
- CreateChildren(indices, distances, nearSetSize, farSetSize, usedSetSize);
- -
- - // Initialize statistic.
- - stat = StatisticType(*this);
- }
- // Manually create a cover tree node.
- @@ -472,9 +479,6 @@ CoverTree<MetricType, StatisticType, MatType, RootPointPolicy>::CoverTree(
- // If necessary, create a local metric.
- if (localMetric)
- this->metric = new MetricType();
- -
- - // Initialize the statistic.
- - stat = StatisticType(*this);
- }
- template<
- @@ -1526,11 +1530,10 @@ inline void CoverTree<MetricType, StatisticType, MatType, RootPointPolicy>::
- // Now take its child.
- children.push_back(&(old->Child(0)));
- - // Set its parent and parameters correctly, and rebuild the statistic.
- + // Set its parent and parameters correctly.
- old->Child(0).Parent() = this;
- old->Child(0).ParentDistance() = old->ParentDistance();
- old->Child(0).DistanceComps() = old->DistanceComps();
- - old->Child(0).Stat() = StatisticType(old->Child(0));
- // Remove its child (so it doesn't delete it).
- old->Children().erase(old->Children().begin() + old->Children().size() - 1);
- --
- 2.20.1
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement