카테고리 없음

mqrrt

정지홍 2025. 8. 21. 17:01

1. mqrrt.cpp

#include "mqrrt/mqrrt.hpp"

#include <random>
#include <algorithm>
#include <cmath>
#include <deque>

#include <nav2_costmap_2d/cost_values.hpp>

namespace mqrrt
{

// =================== lifecycle ===================
void MQRRT::configure(
    const rclcpp_lifecycle::LifecycleNode::WeakPtr & parent,
    std::string name,
    std::shared_ptr<tf2_ros::Buffer> /*tf*/,
    std::shared_ptr<nav2_costmap_2d::Costmap2DROS> costmap_ros)
{
  node_w_ = parent;
  node_   = node_w_.lock();
  logger_ = node_->get_logger();
  costmap_ros_ = costmap_ros;

  auto* cm = costmap_ros_->getCostmap();
  wx_min_ = costmap_ros_->getCostmap()->getOriginX();
  wy_min_ = costmap_ros_->getCostmap()->getOriginY();
  wx_max_ = wx_min_ + cm->getSizeInMetersX();
  wy_max_ = wy_min_ + cm->getSizeInMetersY();

  // parameters (default values already set)
  node_->declare_parameter(name + ".n_max",            n_max_);
  node_->declare_parameter(name + ".step_size",        step_size_);
  node_->declare_parameter(name + ".r_near",           r_near_);
  node_->declare_parameter(name + ".r_commit",         r_commit_);
  node_->declare_parameter(name + ".d_dichotomy",      d_dichotomy_);
  node_->declare_parameter(name + ".ancestry_depth",   ancestry_depth_);

  node_->get_parameter(name + ".n_max",          n_max_);
  node_->get_parameter(name + ".step_size",      step_size_);
  node_->get_parameter(name + ".r_near",         r_near_);
  node_->get_parameter(name + ".r_commit",       r_commit_);
  node_->get_parameter(name + ".d_dichotomy",    d_dichotomy_);
  node_->get_parameter(name + ".ancestry_depth", ancestry_depth_);

  RCLCPP_INFO(logger_, "[MQ-RRT*] configured: n_max=%d step=%.2f r_near=%.2f r_commit=%.2f d_dich=%.3f depth=%d",
              n_max_, step_size_, r_near_, r_commit_, d_dichotomy_, ancestry_depth_);
}

void MQRRT::cleanup()   {}
void MQRRT::activate()  {}
void MQRRT::deactivate(){}

// =================== KD-tree helpers ===================
void MQRRT::rebuildKD()
{
  cloud_.pts.clear();
  cloud_.pts.reserve(tree_.size());
  for (const auto& n : tree_) cloud_.pts.push_back({n.pose.pose.position.x, n.pose.pose.position.y});
  kdtree_.reset(new KDTree(2, cloud_, {10 /* max leaf */}));
  kdtree_->buildIndex();
}

size_t MQRRT::nearest(double x, double y, double& out_dist) const
{
  const size_t N = cloud_.pts.size();
  if (N == 0) { out_dist = std::numeric_limits<double>::infinity(); return 0; }

  size_t ret_index = (size_t)-1;
  double out_dist_sqr = std::numeric_limits<double>::infinity();
  nanoflann::KNNResultSet<double> resultSet(1);
  resultSet.init(&ret_index, &out_dist_sqr);
  double query[2] = {x, y};
  const_cast<KDTree*>(kdtree_.get())->findNeighbors(resultSet, query, nanoflann::SearchParams());
  out_dist = std::sqrt(out_dist_sqr);
  return ret_index;
}

void MQRRT::radiusSearch(double x, double y, double radius,
                         std::vector<std::pair<size_t,double>>& out) const
{
  out.clear();
  if (!kdtree_) return;
  const double rs = radius*radius;
  std::vector<std::pair<size_t, double>> ret_matches;
  nanoflann::RadiusResultSet<double> resultSet(rs, ret_matches);
  double query[2] = {x, y};
  const_cast<KDTree*>(kdtree_.get())->findNeighbors(resultSet, query, nanoflann::SearchParams());
  out = std::move(ret_matches);
}

// =================== tree ops ===================
size_t MQRRT::addNode(const geometry_msgs::msg::PoseStamped& p, int parent_idx)
{
  Node n;
  n.pose = p;
  n.parent = parent_idx;
  if (parent_idx >= 0) {
    n.cost = tree_[parent_idx].cost +
             dist(tree_[parent_idx].pose.pose.position.x, tree_[parent_idx].pose.pose.position.y,
                  p.pose.position.x, p.pose.position.y);
    tree_[parent_idx].children.push_back((int)tree_.size());
  } else {
    n.cost = 0.0;
  }
  size_t idx = tree_.size();
  tree_.push_back(std::move(n));

  cloud_.pts.push_back({p.pose.position.x, p.pose.position.y});
  // nanoflann version in many distros doesn't support incremental addPoints(); rebuild instead.
  rebuildKD();
  return idx;
}

void MQRRT::reconnect(size_t new_parent, size_t child)
{
  // remove from old parent's children
  int op = tree_[child].parent;
  if (op >= 0) {
    auto& ch = tree_[op].children;
    ch.erase(std::remove(ch.begin(), ch.end(), (int)child), ch.end());
  }
  tree_[child].parent = (int)new_parent;
  tree_[new_parent].children.push_back((int)child);

  // cost delta
  double new_cost = tree_[new_parent].cost +
    dist(tree_[new_parent].pose.pose.position.x, tree_[new_parent].pose.pose.position.y,
         tree_[child].pose.pose.position.x, tree_[child].pose.pose.position.y);
  double delta = new_cost - tree_[child].cost;
  propagateCost(child, delta);
}

void MQRRT::propagateCost(size_t start, double delta)
{
  std::deque<size_t> dq; dq.push_back(start);
  while(!dq.empty()) {
    size_t i = dq.front(); dq.pop_front();
    tree_[i].cost += delta;
    for (int c : tree_[i].children) dq.push_back((size_t)c);
  }
}

bool MQRRT::isDescendant(size_t root, size_t q) const
{
  std::deque<size_t> dq; dq.push_back(root);
  while(!dq.empty()) {
    size_t i = dq.front(); dq.pop_front();
    if (i == q) return true;
    for (int c : tree_[i].children) dq.push_back((size_t)c);
  }
  return false;
}

// gather ancestors up to (depth-1) levels above each x_near node
void MQRRT::ancestryFromNeighbors(
    const std::vector<std::pair<size_t,double>>& X_near,
    int ancestry_depth,
    std::vector<size_t>& out) const
{
  out.clear();
  if (ancestry_depth <= 1) return;
  const int levels = ancestry_depth - 1;
  std::unordered_set<size_t> seen;
  for (const auto& m : X_near) {
    size_t cur = m.first;
    for (int l=0; l<levels; ++l) {
      int p = tree_[cur].parent;
      if (p < 0) break;
      size_t ps = (size_t)p;
      if (seen.insert(ps).second) out.push_back(ps);
      cur = ps;
    }
  }
}

// choose best parent among candidates (and x_nearest as fallback)
size_t MQRRT::chooseBestParent(
    const std::vector<size_t>& candidates,
    size_t x_nearest,
    const geometry_msgs::msg::PoseStamped& x_rand) const
{
  size_t best = x_nearest;
  double bestCost = tree_[x_nearest].cost +
      dist(tree_[x_nearest].pose.pose.position.x, tree_[x_nearest].pose.pose.position.y,
           x_rand.pose.position.x, x_rand.pose.position.y);
  for (size_t idx : candidates) {
    if (!isCollisionFree(tree_[idx].pose.pose.position.x, tree_[idx].pose.pose.position.y,
                         x_rand.pose.position.x, x_rand.pose.position.y)) continue;
    double c = tree_[idx].cost +
      dist(tree_[idx].pose.pose.position.x, tree_[idx].pose.pose.position.y,
           x_rand.pose.position.x, x_rand.pose.position.y);
    if (c < bestCost) { best = idx; bestCost = c; }
  }
  return best;
}

// ========== Remove-tips (Alg. 6) ==========
std::pair<geometry_msgs::msg::PoseStamped, geometry_msgs::msg::PoseStamped>
MQRRT::removeTips(size_t x_parent_idx,
                  const geometry_msgs::msg::PoseStamped& x_rand,
                  double D_dichotomy) const
{
  geometry_msgs::msg::PoseStamped empty; // default-constructed; we use orientation.w=0 as "empty" marker
  empty.pose.orientation.w = 0.0;

  // only meaningful if x_parent has an actual parent
  int pp = tree_[x_parent_idx].parent;
  if (pp < 0) return {empty, empty};

  auto x_parent = tree_[x_parent_idx].pose;
  auto x_ppose  = tree_[(size_t)pp].pose;

  auto x_allow_A  = x_parent;
  auto x_allow_B  = x_parent;
  auto x_forbid_A = x_ppose;
  auto x_forbid_B = x_rand;

  auto point_to_line_dist = [](const geometry_msgs::msg::PoseStamped& P,
                               const geometry_msgs::msg::PoseStamped& A,
                               const geometry_msgs::msg::PoseStamped& B)->double{
    const double x0=P.pose.position.x, y0=P.pose.position.y;
    const double x1=A.pose.position.x, y1=A.pose.position.y;
    const double x2=B.pose.position.x, y2=B.pose.position.y;
    double num = std::abs((y2-y1)*x0 - (x2-x1)*y0 + x2*y1 - y2*x1);
    double den = std::hypot(y2-y1, x2-x1);
    return (den <= 1e-9)? 0.0 : num/den;
  };

  double D = point_to_line_dist(x_allow_A, x_forbid_A, x_forbid_B) * 0.5;
  int it = 0;
  while (D > D_dichotomy && it < 32) {
    auto x_mid_A = midpoint(x_allow_A, x_forbid_A);
    auto x_mid_B = midpoint(x_allow_B, x_forbid_B);
    if (isCollisionFree(x_mid_A.pose.position.x, x_mid_A.pose.position.y,
                        x_mid_B.pose.position.x, x_mid_B.pose.position.y)) {
      x_allow_A = x_mid_A;
      x_allow_B = x_mid_B;
    } else {
      x_forbid_A = x_mid_A;
      x_forbid_B = x_mid_B;
    }
    D = point_to_line_dist(x_allow_A, x_forbid_A, x_forbid_B) * 0.5;
    ++it;
  }

  if (x_allow_A.pose.position.x != x_parent.pose.position.x ||
      x_allow_A.pose.position.y != x_parent.pose.position.y) {
    // mark valid by setting w=1
    x_allow_A.pose.orientation.w = 1.0;
    x_allow_B.pose.orientation.w = 1.0;
    return {x_allow_A, x_allow_B};
  } else {
    return {empty, empty};
  }
}

// ========== CreateNodes (Alg. 7) ==========
std::pair<geometry_msgs::msg::PoseStamped, geometry_msgs::msg::PoseStamped>
MQRRT::createNodes(size_t x_parent_idx,
                   const geometry_msgs::msg::PoseStamped& x_rand,
                   const geometry_msgs::msg::PoseStamped& x_corner_A,
                   const geometry_msgs::msg::PoseStamped& x_corner_B,
                   double D_dichotomy) const
{
  geometry_msgs::msg::PoseStamped empty; empty.pose.orientation.w = 0.0;

  if (x_corner_A.pose.orientation.w == 0.0 || x_corner_B.pose.orientation.w == 0.0)
    return {empty, empty};

  auto x_forbid_A = tree_[ (size_t)tree_[x_parent_idx].parent ].pose; // Parent(x_parent)
  auto A = x_corner_A;
  auto B = x_corner_B;

  int it = 0;
  while (dist(A.pose.position.x, A.pose.position.y,
              B.pose.position.x, B.pose.position.y) > D_dichotomy && it < 32) {
    auto mid = midpoint(A, B);
    if (isCollisionFree(x_forbid_A.pose.position.x, x_forbid_A.pose.position.y,
                        mid.pose.position.x, mid.pose.position.y)) {
      A = mid;
    } else {
      B = mid;
    }
    ++it;
  }
  auto x_create_A = A;

  auto x_forbid_B = x_rand;
  it = 0;
  while (dist(x_create_A.pose.position.x, x_create_A.pose.position.y,
              B.pose.position.x, B.pose.position.y) > D_dichotomy && it < 32) {
    auto mid = midpoint(x_create_A, B);
    if (isCollisionFree(x_forbid_B.pose.position.x, x_forbid_B.pose.position.y,
                        mid.pose.position.x, mid.pose.position.y)) {
      B = mid;
    } else {
      x_create_A = mid;
    }
    ++it;
  }
  auto x_create_B = B;
  x_create_A.pose.orientation.w = 1.0;
  x_create_B.pose.orientation.w = 1.0;
  return {x_create_A, x_create_B};
}

// ========== Q-RRT* Rewire (Alg.5) ==========
void MQRRT::rewireQ(
    size_t x_rand_idx,
    const std::vector<std::pair<size_t,double>>& X_near,
    const std::vector<size_t>& X_parent)
{
  // ancestry set for x_rand
  std::vector<size_t> parents = X_parent;

  // For each x_near, try to use {x_rand} U X_parent as new parent candidates
  for (const auto& nm : X_near) {
    size_t q = nm.first;
    double best = tree_[q].cost; // current cost baseline
    size_t best_parent = (size_t)tree_[q].parent;

    std::vector<size_t> candidates;
    candidates.reserve(parents.size() + 1);
    candidates.push_back(x_rand_idx);
    candidates.insert(candidates.end(), parents.begin(), parents.end());

    for (size_t mid : candidates) {
      if (mid == q) continue;
      if (isDescendant(q, mid)) continue; // avoid cycles
      if (!isCollisionFree(tree_[mid].pose.pose.position.x, tree_[mid].pose.pose.position.y,
                           tree_[q].pose.pose.position.x,   tree_[q].pose.pose.position.y))
        continue;

      double c = tree_[mid].cost +
        dist(tree_[mid].pose.pose.position.x, tree_[mid].pose.pose.position.y,
             tree_[q].pose.pose.position.x,   tree_[q].pose.pose.position.y);
      if (c + 1e-9 < best) {
        best = c;
        best_parent = mid;
      }
    }

    if (best_parent != (size_t)tree_[q].parent) {
      reconnect(best_parent, q);
    }
  }
}

// =================== collision check ===================
bool MQRRT::isCollisionFree(double x0, double y0, double x1, double y1) const
{
  auto* cm = costmap_ros_->getCostmap();
  unsigned int mx0, my0, mx1, my1;
  if (!cm->worldToMap(x0, y0, mx0, my0)) return false;
  if (!cm->worldToMap(x1, y1, mx1, my1)) return false;

  int dx = (int)mx1 - (int)mx0;
  int dy = (int)my1 - (int)my0;
  int steps = std::max(std::abs(dx), std::abs(dy));
  if (steps == 0) {
    unsigned char c = cm->getCost(mx0, my0);
    return c < nav2_costmap_2d::LETHAL_OBSTACLE &&
           c != nav2_costmap_2d::NO_INFORMATION;
  }

  for (int i = 0; i <= steps; ++i) {
    double t = static_cast<double>(i) / static_cast<double>(steps);
    int mx = (int)std::round(mx0 + t*dx);
    int my = (int)std::round(my0 + t*dy);
    if (mx < 0 || my < 0 || mx >= (int)cm->getSizeInCellsX() || my >= (int)cm->getSizeInCellsY())
      return false;
    unsigned char c = cm->getCost((unsigned int)mx, (unsigned int)my);
    if (c >= nav2_costmap_2d::LETHAL_OBSTACLE || c == nav2_costmap_2d::NO_INFORMATION)
      return false;
  }
  return true;
}

// =================== sampling ===================
geometry_msgs::msg::PoseStamped MQRRT::sampleUniform() const
{
  static thread_local std::mt19937 rng(std::random_device{}());
  std::uniform_real_distribution<double> ux(wx_min_, wx_max_);
  std::uniform_real_distribution<double> uy(wy_min_, wy_max_);
  geometry_msgs::msg::PoseStamped p = start_;
  p.pose.position.x = ux(rng);
  p.pose.position.y = uy(rng);
  p.pose.orientation.w = 1.0;
  return p;
}

geometry_msgs::msg::PoseStamped MQRRT::sampleSpare(
    int /*iter*/, double r_commit, size_t num_nodes) const
{
  // "Spare sampling": half the time sample uniformly; half the time sample around a random
  // existing node within r_commit radius using a small Gaussian.
  static thread_local std::mt19937 rng(std::random_device{}());
  std::bernoulli_distribution coin(0.5);

  if (tree_.empty() || coin(rng)) {
    return sampleUniform();
  } else {
    std::uniform_int_distribution<size_t> pick(0, num_nodes-1);
    size_t i = pick(rng);
    std::normal_distribution<double> dx(0.0, r_commit*0.5);
    std::normal_distribution<double> dy(0.0, r_commit*0.5);

    geometry_msgs::msg::PoseStamped p = start_;
    p.pose.position.x = tree_[i].pose.pose.position.x + dx(rng);
    p.pose.position.y = tree_[i].pose.pose.position.y + dy(rng);
    // clamp to map bounds
    p.pose.position.x = std::min(std::max(p.pose.position.x, wx_min_), wx_max_);
    p.pose.position.y = std::min(std::max(p.pose.position.y, wy_min_), wy_max_);
    p.pose.orientation.w = 1.0;
    return p;
  }
}

// =================== createPlan (Alg.4 main loop) ===================
nav_msgs::msg::Path MQRRT::createPlan(
    const geometry_msgs::msg::PoseStamped & start,
    const geometry_msgs::msg::PoseStamped & goal)
{
  nav_msgs::msg::Path path;
  path.header = start.header;
  start_ = start; goal_ = goal;

  // reset state
  tree_.clear(); cloud_.pts.clear(); kdtree_.reset();
  N1_ = 0;

  // init tree with start
  (void)addNode(start, -1);

  // main loop
  for (int i=1; i<=n_max_; ++i) {

    // (1) dynamic goal bias via sigmoid S(N1) = 1 / (1 + exp(-(N1-5)))
    double S = 1.0 / (1.0 + std::exp(-(static_cast<double>(N1_) - 5.0)));
    static thread_local std::mt19937 rng(std::random_device{}());
    std::uniform_real_distribution<double> U(0.0, 1.0);

    geometry_msgs::msg::PoseStamped x_rand;
    if (U(rng) < S) {
      x_rand = goal_;
    } else {
      x_rand = sampleSpare(i, r_commit_, tree_.size());
    }

    // every 10th iteration, reset the counter (as in the figure comment)
    if (i % 10 == 0) N1_ = 0;

    // (2) nearest
    double nd; size_t x_nearest = nearest(x_rand.pose.position.x, x_rand.pose.position.y, nd);

    // (3) steer towards x_rand with step_size_ (optional)
    geometry_msgs::msg::PoseStamped x_new = x_rand;
    double dx = x_rand.pose.position.x - tree_[x_nearest].pose.pose.position.x;
    double dy = x_rand.pose.position.y - tree_[x_nearest].pose.pose.position.y;
    double d  = std::hypot(dx,dy);
    if (d > step_size_ && d > 1e-9) {
      x_new = tree_[x_nearest].pose;
      x_new.pose.position.x += dx * (step_size_ / d);
      x_new.pose.position.y += dy * (step_size_ / d);
    }

    // (4) collision check
    if (!isCollisionFree(tree_[x_nearest].pose.pose.position.x, tree_[x_nearest].pose.pose.position.y,
                         x_new.pose.position.x, x_new.pose.position.y)) {
      continue;
    }
    ++N1_; // success streak

    // (5) X_near within r_near_
    std::vector<std::pair<size_t,double>> X_near;
    radiusSearch(x_new.pose.position.x, x_new.pose.position.y, r_near_, X_near);

    // (6) ancestry candidates from neighbors
    std::vector<size_t> X_parent;
    ancestryFromNeighbors(X_near, ancestry_depth_, X_parent);

    // (7) choose parent among {X_near ∪ X_parent}
    std::vector<size_t> candidates; candidates.reserve(X_near.size() + X_parent.size());
    for (auto& m: X_near) candidates.push_back(m.first);
    candidates.insert(candidates.end(), X_parent.begin(), X_parent.end());
    size_t x_parent = chooseBestParent(candidates, x_nearest, x_new);

    // (8) Remove-tips + CreateNodes to generate two assist nodes near obstacle
    auto [x_corner_A, x_corner_B] = removeTips(x_parent, x_new, d_dichotomy_);
    auto [x_create_A, x_create_B] = createNodes(x_parent, x_new, x_corner_A, x_corner_B, d_dichotomy_);

    size_t x_rand_idx = (size_t)-1;
    if (x_create_A.pose.orientation.w != 0.0 && x_create_B.pose.orientation.w != 0.0) {
      // add two assist nodes and connect:
      size_t a = addNode(x_create_A, (int)tree_[x_parent].parent); // edge (Parent(x_parent), x_create_A)
      size_t b = addNode(x_create_B, (int)a);                      // edge (x_create_A, x_create_B)
      x_rand_idx = addNode(x_new, (int)b);                         // edge (x_create_B, x_rand)
    } else {
      x_rand_idx = addNode(x_new, (int)x_parent);                  // edge (x_parent, x_rand)
    }

    // (9) goal reached?
    if (dist(x_new.pose.position.x, x_new.pose.position.y,
             goal_.pose.position.x, goal_.pose.position.y) <= step_size_) {
      // reconstruct path
      std::vector<size_t> chain;
      for (size_t cur = x_rand_idx; cur != (size_t)-1; ) {
        chain.push_back(cur);
        int p = tree_[cur].parent;
        if (p < 0) break;
        cur = (size_t)p;
      }
      std::reverse(chain.begin(), chain.end());
      path.poses.clear();
      path.poses.reserve(chain.size());
      for (size_t idx : chain) {
        auto p = tree_[idx].pose;
        p.pose.orientation.w = 1.0;
        p.header = start.header;
        path.poses.push_back(p);
      }
      // simple densify for visualization
      auto densify = [&](nav_msgs::msg::Path& inout){
        if (inout.poses.size()<2) return;
        nav_msgs::msg::Path out; out.header = inout.header;
        const double step = step_size_*0.5;
        out.poses.push_back(inout.poses.front());
        for (size_t i=1;i<inout.poses.size();++i){
          const auto& A = out.poses.back().pose.position;
          const auto& B = inout.poses[i].pose.position;
          double dx=B.x-A.x, dy=B.y-A.y, d=std::hypot(dx,dy);
          if (d<1e-6) continue;
          int n = std::max(1,(int)std::floor(d/step));
          for (int k=1;k<=n;++k){
            auto p = out.poses.back();
            double t = std::min(1.0, k*step/d);
            p.pose.position.x = A.x + dx*t;
            p.pose.position.y = A.y + dy*t;
            p.header = out.header;
            p.pose.orientation.w = 1.0;
            out.poses.push_back(p);
          }
        }
        inout = std::move(out);
      };
      densify(path);

      // simple stats
      auto stats = [&](const nav_msgs::msg::Path& p){
        double L=0.0; double turn=0.0;
        auto ang = [](double dx, double dy){return std::atan2(dy,dx);};
        bool has_prev=false; double prev_yaw=0.0;
        for (size_t i=1;i<p.poses.size();++i){
          const auto& A = p.poses[i-1].pose.position;
          const auto& B = p.poses[i].pose.position;
          double dx=B.x-A.x, dy=B.y-A.y, d=std::hypot(dx,dy);
          L += d;
          double yaw = ang(dx,dy);
          if (has_prev){
            double dth = std::remainder(yaw - prev_yaw, 2*M_PI);
            turn += std::abs(dth);
          } else has_prev=true;
          prev_yaw = yaw;
        }
        return std::make_pair(L, turn*180.0/M_PI);
      };
      auto [len,deg] = stats(path);
      RCLCPP_INFO(logger_, "[MQ-RRT*] path length=%.3f m, turn=%.1f deg, nodes=%zu",
                  len, deg, tree_.size());
      return path;
    }

    // (10) rewire using Q-RRT* style
    rewireQ(x_rand_idx, X_near, X_parent);
  }

  // failed
  return path;
}

} // namespace mqrrt

