Skip to content

Commit dcd21d2

Browse files
Implementation for threads in spec tests
1 parent 6d075a8 commit dcd21d2

File tree

4 files changed

+315
-9
lines changed

4 files changed

+315
-9
lines changed

src/tools/wasm-shell.cpp

Lines changed: 129 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,17 @@ struct Shell {
5353

5454
Options& options;
5555

56+
struct ThreadState {
57+
Name name;
58+
std::vector<WATParser::ScriptEntry> commands;
59+
size_t pc = 0;
60+
bool isSuspended = false;
61+
std::shared_ptr<ModuleRunner> instance = nullptr;
62+
std::shared_ptr<ContData> suspendedCont = nullptr;
63+
bool done = false;
64+
};
65+
std::vector<ThreadState> activeThreads;
66+
5667
Shell(Options& options) : options(options) { buildSpectestModule(); }
5768

5869
Result<> run(WASTScript& script) {
@@ -105,10 +116,119 @@ struct Shell {
105116
// Run threads in a blocking manner for now.
106117
// TODO: yield on blocking instructions e.g. memory.atomic.wait32.
107118
Result<> doThread(ThreadBlock& thread) {
108-
return run(thread.commands);
119+
ThreadState state;
120+
state.name = thread.name;
121+
state.commands = thread.commands;
122+
activeThreads.push_back(std::move(state));
123+
return Ok{};
109124
}
110125

111126
Result<> doWait(Wait& wait) {
127+
bool found = false;
128+
for (auto& t : activeThreads) {
129+
if (t.name == wait.thread) {
130+
found = true;
131+
break;
132+
}
133+
}
134+
if (!found) {
135+
return Err{"wait called for unknown thread"};
136+
}
137+
138+
// Round-robin execution
139+
while (true) {
140+
bool anyProgress = false;
141+
bool targetDone = false;
142+
143+
for (auto& t : activeThreads) {
144+
if (t.done) {
145+
if (t.name == wait.thread)
146+
targetDone = true;
147+
continue;
148+
}
149+
150+
if (t.isSuspended) {
151+
// Check if it's still waiting. WaitQueue sets `isWaiting` to false
152+
// when notified.
153+
bool stillWaiting = t.suspendedCont && t.suspendedCont->isWaiting;
154+
155+
if (!stillWaiting) {
156+
// It was woken up! We need to resume it.
157+
t.isSuspended = false;
158+
Flow flow;
159+
try {
160+
flow = t.instance->resumeContinuation(t.suspendedCont);
161+
} catch (TrapException&) {
162+
std::cerr << "Thread " << t.name << " trapped upon resume\n";
163+
t.done = true;
164+
anyProgress = true;
165+
continue;
166+
} catch (...) {
167+
WASM_UNREACHABLE("unexpected error during resume");
168+
}
169+
t.suspendedCont = nullptr;
170+
171+
if (flow.breakTo == THREAD_SUSPEND_FLOW) {
172+
// Suspended again
173+
t.isSuspended = true;
174+
t.suspendedCont = t.instance->getSuspendedContinuation();
175+
anyProgress = true;
176+
} else if (flow.suspendTag) {
177+
t.instance->clearContinuationStore();
178+
t.done = true; // unhandled suspension
179+
anyProgress = true;
180+
} else {
181+
t.pc++; // Completed the command that originally suspended!
182+
anyProgress = true;
183+
}
184+
}
185+
} else {
186+
// Normal execution of the next command.
187+
if (t.pc < t.commands.size()) {
188+
auto& cmd = t.commands[t.pc].cmd;
189+
if (auto* act = std::get_if<Action>(&cmd)) {
190+
auto result = doAction(*act);
191+
if (std::get_if<ThreadSuspendResult>(&result)) {
192+
t.isSuspended = true;
193+
if (auto* invoke = std::get_if<InvokeAction>(act)) {
194+
t.instance =
195+
instances[invoke->base ? *invoke->base : lastInstance];
196+
t.suspendedCont = t.instance->getSuspendedContinuation();
197+
}
198+
anyProgress = true;
199+
} else {
200+
t.pc++;
201+
anyProgress = true;
202+
}
203+
} else {
204+
// Not an action, just run it (e.g. module instantiation or
205+
// assertions inside thread)
206+
auto res = runCommand(cmd);
207+
if (res.getErr()) {
208+
std::cerr << "Thread " << t.name
209+
<< " error: " << res.getErr()->msg << "\n";
210+
t.done = true;
211+
} else {
212+
t.pc++;
213+
anyProgress = true;
214+
}
215+
}
216+
} else {
217+
t.done = true;
218+
anyProgress = true; // finishing counts as progress
219+
}
220+
}
221+
}
222+
223+
if (targetDone) {
224+
break;
225+
}
226+
227+
if (!anyProgress) {
228+
// Find if target is still suspended
229+
return Err{"deadlock! no threads can make progress"};
230+
}
231+
}
112232
return Ok{};
113233
}
114234

@@ -237,11 +357,13 @@ struct Shell {
237357
struct HostLimitResult {};
238358
struct ExceptionResult {};
239359
struct SuspensionResult {};
360+
struct ThreadSuspendResult {};
240361
using ActionResult = std::variant<Literals,
241362
TrapResult,
242363
HostLimitResult,
243364
ExceptionResult,
244-
SuspensionResult>;
365+
SuspensionResult,
366+
ThreadSuspendResult>;
245367

246368
std::string resultToString(ActionResult& result) {
247369
if (std::get_if<TrapResult>(&result)) {
@@ -252,6 +374,8 @@ struct Shell {
252374
return "exception";
253375
} else if (std::get_if<SuspensionResult>(&result)) {
254376
return "suspension";
377+
} else if (std::get_if<ThreadSuspendResult>(&result)) {
378+
return "thread_suspend";
255379
} else if (auto* vals = std::get_if<Literals>(&result)) {
256380
std::stringstream ss;
257381
ss << *vals;
@@ -281,6 +405,9 @@ struct Shell {
281405
} catch (...) {
282406
WASM_UNREACHABLE("unexpected error");
283407
}
408+
if (flow.breakTo == THREAD_SUSPEND_FLOW) {
409+
return ThreadSuspendResult{};
410+
}
284411
if (flow.suspendTag) {
285412
// This is an unhandled suspension. Handle it here - clear the
286413
// suspension state - so nothing else is affected.

src/wasm-interpreter.h

Lines changed: 138 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,8 @@ struct NonconstantException {};
7171

7272
// Utilities
7373

74-
extern Name RETURN_FLOW, RETURN_CALL_FLOW, NONCONSTANT_FLOW, SUSPEND_FLOW;
74+
extern Name RETURN_FLOW, RETURN_CALL_FLOW, NONCONSTANT_FLOW, SUSPEND_FLOW,
75+
THREAD_SUSPEND_FLOW;
7576

7677
// Stuff that flows around during executing expressions: a literal, or a change
7778
// in control flow.
@@ -87,13 +88,15 @@ class Flow {
8788
: values(std::move(values)), breakTo(breakTo) {}
8889
Flow(Name breakTo, Tag* suspendTag, Literals&& values)
8990
: values(std::move(values)), breakTo(breakTo), suspendTag(suspendTag) {
90-
assert(breakTo == SUSPEND_FLOW);
91+
assert(breakTo == SUSPEND_FLOW || breakTo == THREAD_SUSPEND_FLOW);
9192
}
9293

9394
Literals values;
9495
Name breakTo; // if non-null, a break is going on
9596
Tag* suspendTag = nullptr; // if non-null, breakTo must be SUSPEND_FLOW, and
96-
// this is the tag being suspended
97+
// this is the tag being suspended. If breakTo is
98+
// THREAD_SUSPEND_FLOW, this represents the thread
99+
// suspending and this field is not used.
97100

98101
// A helper function for the common case where there is only one value
99102
const Literal& getSingleValue() {
@@ -281,6 +284,10 @@ struct ContData {
281284
// resume_throw_ref).
282285
Literal exception;
283286

287+
// If set, this continuation was suspended into a wait queue by a thread
288+
// and has not yet been woken up.
289+
bool isWaiting = false;
290+
284291
// Whether we executed. Continuations are one-shot, so they may not be
285292
// executed a second time.
286293
bool executed = false;
@@ -303,6 +310,13 @@ struct ContinuationStore {
303310

304311
// Set when we are resuming execution, that is, re-winding the stack.
305312
bool resuming = false;
313+
314+
// The wait queue for threads waiting on addresses (represented by GCData and
315+
// field index).
316+
std::unordered_map<
317+
std::shared_ptr<GCData>,
318+
std::unordered_map<Index, std::vector<std::shared_ptr<ContData>>>>
319+
waitQueues;
306320
};
307321

308322
// Execute an expression
@@ -2244,13 +2258,90 @@ class ExpressionRunner : public OverriddenVisitor<SubType, Flow> {
22442258
}
22452259

22462260
Flow visitStructWait(StructWait* curr) {
2247-
WASM_UNREACHABLE("struct.wait not implemented");
2248-
return Flow();
2261+
VISIT(ref, curr->ref)
2262+
VISIT(expected, curr->expected)
2263+
VISIT(timeout,
2264+
curr->timeout) // We ignore timeout in the simulation for simplicity
2265+
2266+
auto data = ref.getSingleValue().getGCData();
2267+
if (!data) {
2268+
trap("null ref");
2269+
}
2270+
2271+
auto& field = data->values[curr->index];
2272+
if (field != expected.getSingleValue()) {
2273+
return Literal(int32_t(1)); // not-equal, don't wait
2274+
}
2275+
2276+
if (self()->isResuming()) {
2277+
// We have been notified and resumed.
2278+
// Clear the resume state and continue.
2279+
auto currContinuation = self()->getCurrContinuation();
2280+
assert(curr == currContinuation->resumeExpr);
2281+
self()->continuationStore->resuming = false;
2282+
assert(currContinuation->resumeInfo.empty());
2283+
assert(self()->restoredValuesMap.empty());
2284+
return Literal(int32_t(0)); // ok, woken up
2285+
}
2286+
2287+
// We need to wait. Create a continuation and suspend the thread.
2288+
auto old = self()->getCurrContinuationOrNull();
2289+
if (!old) {
2290+
// Not executing within a continuation, cannot suspend.
2291+
// For wasm-shell simulation, we assume threads are started with
2292+
// ContNew/ContBind.
2293+
return Flow(THREAD_SUSPEND_FLOW); // This will cause a trap up the stack
2294+
// natively if not caught.
2295+
}
2296+
assert(old->executed);
2297+
2298+
auto new_ = std::make_shared<ContData>();
2299+
self()->popCurrContinuation();
2300+
self()->pushCurrContinuation(new_);
2301+
new_->resumeExpr = curr;
2302+
new_->isWaiting = true;
2303+
2304+
self()->continuationStore->waitQueues[data][curr->index].push_back(new_);
2305+
2306+
return Flow(THREAD_SUSPEND_FLOW);
22492307
}
22502308

22512309
Flow visitStructNotify(StructNotify* curr) {
2252-
WASM_UNREACHABLE("struct.notify not implemented");
2253-
return Flow();
2310+
VISIT(ref, curr->ref)
2311+
VISIT(count, curr->count)
2312+
2313+
auto data = ref.getSingleValue().getGCData();
2314+
if (!data) {
2315+
trap("null ref");
2316+
}
2317+
2318+
int32_t countVal = count.getSingleValue().geti32();
2319+
int32_t woken = 0;
2320+
2321+
auto& store = self()->continuationStore;
2322+
auto it1 = store->waitQueues.find(data);
2323+
if (it1 != store->waitQueues.end()) {
2324+
auto& fieldQueues = it1->second;
2325+
auto it2 = fieldQueues.find(curr->index);
2326+
if (it2 != fieldQueues.end()) {
2327+
auto& queue = it2->second;
2328+
while (!queue.empty() && woken < countVal) {
2329+
// The waking thread will be executed by the wasm-shell scheduler.
2330+
// In the reference interpreter, awake continuations should be
2331+
// tracked. Since wasm-shell handles interleaved threads, we don't
2332+
// automatically execute them here. Wait! wasm-shell scheduler needs
2333+
// to know which threads are ready. Our ContinuationStore wait queues
2334+
// structure just pops them. The scheduler wrapper will need a way to
2335+
// track all active threads.
2336+
auto wokeCont = queue.front();
2337+
wokeCont->isWaiting = false;
2338+
queue.erase(queue.begin());
2339+
woken++;
2340+
}
2341+
}
2342+
}
2343+
2344+
return Literal(woken);
22542345
}
22552346

22562347
// Arbitrary deterministic limit on size. If we need to allocate a Literals
@@ -2714,6 +2805,10 @@ class ExpressionRunner : public OverriddenVisitor<SubType, Flow> {
27142805

27152806
virtual void hostLimit(std::string_view why) { WASM_UNREACHABLE("unimp"); }
27162807

2808+
virtual void invokeMain(const std::string& startName) {
2809+
WASM_UNREACHABLE("unimp");
2810+
}
2811+
27172812
virtual void throwException(const WasmException& exn) {
27182813
WASM_UNREACHABLE("unimp");
27192814
}
@@ -3257,6 +3352,42 @@ class ModuleRunnerBase : public ExpressionRunner<SubType> {
32573352

32583353
Flow callExport(Name name) { return callExport(name, Literals()); }
32593354

3355+
std::shared_ptr<ContData> getSuspendedContinuation() {
3356+
return this->getCurrContinuationOrNull();
3357+
}
3358+
3359+
Flow resumeContinuation(std::shared_ptr<ContData> contData,
3360+
Literals arguments = {}) {
3361+
if (contData->executed) {
3362+
this->trap("continuation already executed");
3363+
}
3364+
contData->executed = true;
3365+
3366+
if (contData->resumeArguments.empty()) {
3367+
contData->resumeArguments = arguments;
3368+
}
3369+
3370+
this->pushCurrContinuation(contData);
3371+
this->continuationStore->resuming = true;
3372+
#if WASM_INTERPRETER_DEBUG
3373+
std::cout << this->indent() << "resuming func " << contData->func.getFunc()
3374+
<< '\n';
3375+
#endif
3376+
3377+
Flow ret = contData->func.getFuncData()->doCall(arguments);
3378+
3379+
if (this->isResuming()) {
3380+
// if we didn't suspend again natively, clear resuming flag
3381+
this->continuationStore->resuming = false;
3382+
}
3383+
3384+
if (ret.breakTo != THREAD_SUSPEND_FLOW && !ret.suspendTag) {
3385+
// The coroutine finished normally.
3386+
this->popCurrContinuation();
3387+
}
3388+
return ret;
3389+
}
3390+
32603391
Literal getExportedFunction(Name name) {
32613392
Export* export_ = wasm.getExportOrNull(name);
32623393
if (!export_ || export_->kind != ExternalKind::Function) {

src/wasm/wasm.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ Name RETURN_FLOW("*return:)*");
2828
Name RETURN_CALL_FLOW("*return-call:)*");
2929
Name NONCONSTANT_FLOW("*nonconstant:)*");
3030
Name SUSPEND_FLOW("*suspend:)*");
31+
Name THREAD_SUSPEND_FLOW("*thread_suspend:)*");
3132

3233
namespace BinaryConsts::CustomSections {
3334

0 commit comments

Comments
 (0)