Skip to content

Commit 67a7b01

Browse files
RobBuchananCompPhysrprospero
authored andcommitted
fix: Iterable graph bug on Graph::process empty/non-empty edges (#2313)
1 parent 5da9c0a commit 67a7b01

File tree

7 files changed

+156
-2
lines changed

7 files changed

+156
-2
lines changed

src/nodes/graph.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ NodeConstants::ProcessResult Graph::process()
4646
// Check each node for output edges - any that have zero output edges need to be run()
4747
auto terminalNodeResult = NodeConstants::ProcessResult::Unchanged;
4848
for (auto &&[nodeName, node] : nodes_)
49-
if (!node->outputEdges().empty())
49+
if (node->outputEdges().empty())
5050
{
5151
switch (node->run())
5252
{

src/nodes/iterableGraph.cpp

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,8 @@ bool IterableGraph::addEdge(const EdgeDefinition &definition)
6666

6767
loopEdges_.emplace_back(LoopEdge::makeLoopEdge(edge.release(), proxyInputs()));
6868

69+
addOutputLoopEdge(definition.sourceOutput, loopEdges_.back().get());
70+
6971
return true;
7072
}
7173

@@ -82,6 +84,8 @@ bool IterableGraph::removeEdge(const EdgeDefinition &definition)
8284
releaseLoopBack(definition.targetInput);
8385
else
8486
return false;
87+
88+
removeOutputLoopEdge(definition.sourceOutput, static_cast<Edge *>(loopEdge));
8589
}
8690
return true;
8791
}
@@ -114,6 +118,42 @@ LoopEdge *IterableGraph::findLoopEdge(const EdgeDefinition &definition) const
114118
return {};
115119
}
116120

121+
// Add edge to node map
122+
Edge *IterableGraph::addOutputLoopEdge(std::string_view sourceOutput, Edge *edge)
123+
{
124+
auto &outputEdgeMap = loopBacks()->outputEdges();
125+
auto outputEdges = outputEdgeMap.find(sourceOutput);
126+
127+
// If source not in edge map, insert it, otherwise push new edge to current source node
128+
if (outputEdges != outputEdgeMap.end())
129+
{
130+
outputEdges->second.push_back(edge);
131+
return edge;
132+
}
133+
else
134+
{
135+
outputEdgeMap.insert({sourceOutput, {edge}});
136+
return edge;
137+
}
138+
139+
return nullptr;
140+
}
141+
142+
// Remove edge from node map
143+
Edge *IterableGraph::removeOutputLoopEdge(std::string_view sourceOutput, Edge *edge)
144+
{
145+
auto outputEdgeMap = loopBacks()->outputEdges();
146+
auto outputEdges = outputEdgeMap.find(sourceOutput);
147+
if (outputEdges != outputEdgeMap.end())
148+
{
149+
auto removedEdge = std::remove(outputEdges->second.begin(), outputEdges->second.end(), edge);
150+
outputEdges->second.erase(removedEdge, outputEdges->second.end());
151+
return edge;
152+
}
153+
154+
return nullptr;
155+
}
156+
117157
/*
118158
* Processing & Validity
119159
*/

src/nodes/iterableGraph.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,12 @@ class IterableGraph : public Graph
6464
// Find loop edge between nodes
6565
LoopEdge *findLoopEdge(const EdgeDefinition &definition) const;
6666

67+
private:
68+
// Add edge to node map
69+
Edge *addOutputLoopEdge(std::string_view sourceOutput, Edge *edge);
70+
// Remove edge from node map
71+
Edge *removeOutputLoopEdge(std::string_view sourceOutput, Edge *edge);
72+
6773
/*
6874
* Processing & Validity
6975
*/

src/nodes/loopBack.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,9 @@ NodeConstants::ProcessResult LoopBacksNode::run()
4646
return status;
4747
}
4848

49+
// Get the outgoing edges from this node
50+
Node::EdgeMap &LoopBacksNode::outputEdges() { return loopEdges_; }
51+
4952
// Flag that the node data needs to be updated
5053
void LoopBacksNode::setUpdateRequired()
5154
{

src/nodes/loopBack.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,12 @@ class LoopBacksNode : public Node
3232
// Run the node, retrieving dependent inputs as necessary
3333
NodeConstants::ProcessResult run() override;
3434

35+
// Get the outgoing edges from this node
36+
Node::EdgeMap &outputEdges() override;
37+
38+
private:
39+
Node::EdgeMap loopEdges_;
40+
3541
public:
3642
// Flag that the node data needs to be updated
3743
void setUpdateRequired() override;

src/nodes/node.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -292,7 +292,7 @@ class Node : public Serialisable<>
292292
// Get the incoming edges to this node
293293
EdgeMap &inputEdges();
294294
// Get the outgoing edges from this node
295-
EdgeMap &outputEdges();
295+
virtual EdgeMap &outputEdges();
296296
// Mark incoming edges to the specified parameter as needing a re-pull
297297
void markIncomingEdgesForPull(const ParameterBase *toParameter) const;
298298
// Returns the node parent graph

tests/nodes/loop.cpp

Lines changed: 99 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,105 @@ class IterableGraphTest : public ::testing::Test
7373
IterableGraph *loop_{nullptr};
7474
};
7575