#include <pluginlib/class_list_macros.hpp>
PLUGINLIB_EXPORT_CLASS(mqrrt::MQRRT, nav2_core::GlobalPlanner)

 

 

 

2. hpp

#ifndef MQRRT__MQRRT_HPP_
#define MQRRT__MQRRT_HPP_

#include <nav2_core/global_planner.hpp>
#include <geometry_msgs/msg/pose_stamped.hpp>
#include <nav_msgs/msg/path.hpp>
#include <rclcpp/rclcpp.hpp>

#include <tf2_ros/buffer.h>
#include <nav2_costmap_2d/costmap_2d_ros.hpp>

#include <nanoflann.hpp>
#include <vector>
#include <utility>
#include <limits>
#include <unordered_set>

namespace mqrrt
{

// ===================== KD-Tree backing store =====================
struct PointCloud {
  struct Point { double x, y; };
  std::vector<Point> pts;
  inline size_t kdtree_get_point_count() const { return pts.size(); }
  inline double kdtree_get_pt(const size_t idx, int dim) const {
    return (dim == 0) ? pts[idx].x : pts[idx].y;
  }
  template <class BBOX> bool kdtree_get_bbox(BBOX&) const { return false; }
};

using KDTree = nanoflann::KDTreeSingleIndexAdaptor<
    nanoflann::L2_Simple_Adaptor<double, PointCloud>,
    PointCloud, 2>;

// ============================ Node ===============================
struct Node {
  geometry_msgs::msg::PoseStamped pose;
  int    parent = -1;
  double cost   = 0.0;
  std::vector<int> children;
};

// ========================= Planner ==============================
class MQRRT : public nav2_core::GlobalPlanner
{
public:
  MQRRT() = default;
  ~MQRRT() override = default;

