Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,7 @@ In Claude Desktop, try asking:
| `--max-connections` | Maximum connection pool size | `1` |
| `--timeout` | Connection timeout (seconds) | `30` |
| `--log-level` | Log level (error/warn/info/debug) | `info` |
| `--read-only` | Enforce read-only mode (rejects `sql_exec`) | `false` |

## 🛠️ Available Tools

Expand Down
3 changes: 2 additions & 1 deletion README_cn.md
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,7 @@ cargo install --git https://github.com/rbatis/rbdc-mcp.git
| `--max-connections` | 最大连接池大小 | `1` |
| `--timeout` | 连接超时时间(秒) | `30` |
| `--log-level` | 日志级别(error/warn/info/debug) | `info` |
| `--read-only` | 强制只读模式(拒绝 `sql_exec`) | `false` |

## 🛠️ 可用工具

Expand All @@ -153,4 +154,4 @@ cargo install --git https://github.com/rbatis/rbdc-mcp.git

## 许可证

Apache-2.0
Apache-2.0
7 changes: 6 additions & 1 deletion src/db_manager.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ use rbdc_pool_fast::FastPool;
use rbs::Value;
use std::sync::Arc;
use std::time::Duration;
use crate::sql_guard::is_read_only_sql;

/// Supported database types
#[derive(Debug, Clone)]
Expand Down Expand Up @@ -73,6 +74,10 @@ impl DatabaseManager {

/// Execute query and return result set
pub async fn execute_query(&self, sql: &str, params: Vec<Value>) -> Result<Value> {
if !is_read_only_sql(sql) {
return Err(anyhow!("Read-only query validation failed"));
}

let mut conn = self.pool.get().await
.map_err(|e| anyhow!("Failed to get database connection: {}", e))?;
let result = conn.get_values(sql, params).await
Expand Down Expand Up @@ -121,4 +126,4 @@ impl DatabaseManager {

Ok(())
}
}
}
21 changes: 19 additions & 2 deletions src/handler.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
use std::sync::Arc;
use crate::db_manager::DatabaseManager;
use crate::sql_guard::is_read_only_sql;

