Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 31 additions & 25 deletions src/search/dag_classic/node.cc
Original file line number Diff line number Diff line change
Expand Up @@ -343,54 +343,60 @@ void Node::CancelScoreUpdate(uint32_t multivisit) {
n_in_flight_.fetch_sub(multivisit, std::memory_order_acq_rel);
}

void LowNode::FinalizeScoreUpdate(float v, float d, float m,
uint32_t multivisit) {
double LowNode::FinalizeScoreUpdate(double v, double d, float m,
uint32_t multivisit) {
assert(edges_);
// Increment N.
n_ += multivisit;

// Recompute Q.
wl_ += multivisit * (v - wl_) / (n_ + multivisit);
d_ += multivisit * (d - d_) / (n_ + multivisit);
m_ += multivisit * (m - m_) / (n_ + multivisit);
double divisor = 1.0 / n_;
wl_ += multivisit * (v - wl_) * divisor;
d_ += multivisit * (d - d_) * divisor;
m_ += multivisit * (m - m_) * static_cast<float>(divisor);

assert(WLDMInvariantsHold());

// Increment N.
n_ += multivisit;
return divisor;
}

void LowNode::AdjustForTerminal(float v, float d, float m,
void LowNode::AdjustForTerminal(double v, double d, float m, double divisor,
uint32_t multivisit) {
assert(static_cast<uint32_t>(multivisit) <= n_);

// Recompute Q.
wl_ += multivisit * v / n_;
d_ += multivisit * d / n_;
m_ += multivisit * m / n_;
wl_ += multivisit * v * divisor;
d_ += multivisit * d * divisor;
m_ += multivisit * m * static_cast<float>(divisor);

assert(WLDMInvariantsHold());
}

void Node::FinalizeScoreUpdate(float v, float d, float m, uint32_t multivisit) {
double Node::FinalizeScoreUpdate(double v, double d, float m,
uint32_t multivisit) {
// Increment N.
n_ += multivisit;

// Recompute Q.
wl_ += multivisit * (v - wl_) / (n_ + multivisit);
d_ += multivisit * (d - d_) / (n_ + multivisit);
m_ += multivisit * (m - m_) / (n_ + multivisit);
double divisor = 1.0 / n_;
wl_ += multivisit * (v - wl_) * divisor;
d_ += multivisit * (d - d_) * divisor;
m_ += multivisit * (m - m_) * static_cast<float>(divisor);

assert(WLDMInvariantsHold());

// Increment N.
n_ += multivisit;
// Decrement virtual loss.
assert(GetNInFlight() >= (uint32_t)multivisit);
n_in_flight_.fetch_sub(multivisit, std::memory_order_acq_rel);
return divisor;
}

void Node::AdjustForTerminal(float v, float d, float m, uint32_t multivisit) {
void Node::AdjustForTerminal(double v, double d, float m, double divisor,
uint32_t multivisit) {
assert(static_cast<uint32_t>(multivisit) <= n_);

// Recompute Q.
wl_ += multivisit * v / n_;
d_ += multivisit * d / n_;
m_ += multivisit * m / n_;
wl_ += multivisit * v * divisor;
d_ += multivisit * d * divisor;
m_ += multivisit * m * static_cast<float>(divisor);

assert(WLDMInvariantsHold());
}
Expand Down Expand Up @@ -610,10 +616,10 @@ void Node::SortEdges() const {
low_node_->SortEdges();
}

static constexpr float wld_tolerance = 0.000001f;
static constexpr double wld_tolerance = 0.000001f;
Copy link

Copilot AI Feb 13, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The variable wld_tolerance is declared as double but initialized with a float literal suffix f. For consistency with the type declaration, use a double literal without the f suffix: 0.000001 instead of 0.000001f.

Suggested change
static constexpr double wld_tolerance = 0.000001f;
static constexpr double wld_tolerance = 0.000001;

Copilot uses AI. Check for mistakes.
static constexpr float m_tolerance = 0.000001f;

static bool WLDMInvariantsHold(float wl, float d, float m) {
static bool WLDMInvariantsHold(double wl, double d, float m) {
return -(1.0f + wld_tolerance) < wl && wl < (1.0f + wld_tolerance) && //
-(0.0f + wld_tolerance) < d && d < (1.0f + wld_tolerance) && //
-(0.0f + m_tolerance) < m && //
Expand Down
28 changes: 14 additions & 14 deletions src/search/dag_classic/node.h
Original file line number Diff line number Diff line change
Expand Up @@ -207,8 +207,8 @@ class Edge {
};

struct Eval {
float wl;
float d;
double wl;
double d;
float ml;
};

Expand Down Expand Up @@ -301,11 +301,11 @@ class Node {
// Returns n + n_in_flight.
int GetNStarted() const { return n_ + GetNInFlight(); }

float GetQ(float draw_score) const { return wl_ + draw_score * d_; }
double GetQ(double draw_score) const { return wl_ + draw_score * d_; }
// Returns node eval, i.e. average subtree V for non-terminal node and -1/0/1
// for terminal nodes.
float GetWL() const { return wl_; }
float GetD() const { return d_; }
double GetWL() const { return wl_; }
double GetD() const { return d_; }
float GetM() const { return m_; }

// Returns whether the node is known to be draw/lose/win.
Expand Down Expand Up @@ -335,9 +335,9 @@ class Node {
// * Q (weighted average of all V in a subtree)
// * N (+=multivisit)
// * N-in-flight (-=multivisit)
void FinalizeScoreUpdate(float v, float d, float m, uint32_t multivisit);
double FinalizeScoreUpdate(double v, double d, float m, uint32_t multivisit);
// Like FinalizeScoreUpdate, but it updates n existing visits by delta amount.
void AdjustForTerminal(float v, float d, float m, uint32_t multivisit);
void AdjustForTerminal(double v, double d, float m, double divisor, uint32_t multivisit);
// When search decides to treat one visit as several (in case of collisions
// or visiting terminal nodes several times), it amplifies the visit by
// incrementing n_in_flight.
Expand Down Expand Up @@ -546,8 +546,8 @@ class LowNode {

// Returns node eval, i.e. average subtree V for non-terminal node and -1/0/1
// for terminal nodes.
float GetWL() const { return wl_; }
float GetD() const { return d_; }
double GetWL() const { return wl_; }
double GetD() const { return d_; }
float GetM() const { return m_; }

// Returns whether the node is known to be draw/loss/win.
Expand All @@ -574,9 +574,9 @@ class LowNode {
// * Q (weighted average of all V in a subtree)
// * N (+=multivisit)
// * N-in-flight (-=multivisit)
void FinalizeScoreUpdate(float v, float d, float m, uint32_t multivisit);
double FinalizeScoreUpdate(double v, double d, float m, uint32_t multivisit);
// Like FinalizeScoreUpdate, but it updates n existing visits by delta amount.
void AdjustForTerminal(float v, float d, float m, uint32_t multivisit);
void AdjustForTerminal(double v, double d, float m, double divisor, uint32_t multivisit);

// Deletes all children.
void ReleaseChildren();
Expand Down Expand Up @@ -691,13 +691,13 @@ class EdgeAndNode {
Node* node() const { return node_; }

// Proxy functions for easier access to node/edge.
float GetQ(float default_q, float draw_score) const {
double GetQ(double default_q, double draw_score) const {
return (node_ && node_->GetN() > 0) ? node_->GetQ(draw_score) : default_q;
}
float GetWL(float default_wl) const {
double GetWL(double default_wl) const {
return (node_ && node_->GetN() > 0) ? node_->GetWL() : default_wl;
}
float GetD(float default_d) const {
double GetD(double default_d) const {
return (node_ && node_->GetN() > 0) ? node_->GetD() : default_d;
}
float GetM(float default_m) const {
Expand Down
37 changes: 18 additions & 19 deletions src/search/dag_classic/search.cc
Original file line number Diff line number Diff line change
Expand Up @@ -271,7 +271,8 @@ void ApplyDirichletNoise(LowNode* node, float eps, double alpha) {

namespace {
// WDL conversion formula based on random walk model.
inline double WDLRescale(float& v, float& d, float wdl_rescale_ratio,
template <typename T>
inline double WDLRescale(T& v, T& d, float wdl_rescale_ratio,
float wdl_rescale_diff, float sign, bool invert,
float max_reasonable_s) {
if (invert) {
Expand Down Expand Up @@ -476,7 +477,7 @@ void Search::RecordNPSStartTime() {
}

// Root is depth 0, i.e. even depth.
float Search::GetDrawScore(bool is_odd_depth) const {
double Search::GetDrawScore(bool is_odd_depth) const {
return (is_odd_depth == played_history_.IsBlackToMove()
? params_.GetDrawScore()
: -params_.GetDrawScore());
Expand Down Expand Up @@ -2181,8 +2182,8 @@ void SearchWorker::DoBackupUpdate() {
}

bool SearchWorker::MaybeAdjustForTerminalOrTransposition(
Node* n, const std::shared_ptr<LowNode>& nl, float& v, float& d, float& m,
uint32_t& n_to_fix, float& v_delta, float& d_delta, float& m_delta,
Node* n, const std::shared_ptr<LowNode>& nl, double& v, double& d, float& m,
uint32_t& n_to_fix, double& v_delta, double& d_delta, float& m_delta,
bool& update_parent_bounds) const {
if (n->IsTerminal()) {
v = n->GetWL();
Expand All @@ -2193,7 +2194,7 @@ bool SearchWorker::MaybeAdjustForTerminalOrTransposition(
}

// Use information from transposition or a new terminal.
if (nl->IsTransposition() || nl->IsTerminal() || n->GetN() < nl->GetN()) {
if (nl->IsTransposition() || nl->IsTerminal() || n->GetN() + 1 < nl->GetN()) {
// Adapt information from low node to node by flipping Q sign, bounds,
// result and incrementing m.
v = -nl->GetWL();
Expand Down Expand Up @@ -2269,19 +2270,18 @@ void SearchWorker::DoBackupUpdateSingleNode(
auto update_parent_bounds =
params_.GetStickyEndgames() && n->IsTerminal() && !n->GetN();
const auto& nl = n->GetLowNode();
float v = 0.0f;
float d = 0.0f;
double v = 0.0;
double d = 0.0;
float m = 0.0f;
uint32_t n_to_fix = 0;
float v_delta = 0.0f;
float d_delta = 0.0f;
double v_delta = 0.0;
double d_delta = 0.0;
float m_delta = 0.0f;

// Update the low node at the start of the backup path first, but only visit
// it the first time that backup sees it.
if (nl && nl->GetN() == 0) {
nl->FinalizeScoreUpdate(nl->GetWL(), nl->GetD(), nl->GetM(),
node_to_process.multivisit);
nl->FinalizeScoreUpdate(nl->GetWL(), nl->GetD(), nl->GetM(), 1);
}

if (nr >= 2) {
Expand All @@ -2303,9 +2303,9 @@ void SearchWorker::DoBackupUpdateSingleNode(
// Backup V value up to a root. After 1 visit, V = Q.
for (auto it = path.crbegin(); it != path.crend();
/* ++it in the body */) {
n->FinalizeScoreUpdate(v, d, m, node_to_process.multivisit);
auto divisor = n->FinalizeScoreUpdate(v, d, m, 1);
if (n_to_fix > 0 && !n->IsTerminal()) {
n->AdjustForTerminal(v_delta, d_delta, m_delta, n_to_fix);
n->AdjustForTerminal(v_delta, d_delta, m_delta, divisor, n_to_fix);
}

// Stop delta update on repetition "terminal" and propagate a draw above
Expand Down Expand Up @@ -2335,9 +2335,9 @@ void SearchWorker::DoBackupUpdateSingleNode(
m = pl->GetM();
n_to_fix = 0;
}
pl->FinalizeScoreUpdate(v, d, m, node_to_process.multivisit);
divisor = pl->FinalizeScoreUpdate(v, d, m, 1);
if (n_to_fix > 0) {
pl->AdjustForTerminal(v_delta, d_delta, m_delta, n_to_fix);
pl->AdjustForTerminal(v_delta, d_delta, m_delta, divisor, n_to_fix);
}

bool old_update_parent_bounds = update_parent_bounds;
Expand Down Expand Up @@ -2375,18 +2375,17 @@ void SearchWorker::DoBackupUpdateSingleNode(
nr = pr;
nm = pm;
}
search_->total_playouts_ += node_to_process.multivisit;
search_->total_playouts_ += 1;
if (node_to_process.nn_queried && !node_to_process.is_cache_hit) {
search_->network_evaluations_++;
}
search_->cum_depth_ +=
node_to_process.path.size() * node_to_process.multivisit;
search_->cum_depth_ += node_to_process.path.size();
search_->max_depth_ =
std::max(search_->max_depth_, (uint16_t)node_to_process.path.size());
}

bool SearchWorker::MaybeSetBounds(Node* p, float m, uint32_t* n_to_fix,
float* v_delta, float* d_delta,
double* v_delta, double* d_delta,
float* m_delta) const {
auto losing_m = 0.0f;
auto prefer_tb = false;
Expand Down
12 changes: 6 additions & 6 deletions src/search/dag_classic/search.h
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,7 @@ class Search {
// Returns the draw score at the root of the search. At odd depth pass true to
// the value of @is_odd_depth to change the sign of the draw score.
// Depth of a root node is 0 (even number).
float GetDrawScore(bool is_odd_depth) const;
double GetDrawScore(bool is_odd_depth) const;

mutable Mutex counters_mutex_ ACQUIRED_AFTER(nodes_mutex_);
// Tells all threads to stop.
Expand Down Expand Up @@ -448,14 +448,14 @@ class SearchWorker {
// Return true if adjustment happened.
bool MaybeAdjustForTerminalOrTransposition(Node* n,
const std::shared_ptr<LowNode>& nl,
float& v, float& d, float& m,
uint32_t& n_to_fix, float& v_delta,
float& d_delta, float& m_delta,
double& v, double& d, float& m,
uint32_t& n_to_fix, double& v_delta,
double& d_delta, float& m_delta,
bool& update_parent_bounds) const;
void DoBackupUpdateSingleNode(const NodeToProcess& node_to_process);
// Returns whether a node's bounds were set based on its children.
bool MaybeSetBounds(Node* p, float m, uint32_t* n_to_fix, float* v_delta,
float* d_delta, float* m_delta) const;
bool MaybeSetBounds(Node* p, float m, uint32_t* n_to_fix, double* v_delta,
double* d_delta, float* m_delta) const;
void PickNodesToExtend(int collision_limit);
void PickNodesToExtendTask(const BackupPath& path, int collision_limit,
PositionHistory& history,
Expand Down