  void configure(
      const rclcpp_lifecycle::LifecycleNode::WeakPtr & parent,
      std::string name,
      std::shared_ptr<tf2_ros::Buffer> tf,
      std::shared_ptr<nav2_costmap_2d::Costmap2DROS> costmap_ros) override;

  void cleanup() override;
  void activate() override;
  void deactivate() override;

  nav_msgs::msg::Path createPlan(
      const geometry_msgs::msg::PoseStamped & start,
      const geometry_msgs::msg::PoseStamped & goal) override;

private:
  // ------------- Core utils -------------
  bool isCollisionFree(double x0, double y0, double x1, double y1) const;
  static double sqr(double v) { return v*v; }
  static double dist(double x0, double y0, double x1, double y1) {
    return std::hypot(x1-x0, y1-y0);
  }
  static geometry_msgs::msg::PoseStamped midpoint(
      const geometry_msgs::msg::PoseStamped& a,
      const geometry_msgs::msg::PoseStamped& b) {
    geometry_msgs::msg::PoseStamped m = a;
    m.pose.position.x = 0.5*(a.pose.position.x + b.pose.position.x);
    m.pose.position.y = 0.5*(a.pose.position.y + b.pose.position.y);
    return m;
  }

  // ------------- KD tree helpers -------------
  void rebuildKD();
  size_t nearest(double x, double y, double& out_dist) const;
  void radiusSearch(double x, double y, double radius,
                    std::vector<std::pair<size_t,double>>& out) const;

