Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions bin/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,6 @@ install(PROGRAMS ${PROJECT_BINARY_DIR}/bin/nrngui ${PROJECT_BINARY_DIR}/bin/neur
${PROJECT_BINARY_DIR}/bin/nrnivmodl DESTINATION ${NRN_INSTALL_DATA_PREFIX}bin)

install(FILES ${PROJECT_BINARY_DIR}/bin/nrnmech_makefile DESTINATION ${NRN_INSTALL_DATA_PREFIX}bin)
install(PROGRAMS ${CMAKE_CURRENT_BINARY_DIR}/sortspike ${CMAKE_CURRENT_SOURCE_DIR}/mkthreadsafe
${PROJECT_BINARY_DIR}/bin/nrnpyenv.sh ${CMAKE_CURRENT_SOURCE_DIR}/set_nrnpyenv.sh
install(PROGRAMS ${CMAKE_CURRENT_BINARY_DIR}/sortspike ${PROJECT_BINARY_DIR}/bin/nrnpyenv.sh
${CMAKE_CURRENT_SOURCE_DIR}/set_nrnpyenv.sh
DESTINATION ${NRN_INSTALL_DATA_PREFIX}bin)
82 changes: 0 additions & 82 deletions bin/mkthreadsafe

This file was deleted.

1 change: 1 addition & 0 deletions src/nmodl/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,7 @@ add_subdirectory(utils)
add_subdirectory(visitors)
add_subdirectory(pybind)
add_subdirectory(solver)
add_subdirectory(mkthreadsafe)

# =============================================================================
# NMODL sources
Expand Down
3 changes: 3 additions & 0 deletions src/nmodl/codegen/codegen_naming.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,9 @@ static constexpr char NET_SEND_METHOD[] = "net_send";
/// nrn_pointing function in nmodl
static constexpr char NRN_POINTING_METHOD[] = "nrn_pointing";

/// state_discontinuity function in nmodl
static constexpr char NRN_STATE_DISC_METHOD[] = "state_discontinuity";

/// artificial cell keyword in nmodl
static constexpr char ARTIFICIAL_CELL[] = "ARTIFICIAL_CELL";

Expand Down
7 changes: 7 additions & 0 deletions src/nmodl/mkthreadsafe/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
# =============================================================================
# mkthreadsafe executable
# =============================================================================

add_executable(mkthreadsafe main.cpp)
target_link_libraries(mkthreadsafe CLI11::CLI11 lexer printer util visitor)
install(TARGETS mkthreadsafe DESTINATION ${NMODL_INSTALL_DIR_SUFFIX}bin/)
61 changes: 61 additions & 0 deletions src/nmodl/mkthreadsafe/main.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
/*
* Copyright 2025 EPFL.
* See the top-level LICENSE file for details.
*
* SPDX-License-Identifier: Apache-2.0
*/

#include <filesystem>

#include <CLI/CLI.hpp>

#include "ast/program.hpp"
#include "config/config.h"
#include "parser/nmodl_driver.hpp"
#include "utils/logger.hpp"
#include "visitors/nmodl_visitor.hpp"
#include "visitors/threadsafe_visitor.hpp"

/**
* Standalone mkthreadsafe program for NMODL.
*/

using namespace nmodl;

int main(int argc, const char* argv[]) {
CLI::App app{fmt::format("NMODL-mkthreadsafe : Standalone mkthreadsafe for NMODL({})",
Version::to_string())};

std::vector<std::string> mod_files;

bool convert_globals = false;

bool convert_verbatim = false;

app.add_option("file", mod_files, "One or more NMODL files")
->required()
->check(CLI::ExistingFile);

app.add_flag("--global",
convert_globals,
"Automatically mark threadsafe despite the use of GLOBAL");

app.add_flag("--verbatim",
convert_verbatim,
"Automatically mark threadsafe despite the use of VERBATIM");

CLI11_PARSE(app, argc, argv);


for (const auto& f: mod_files) {
parser::NmodlDriver driver;
const auto& ast = driver.parse_file(f);

logger->info("Running Threadsafe visitor on file {}", f);
visitor::ThreadsafeVisitor(convert_globals, convert_verbatim).visit_program(*ast);
logger->info("Writing AST to NMODL transformation to {}", f);
visitor::NmodlPrintVisitor(f).visit_program(*ast);
}

return 0;
}
1 change: 1 addition & 0 deletions src/nmodl/visitors/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ add_library(
sympy_conductance_visitor.cpp
sympy_replace_solutions_visitor.cpp
sympy_solver_visitor.cpp
threadsafe_visitor.cpp
units_visitor.cpp
var_usage_visitor.cpp
verbatim_var_rename_visitor.cpp
Expand Down
116 changes: 116 additions & 0 deletions src/nmodl/visitors/threadsafe_visitor.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
/*
* Copyright 2025 EPFL.
* See the top-level LICENSE file for details.
*
* SPDX-License-Identifier: Apache-2.0
*/

#include "visitors/threadsafe_visitor.hpp"

#include "ast/all.hpp"
#include "utils/logger.hpp"
#include "visitors/visitor_utils.hpp"
#include "codegen/codegen_naming.hpp"


namespace nmodl {
namespace visitor {
void ThreadsafeVisitor::visit_function_call(ast::FunctionCall& node) {
const auto& name = node.get_node_name();
if (not in_net_receive_block and name == codegen::naming::NRN_STATE_DISC_METHOD) {
can_be_made_threadsafe = false;
}
node.visit_children(*this);
}

void ThreadsafeVisitor::visit_net_receive_block(ast::NetReceiveBlock& node) {
in_net_receive_block = true;
node.visit_children(*this);
in_net_receive_block = false;
}
void ThreadsafeVisitor::visit_initial_block(ast::InitialBlock& node) {
in_initial_block = true;
node.visit_children(*this);
in_initial_block = false;
}

void ThreadsafeVisitor::visit_param_block(ast::ParamBlock& node) {
in_parameter_block = true;
node.visit_children(*this);
in_parameter_block = false;
}

void ThreadsafeVisitor::visit_assigned_block(ast::AssignedBlock& node) {
in_assigned_block = true;
node.visit_children(*this);
in_assigned_block = false;
}


void ThreadsafeVisitor::visit_binary_expression(ast::BinaryExpression& node) {
const auto& lhs = node.get_lhs();
const auto& converted_lhs = std::dynamic_pointer_cast<ast::VarName>(lhs);
const bool condition = not in_initial_block and not in_parameter_block and
not in_assigned_block;
if (condition and converted_lhs and global_vars.count(converted_lhs->get_node_name()) > 0) {
can_be_made_threadsafe = false;
}
node.visit_children(*this);
}

void ThreadsafeVisitor::visit_program(ast::Program& node) {
// short circuit
const auto& threadsafe_vars = collect_nodes(node, {ast::AstNodeType::THREAD_SAFE});
if (threadsafe_vars.size() > 0) {
logger->info("Mechanism already marked as thread-safe, nothing to do");
return;
}

const auto& unsafe_constructs = collect_nodes(node,
{ast::AstNodeType::EXTERNAL,
ast::AstNodeType::LINEAR_BLOCK,
ast::AstNodeType::DISCRETE_BLOCK});
if (unsafe_constructs.size() > 0) {
can_be_made_threadsafe = false;
}

// I don't get this one for now. Can it be made threadsafe?
const auto& pointer_vars = collect_nodes(node, {ast::AstNodeType::POINTER});
if (pointer_vars.size() > 0) {
can_be_made_threadsafe = false;
}

const auto& verbatim_blocks = collect_nodes(node, {ast::AstNodeType::VERBATIM});

if (verbatim_blocks.size() > 0 and not convert_verbatim) {
can_be_made_threadsafe = false;
}

const auto& global_vars_local = collect_nodes(node, {ast::AstNodeType::GLOBAL_VAR});
for (const auto& var: global_vars_local) {
const auto& conv = std::dynamic_pointer_cast<ast::GlobalVar>(var);
global_vars.insert(conv->get_node_name());
}

node.visit_children(*this);

if (not can_be_made_threadsafe) {
logger->warn("Mechanism cannot be made thread safe");
return;
}

const auto& neuron_block = collect_nodes(node, {ast::AstNodeType::NEURON_BLOCK});
if (not neuron_block.size()) {
// TODO insert the block
return;
}
logger->info("Will insert THREADSAFE block");
const auto& converted = std::dynamic_pointer_cast<ast::NeuronBlock>(neuron_block[0]);
const auto& statement_block = converted->get_statement_block();
auto expr_statement = std::make_shared<ast::ExpressionStatement>(
std::make_shared<ast::Name>(std::make_shared<ast::String>("THREADSAFE")));
statement_block->emplace_back_statement(expr_statement);
}

} // namespace visitor
} // namespace nmodl
58 changes: 58 additions & 0 deletions src/nmodl/visitors/threadsafe_visitor.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
/*
* Copyright 2025 EPFL.
* See the top-level LICENSE file for details.
*
* SPDX-License-Identifier: Apache-2.0
*/

#pragma once

/**
* \file
* \brief \copybrief nmodl::visitor::ThreadsafeVisitor
*/

#include <string>
#include <unordered_set>
#include "visitors/ast_visitor.hpp"

namespace nmodl {
namespace visitor {

/**
* \addtogroup visitor_classes
* \{
*/

/**
* \class ThreadsafeVisitor
* \brief Visitor used for making a mod file threadsafe
*/
class ThreadsafeVisitor: public AstVisitor {
private:
bool in_net_receive_block = false;
bool in_initial_block = false;
bool in_parameter_block = false;
bool in_assigned_block = false;
bool can_be_made_threadsafe = true;
std::unordered_set<std::string> global_vars;
bool convert_globals = false;
bool convert_verbatim = false;

public:
void visit_program(ast::Program& node) override;
void visit_assigned_block(ast::AssignedBlock&) override;
void visit_function_call(ast::FunctionCall&) override;
void visit_initial_block(ast::InitialBlock&) override;
void visit_net_receive_block(ast::NetReceiveBlock&) override;
void visit_param_block(ast::ParamBlock&) override;
void visit_binary_expression(ast::BinaryExpression&) override;
explicit ThreadsafeVisitor(bool convert_globals = false, bool convert_verbatim = false)
: convert_globals(convert_globals)
, convert_verbatim(convert_verbatim){};
};

/** \} */ // end of visitor_classes

} // namespace visitor
} // namespace nmodl
Loading