use rmcp::{
ErrorData as McpError, ServerHandler,
Expand All @@ -12,6 +13,7 @@ use rmcp::{
#[derive(Clone)]
pub struct RbdcDatabaseHandler {
db_manager: Arc<DatabaseManager>,
read_only: bool,
tool_router: ToolRouter<RbdcDatabaseHandler>,
}

Expand All @@ -36,9 +38,10 @@ pub struct SqlExecParams {
// Use tool_router macro to generate the tool router
#[tool_router]
impl RbdcDatabaseHandler {
pub fn new(db_manager: Arc<DatabaseManager>) -> Self {
pub fn new(db_manager: Arc<DatabaseManager>, read_only: bool) -> Self {
Self {
db_manager,
read_only,
tool_router: Self::tool_router(),
}
}
Expand All @@ -55,6 +58,13 @@ impl RbdcDatabaseHandler {
_context: RequestContext<RoleServer>,
Parameters(params): Parameters<SqlQueryParams>,
) -> Result<CallToolResult, McpError> {
if !is_read_only_sql(&params.sql) {
return Err(McpError::invalid_params(
"sql_query only accepts single read-only SQL statements".to_string(),
None,
));
}

// Convert parameter types from serde_json::Value to rbs::Value
let rbs_params = self.convert_params(&params.params);

Expand All @@ -74,6 +84,13 @@ impl RbdcDatabaseHandler {
_context: RequestContext<RoleServer>,
Parameters(params): Parameters<SqlExecParams>,
) -> Result<CallToolResult, McpError> {
if self.read_only {
return Err(McpError::invalid_params(
"sql_exec is disabled when server is started with --read-only".to_string(),
None,
));
}

// Convert parameter types from serde_json::Value to rbs::Value
let rbs_params = self.convert_params(&params.params);

Expand Down Expand Up @@ -126,4 +143,4 @@ impl ServerHandler for RbdcDatabaseHandler {
) -> Result<InitializeResult, McpError> {
Ok(self.get_info())
}
}
}
10 changes: 8 additions & 2 deletions src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ use tracing_subscriber::{EnvFilter};

mod db_manager;
mod handler;
mod sql_guard;

use crate::db_manager::DatabaseManager;
use crate::handler::RbdcDatabaseHandler;
Expand All @@ -31,6 +32,10 @@ struct Args {
/// Log level
#[arg(long, default_value = "info")]
log_level: String,

/// Enforce read-only server mode (blocks sql_exec)
#[arg(long, default_value_t = false)]
read_only: bool,
}

#[tokio::main]
Expand All @@ -50,6 +55,7 @@ async fn main() -> Result<(), anyhow::Error> {

info!("Starting RBDC MCP Server");
info!("Database URL: {}", args.database_url);
info!("Read-only mode: {}", args.read_only);

// Create database manager
let db_manager = DatabaseManager::new(&args.database_url)
Expand All @@ -68,7 +74,7 @@ async fn main() -> Result<(), anyhow::Error> {
info!("Database connection test successful");

// Create RBDC database handler
let handler = RbdcDatabaseHandler::new(Arc::new(db_manager));
let handler = RbdcDatabaseHandler::new(Arc::new(db_manager), args.read_only);

info!("Starting RBDC MCP Server...");

Expand All @@ -79,4 +85,4 @@ async fn main() -> Result<(), anyhow::Error> {

service.waiting().await?;
Ok(())
}
}
213 changes: 213 additions & 0 deletions src/sql_guard.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,213 @@
use std::borrow::Cow;

const READ_ONLY_START_KEYWORDS: &[&str] = &[
"SELECT",
"SHOW",
"DESC",
"DESCRIBE",
"EXPLAIN",
"WITH",
];

const WRITE_KEYWORDS: &[&str] = &[
"INSERT",
"UPDATE",
"DELETE",
"UPSERT",
"REPLACE",
"MERGE",
"CREATE",
"ALTER",
"DROP",
"TRUNCATE",
"GRANT",
"REVOKE",
"COMMIT",
"ROLLBACK",
"BEGIN",
"START",
"VACUUM",
"ANALYZE",
"ATTACH",
"DETACH",
"PRAGMA",
"EXEC",
"EXECUTE",
"CALL",
"DO",
"SET",
"USE",
"LOCK",
"UNLOCK",
];

fn is_ident_char(c: char) -> bool {
c.is_ascii_alphanumeric() || c == '_'
}

fn is_escaped(sql: &[char], pos: usize) -> bool {
if pos == 0 {
return false;
}
let mut backslashes = 0usize;
let mut i = pos;
while i > 0 {
i -= 1;
if sql[i] == '\\' {
backslashes += 1;
} else {
break;
}
}
backslashes % 2 == 1
}

fn collect_uppercase_words_and_semicolons(sql: &str) -> (Vec<String>, usize) {
let chars: Vec<char> = sql.chars().collect();
let mut i = 0usize;
let mut words = Vec::new();
let mut semicolons = 0usize;

while i < chars.len() {
let c = chars[i];
let next = if i + 1 < chars.len() {
Some(chars[i + 1])
} else {
None
};

if c.is_whitespace() {
i += 1;
continue;
}

if c == '-' && next == Some('-') {
i += 2;
while i < chars.len() && chars[i] != '\n' {
i += 1;
}
continue;
}

if c == '/' && next == Some('*') {
i += 2;
while i + 1 < chars.len() {
if chars[i] == '*' && chars[i + 1] == '/' {
i += 2;
break;
}
i += 1;
}
continue;
}

if c == '\'' || c == '"' || c == '`' {
let quote = c;
i += 1;
while i < chars.len() {
if chars[i] == quote && !is_escaped(&chars, i) {
i += 1;
break;
}
i += 1;
}
continue;
}

if c == ';' {
semicolons += 1;
i += 1;
continue;
}

if is_ident_char(c) {
let start = i;
i += 1;
while i < chars.len() && is_ident_char(chars[i]) {
i += 1;
}
let token: Cow<'_, str> = chars[start..i].iter().collect::<String>().into();
words.push(token.to_ascii_uppercase());
continue;
}

i += 1;
}

(words, semicolons)
}

pub fn is_read_only_sql(sql: &str) -> bool {
let trimmed = sql.trim();
if trimmed.is_empty() {
return false;
}

let (words, semicolons) = collect_uppercase_words_and_semicolons(trimmed);
if words.is_empty() {
return false;
}

if semicolons > 1 {
return false;
}

if semicolons == 1 && !trimmed.ends_with(';') {
return false;
}

let first = words.first().map(|s| s.as_str()).unwrap_or_default();
if !READ_ONLY_START_KEYWORDS.contains(&first) {
return false;
}

for word in &words {
if WRITE_KEYWORDS.contains(&word.as_str()) {
return false;
}
}

true
}

#[cfg(test)]
mod tests {
use super::is_read_only_sql;

#[test]
fn allows_plain_select() {
assert!(is_read_only_sql("SELECT * FROM users"));
}

#[test]
fn allows_select_with_string_semicolon() {
assert!(is_read_only_sql("SELECT ';' AS semi"));
}

#[test]
fn allows_explain_select() {
assert!(is_read_only_sql("EXPLAIN SELECT * FROM users"));
}

#[test]
fn rejects_delete() {
assert!(!is_read_only_sql("DELETE FROM users WHERE id = 1"));
}

#[test]
fn rejects_comment_prefixed_delete() {
assert!(!is_read_only_sql("-- note\nDELETE FROM users"));
}

#[test]
fn rejects_multi_statement() {
assert!(!is_read_only_sql("SELECT * FROM users; DELETE FROM users"));
}

#[test]
fn rejects_write_cte() {
assert!(!is_read_only_sql(
"WITH moved AS (DELETE FROM users RETURNING *) SELECT * FROM moved"
));
}
}