  // ------------- Tree ops -------------
  size_t addNode(const geometry_msgs::msg::PoseStamped& p, int parent_idx);
  void   reconnect(size_t new_parent, size_t child);
  void   propagateCost(size_t start, double delta);
  bool   isDescendant(size_t root, size_t q) const;
  double pathCost(size_t idx) const { return tree_[idx].cost; }

  // ------------- MQ-RRT* helpers -------------
  void ancestryFromNeighbors(
      const std::vector<std::pair<size_t,double>>& X_near,
      int ancestry_depth,
      std::vector<size_t>& out) const;

  size_t chooseBestParent(
      const std::vector<size_t>& candidates,
      size_t x_nearest,
      const geometry_msgs::msg::PoseStamped& x_rand) const;

  // Remove-tips (Alg. 6) & CreateNodes (Alg. 7)
  std::pair<geometry_msgs::msg::PoseStamped, geometry_msgs::msg::PoseStamped>
  removeTips(size_t x_parent_idx,
             const geometry_msgs::msg::PoseStamped& x_rand,
             double D_dichotomy) const;

  std::pair<geometry_msgs::msg::PoseStamped, geometry_msgs::msg::PoseStamped>
  createNodes(size_t x_parent_idx,
              const geometry_msgs::msg::PoseStamped& x_rand,
              const geometry_msgs::msg::PoseStamped& x_corner_A,
              const geometry_msgs::msg::PoseStamped& x_corner_B,
              double D_dichotomy) const;

