diff --git a/src/search/dag_classic/node.cc b/src/search/dag_classic/node.cc index 1ad4762f0e..6136707970 100644 --- a/src/search/dag_classic/node.cc +++ b/src/search/dag_classic/node.cc @@ -343,54 +343,60 @@ void Node::CancelScoreUpdate(uint32_t multivisit) { n_in_flight_.fetch_sub(multivisit, std::memory_order_acq_rel); } -void LowNode::FinalizeScoreUpdate(float v, float d, float m, - uint32_t multivisit) { +double LowNode::FinalizeScoreUpdate(double v, double d, float m, + uint32_t multivisit) { assert(edges_); + // Increment N. + n_ += multivisit; + // Recompute Q. - wl_ += multivisit * (v - wl_) / (n_ + multivisit); - d_ += multivisit * (d - d_) / (n_ + multivisit); - m_ += multivisit * (m - m_) / (n_ + multivisit); + double divisor = 1.0 / n_; + wl_ += multivisit * (v - wl_) * divisor; + d_ += multivisit * (d - d_) * divisor; + m_ += multivisit * (m - m_) * static_cast(divisor); assert(WLDMInvariantsHold()); - - // Increment N. - n_ += multivisit; + return divisor; } -void LowNode::AdjustForTerminal(float v, float d, float m, +void LowNode::AdjustForTerminal(double v, double d, float m, double divisor, uint32_t multivisit) { assert(static_cast(multivisit) <= n_); // Recompute Q. - wl_ += multivisit * v / n_; - d_ += multivisit * d / n_; - m_ += multivisit * m / n_; + wl_ += multivisit * v * divisor; + d_ += multivisit * d * divisor; + m_ += multivisit * m * static_cast(divisor); assert(WLDMInvariantsHold()); } -void Node::FinalizeScoreUpdate(float v, float d, float m, uint32_t multivisit) { +double Node::FinalizeScoreUpdate(double v, double d, float m, + uint32_t multivisit) { + // Increment N. + n_ += multivisit; + // Recompute Q. - wl_ += multivisit * (v - wl_) / (n_ + multivisit); - d_ += multivisit * (d - d_) / (n_ + multivisit); - m_ += multivisit * (m - m_) / (n_ + multivisit); + double divisor = 1.0 / n_; + wl_ += multivisit * (v - wl_) * divisor; + d_ += multivisit * (d - d_) * divisor; + m_ += multivisit * (m - m_) * static_cast(divisor); assert(WLDMInvariantsHold()); - - // Increment N. - n_ += multivisit; // Decrement virtual loss. assert(GetNInFlight() >= (uint32_t)multivisit); n_in_flight_.fetch_sub(multivisit, std::memory_order_acq_rel); + return divisor; } -void Node::AdjustForTerminal(float v, float d, float m, uint32_t multivisit) { +void Node::AdjustForTerminal(double v, double d, float m, double divisor, + uint32_t multivisit) { assert(static_cast(multivisit) <= n_); // Recompute Q. - wl_ += multivisit * v / n_; - d_ += multivisit * d / n_; - m_ += multivisit * m / n_; + wl_ += multivisit * v * divisor; + d_ += multivisit * d * divisor; + m_ += multivisit * m * static_cast(divisor); assert(WLDMInvariantsHold()); } @@ -610,10 +616,10 @@ void Node::SortEdges() const { low_node_->SortEdges(); } -static constexpr float wld_tolerance = 0.000001f; +static constexpr double wld_tolerance = 0.000001f; static constexpr float m_tolerance = 0.000001f; -static bool WLDMInvariantsHold(float wl, float d, float m) { +static bool WLDMInvariantsHold(double wl, double d, float m) { return -(1.0f + wld_tolerance) < wl && wl < (1.0f + wld_tolerance) && // -(0.0f + wld_tolerance) < d && d < (1.0f + wld_tolerance) && // -(0.0f + m_tolerance) < m && // diff --git a/src/search/dag_classic/node.h b/src/search/dag_classic/node.h index b74a64a9d4..dac151f00d 100644 --- a/src/search/dag_classic/node.h +++ b/src/search/dag_classic/node.h @@ -207,8 +207,8 @@ class Edge { }; struct Eval { - float wl; - float d; + double wl; + double d; float ml; }; @@ -301,11 +301,11 @@ class Node { // Returns n + n_in_flight. int GetNStarted() const { return n_ + GetNInFlight(); } - float GetQ(float draw_score) const { return wl_ + draw_score * d_; } + double GetQ(double draw_score) const { return wl_ + draw_score * d_; } // Returns node eval, i.e. average subtree V for non-terminal node and -1/0/1 // for terminal nodes. - float GetWL() const { return wl_; } - float GetD() const { return d_; } + double GetWL() const { return wl_; } + double GetD() const { return d_; } float GetM() const { return m_; } // Returns whether the node is known to be draw/lose/win. @@ -335,9 +335,9 @@ class Node { // * Q (weighted average of all V in a subtree) // * N (+=multivisit) // * N-in-flight (-=multivisit) - void FinalizeScoreUpdate(float v, float d, float m, uint32_t multivisit); + double FinalizeScoreUpdate(double v, double d, float m, uint32_t multivisit); // Like FinalizeScoreUpdate, but it updates n existing visits by delta amount. - void AdjustForTerminal(float v, float d, float m, uint32_t multivisit); + void AdjustForTerminal(double v, double d, float m, double divisor, uint32_t multivisit); // When search decides to treat one visit as several (in case of collisions // or visiting terminal nodes several times), it amplifies the visit by // incrementing n_in_flight. @@ -546,8 +546,8 @@ class LowNode { // Returns node eval, i.e. average subtree V for non-terminal node and -1/0/1 // for terminal nodes. - float GetWL() const { return wl_; } - float GetD() const { return d_; } + double GetWL() const { return wl_; } + double GetD() const { return d_; } float GetM() const { return m_; } // Returns whether the node is known to be draw/loss/win. @@ -574,9 +574,9 @@ class LowNode { // * Q (weighted average of all V in a subtree) // * N (+=multivisit) // * N-in-flight (-=multivisit) - void FinalizeScoreUpdate(float v, float d, float m, uint32_t multivisit); + double FinalizeScoreUpdate(double v, double d, float m, uint32_t multivisit); // Like FinalizeScoreUpdate, but it updates n existing visits by delta amount. - void AdjustForTerminal(float v, float d, float m, uint32_t multivisit); + void AdjustForTerminal(double v, double d, float m, double divisor, uint32_t multivisit); // Deletes all children. void ReleaseChildren(); @@ -691,13 +691,13 @@ class EdgeAndNode { Node* node() const { return node_; } // Proxy functions for easier access to node/edge. - float GetQ(float default_q, float draw_score) const { + double GetQ(double default_q, double draw_score) const { return (node_ && node_->GetN() > 0) ? node_->GetQ(draw_score) : default_q; } - float GetWL(float default_wl) const { + double GetWL(double default_wl) const { return (node_ && node_->GetN() > 0) ? node_->GetWL() : default_wl; } - float GetD(float default_d) const { + double GetD(double default_d) const { return (node_ && node_->GetN() > 0) ? node_->GetD() : default_d; } float GetM(float default_m) const { diff --git a/src/search/dag_classic/search.cc b/src/search/dag_classic/search.cc index 6861371dc6..d40ef260f5 100644 --- a/src/search/dag_classic/search.cc +++ b/src/search/dag_classic/search.cc @@ -271,7 +271,8 @@ void ApplyDirichletNoise(LowNode* node, float eps, double alpha) { namespace { // WDL conversion formula based on random walk model. -inline double WDLRescale(float& v, float& d, float wdl_rescale_ratio, +template +inline double WDLRescale(T& v, T& d, float wdl_rescale_ratio, float wdl_rescale_diff, float sign, bool invert, float max_reasonable_s) { if (invert) { @@ -476,7 +477,7 @@ void Search::RecordNPSStartTime() { } // Root is depth 0, i.e. even depth. -float Search::GetDrawScore(bool is_odd_depth) const { +double Search::GetDrawScore(bool is_odd_depth) const { return (is_odd_depth == played_history_.IsBlackToMove() ? params_.GetDrawScore() : -params_.GetDrawScore()); @@ -2181,8 +2182,8 @@ void SearchWorker::DoBackupUpdate() { } bool SearchWorker::MaybeAdjustForTerminalOrTransposition( - Node* n, const std::shared_ptr& nl, float& v, float& d, float& m, - uint32_t& n_to_fix, float& v_delta, float& d_delta, float& m_delta, + Node* n, const std::shared_ptr& nl, double& v, double& d, float& m, + uint32_t& n_to_fix, double& v_delta, double& d_delta, float& m_delta, bool& update_parent_bounds) const { if (n->IsTerminal()) { v = n->GetWL(); @@ -2193,7 +2194,7 @@ bool SearchWorker::MaybeAdjustForTerminalOrTransposition( } // Use information from transposition or a new terminal. - if (nl->IsTransposition() || nl->IsTerminal() || n->GetN() < nl->GetN()) { + if (nl->IsTransposition() || nl->IsTerminal() || n->GetN() + 1 < nl->GetN()) { // Adapt information from low node to node by flipping Q sign, bounds, // result and incrementing m. v = -nl->GetWL(); @@ -2269,19 +2270,18 @@ void SearchWorker::DoBackupUpdateSingleNode( auto update_parent_bounds = params_.GetStickyEndgames() && n->IsTerminal() && !n->GetN(); const auto& nl = n->GetLowNode(); - float v = 0.0f; - float d = 0.0f; + double v = 0.0; + double d = 0.0; float m = 0.0f; uint32_t n_to_fix = 0; - float v_delta = 0.0f; - float d_delta = 0.0f; + double v_delta = 0.0; + double d_delta = 0.0; float m_delta = 0.0f; // Update the low node at the start of the backup path first, but only visit // it the first time that backup sees it. if (nl && nl->GetN() == 0) { - nl->FinalizeScoreUpdate(nl->GetWL(), nl->GetD(), nl->GetM(), - node_to_process.multivisit); + nl->FinalizeScoreUpdate(nl->GetWL(), nl->GetD(), nl->GetM(), 1); } if (nr >= 2) { @@ -2303,9 +2303,9 @@ void SearchWorker::DoBackupUpdateSingleNode( // Backup V value up to a root. After 1 visit, V = Q. for (auto it = path.crbegin(); it != path.crend(); /* ++it in the body */) { - n->FinalizeScoreUpdate(v, d, m, node_to_process.multivisit); + auto divisor = n->FinalizeScoreUpdate(v, d, m, 1); if (n_to_fix > 0 && !n->IsTerminal()) { - n->AdjustForTerminal(v_delta, d_delta, m_delta, n_to_fix); + n->AdjustForTerminal(v_delta, d_delta, m_delta, divisor, n_to_fix); } // Stop delta update on repetition "terminal" and propagate a draw above @@ -2335,9 +2335,9 @@ void SearchWorker::DoBackupUpdateSingleNode( m = pl->GetM(); n_to_fix = 0; } - pl->FinalizeScoreUpdate(v, d, m, node_to_process.multivisit); + divisor = pl->FinalizeScoreUpdate(v, d, m, 1); if (n_to_fix > 0) { - pl->AdjustForTerminal(v_delta, d_delta, m_delta, n_to_fix); + pl->AdjustForTerminal(v_delta, d_delta, m_delta, divisor, n_to_fix); } bool old_update_parent_bounds = update_parent_bounds; @@ -2375,18 +2375,17 @@ void SearchWorker::DoBackupUpdateSingleNode( nr = pr; nm = pm; } - search_->total_playouts_ += node_to_process.multivisit; + search_->total_playouts_ += 1; if (node_to_process.nn_queried && !node_to_process.is_cache_hit) { search_->network_evaluations_++; } - search_->cum_depth_ += - node_to_process.path.size() * node_to_process.multivisit; + search_->cum_depth_ += node_to_process.path.size(); search_->max_depth_ = std::max(search_->max_depth_, (uint16_t)node_to_process.path.size()); } bool SearchWorker::MaybeSetBounds(Node* p, float m, uint32_t* n_to_fix, - float* v_delta, float* d_delta, + double* v_delta, double* d_delta, float* m_delta) const { auto losing_m = 0.0f; auto prefer_tb = false; diff --git a/src/search/dag_classic/search.h b/src/search/dag_classic/search.h index a60fb7e705..b179a21ad2 100644 --- a/src/search/dag_classic/search.h +++ b/src/search/dag_classic/search.h @@ -143,7 +143,7 @@ class Search { // Returns the draw score at the root of the search. At odd depth pass true to // the value of @is_odd_depth to change the sign of the draw score. // Depth of a root node is 0 (even number). - float GetDrawScore(bool is_odd_depth) const; + double GetDrawScore(bool is_odd_depth) const; mutable Mutex counters_mutex_ ACQUIRED_AFTER(nodes_mutex_); // Tells all threads to stop. @@ -448,14 +448,14 @@ class SearchWorker { // Return true if adjustment happened. bool MaybeAdjustForTerminalOrTransposition(Node* n, const std::shared_ptr& nl, - float& v, float& d, float& m, - uint32_t& n_to_fix, float& v_delta, - float& d_delta, float& m_delta, + double& v, double& d, float& m, + uint32_t& n_to_fix, double& v_delta, + double& d_delta, float& m_delta, bool& update_parent_bounds) const; void DoBackupUpdateSingleNode(const NodeToProcess& node_to_process); // Returns whether a node's bounds were set based on its children. - bool MaybeSetBounds(Node* p, float m, uint32_t* n_to_fix, float* v_delta, - float* d_delta, float* m_delta) const; + bool MaybeSetBounds(Node* p, float m, uint32_t* n_to_fix, double* v_delta, + double* d_delta, float* m_delta) const; void PickNodesToExtend(int collision_limit); void PickNodesToExtendTask(const BackupPath& path, int collision_limit, PositionHistory& history,