76+
TEST_F(IterableGraphTest, BasicNonLoopingSeries)
77+
{
78+
CoreData coreData;
79+
Dissolve dissolve(coreData);
80+
auto root = std::make_unique<DissolveGraph>(dissolve);
81+
auto loop = dynamic_cast<IterableGraph *>(root->createNode("Iterator", "Iterator"));
82+
auto i = dynamic_cast<NumberNode *>(root->createNode("Number", "i"));
83+
auto a = dynamic_cast<AddNode *>(loop->createNode("Add", "a"));
84+
auto b = dynamic_cast<AddNode *>(loop->createNode("Add", "b"));
85+
auto c = dynamic_cast<AddNode *>(loop->createNode("Add", "c"));
86+
ASSERT_TRUE(loop->setOption<Number>("N", 1));
87+
EXPECT_TRUE(root->addEdge({"i", "X", "Iterator", "I"}));
88+
89+
/*
90+
* i, 1 -> itA, 1 + 1 = 2
91+
*/
92+
93+
// This should actually result in the sole internal node not being run
94+
i->setOption<Number>("X", 1);
95+
a->setInput<Number>("Y", 1);
96+
b->setInput<Number>("Y", 1);
97+
c->setInput<Number>("Y", 1);
98+
EXPECT_TRUE(loop->addEdge({"Inputs", "I", "a", "X"}));
99+
EXPECT_EQ(loop->run(), NodeConstants::ProcessResult::Success);
100+
auto res1 = a->getOutputValue<Number>("Result").asInteger();
101+
EXPECT_TRUE(a->versionIndex() == 0);
102+
EXPECT_TRUE(loop->versionIndex() == 0);
103+
ASSERT_EQ(res1, 2);
104+
105+
// No loopbacks occur
106+
EXPECT_TRUE(loop->loopBacks()->versionIndex() == -1);
107+
108+
/*
109+
* i, 1 -> itA, 1 + 1 = 2 -> itB, 1 + 2 = 3
110+
*/
111+
EXPECT_TRUE(loop->addEdge({"a", "Result", "b", "X"}));
112+
ASSERT_TRUE(loop->setOption<Number>("N", 1));
113+
EXPECT_EQ(loop->run(), NodeConstants::ProcessResult::Success);
114+
auto res2 = b->getOutputValue<Number>("Result").asInteger();
115+
EXPECT_TRUE(a->versionIndex() == 1);
116+
EXPECT_TRUE(b->versionIndex() == 0);
117+
EXPECT_TRUE(loop->versionIndex() == 1);
118+
ASSERT_EQ(res2, 3);
119+
120+
// No loopbacks occur
121+
EXPECT_TRUE(loop->loopBacks()->versionIndex() == -1);
122+
123+
/*
124+
* i, 1 -> itA, 1 + 1 = 2 -> itB, 1 + 2 = 3 -> itC, 1 + 3 = 4
125+
*/
126+
EXPECT_TRUE(loop->addEdge({"b", "Result", "c", "X"}));
127+
ASSERT_TRUE(loop->setOption<Number>("N", 1));
128+
EXPECT_EQ(loop->run(), NodeConstants::ProcessResult::Success);
129+
auto res3 = c->getOutputValue<Number>("Result").asInteger();
130+
EXPECT_TRUE(a->versionIndex() == 2);
131+
EXPECT_TRUE(b->versionIndex() == 1);
132+
EXPECT_TRUE(c->versionIndex() == 0);
133+
EXPECT_TRUE(loop->versionIndex() == 2);
134+
ASSERT_EQ(res3, 4);
135+
136+
// No loopbacks occur
137+
EXPECT_TRUE(loop->loopBacks()->versionIndex() == -1);
138+
139+
// Run 100 times
140+
ASSERT_TRUE(loop->setOption<Number>("N", 100));
141+
EXPECT_EQ(loop->run(), NodeConstants::ProcessResult::Success);
142+
ASSERT_EQ(a->getOutputValue<Number>("Result").asInteger(), 2);
143+
ASSERT_EQ(b->getOutputValue<Number>("Result").asInteger(), 3);
144+
ASSERT_EQ(c->getOutputValue<Number>("Result").asInteger(), 4);
145+
EXPECT_TRUE(a->versionIndex() == 3);
146+
EXPECT_TRUE(b->versionIndex() == 2);
147+
EXPECT_TRUE(c->versionIndex() == 1);
148+
EXPECT_TRUE(loop->versionIndex() == 3);
149+
150+
// No loopbacks occur
151+
EXPECT_TRUE(loop->loopBacks()->versionIndex() == -1);
152+
153+
// Upstream node change
154+
i->setOption<Number>("X", 2);
155+
156+
/*
157+
* i, 2 -> itA, 1 + 2 = 3 -> itB, 1 + 3 = 4 -> itC, 1 + 4 = 5
158+
*/
159+
160+
// Run again
161+
ASSERT_TRUE(loop->setOption<Number>("N", 1));
162+
EXPECT_EQ(loop->run(), NodeConstants::ProcessResult::Success);
163+
ASSERT_EQ(a->getOutputValue<Number>("Result").asInteger(), 3);
164+
ASSERT_EQ(b->getOutputValue<Number>("Result").asInteger(), 4);
165+
ASSERT_EQ(c->getOutputValue<Number>("Result").asInteger(), 5);
166+
EXPECT_TRUE(a->versionIndex() == 4);
167+
EXPECT_TRUE(b->versionIndex() == 3);
168+
EXPECT_TRUE(c->versionIndex() == 2);
169+
EXPECT_TRUE(loop->versionIndex() == 4);
170+
171+
// No loopbacks occur
172+
EXPECT_TRUE(loop->loopBacks()->versionIndex() == -1);
173+
}
174+
76175
TEST_F(IterableGraphTest, NoRun)
77176
{
78177
createGraph();

0 commit comments

Comments
 (0)