  void rewireQ(
      size_t x_rand_idx,
      const std::vector<std::pair<size_t,double>>& X_near,
      const std::vector<size_t>& X_parent);

  // ------------- Sampling -------------
  geometry_msgs::msg::PoseStamped sampleUniform() const;
  geometry_msgs::msg::PoseStamped sampleSpare(
      int iter, double r_commit, size_t num_nodes) const;

  // ------------- Parameters/State -------------
  rclcpp_lifecycle::LifecycleNode::WeakPtr node_w_;
  rclcpp_lifecycle::LifecycleNode::SharedPtr node_;
  rclcpp::Logger logger_{rclcpp::get_logger("mqrrt")};
  std::shared_ptr<nav2_costmap_2d::Costmap2DROS> costmap_ros_;

  // map bounds in world coords
  double wx_min_{0.0}, wy_min_{0.0}, wx_max_{0.0}, wy_max_{0.0};

  // core params
  int    n_max_{4000};
  double step_size_{0.5};
  double r_near_{1.5};
  double r_commit_{2.0};
  double d_dichotomy_{0.10};
  int    ancestry_depth_{3}; // D_ancestor in the figure

  // goal-bias control
  int N1_{0}; // success counter for dynamic goal bias
  geometry_msgs::msg::PoseStamped start_, goal_;

  // data
  std::vector<Node> tree_;
  PointCloud cloud_;
  std::unique_ptr<KDTree> kdtree_;
};

} // namespace mqrrt

#endif  // MQRRT__MQRRT_HPP_