Skip to content

Commit 22cb3f4

Browse files
committed
fix: prevent SQL injection in workflow custom SQL execution
1 parent 72cecff commit 22cb3f4

File tree

1 file changed

+60
-0
lines changed

1 file changed

+60
-0
lines changed

backend/domain/memory/database/service/database_impl.go

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,8 +24,10 @@ import (
2424
"errors"
2525
"fmt"
2626
"math/rand"
27+
"regexp"
2728
"runtime/debug"
2829
"strconv"
30+
"strings"
2931
"time"
3032

3133
"github.com/tealeg/xlsx/v3"
@@ -1045,6 +1047,10 @@ func (d databaseService) executeCustomSQL(ctx context.Context, req *ExecuteSQLRe
10451047
return nil, fmt.Errorf("SQL is empty")
10461048
}
10471049

1050+
if err := validateCustomSQL(*req.SQL); err != nil {
1051+
return nil, fmt.Errorf("SQL validation failed: %v", err)
1052+
}
1053+
10481054
operation, err := sqlparser.New().GetSQLOperation(*req.SQL)
10491055
if err != nil {
10501056
return nil, err
@@ -1076,6 +1082,10 @@ func (d databaseService) executeCustomSQL(ctx context.Context, req *ExecuteSQLRe
10761082
if err != nil {
10771083
return nil, fmt.Errorf("parse sql failed: %v", err)
10781084
}
1085+
1086+
if err := validateParsedSQL(parsedSQL); err != nil {
1087+
return nil, fmt.Errorf("SQL validation failed: %v", err)
1088+
}
10791089
// add rw mode
10801090
if tableInfo.RwMode == table.BotTableRWMode_LimitedReadWrite && len(req.UserID) != 0 {
10811091
switch operation {
@@ -2189,3 +2199,53 @@ func generateComplexCond(ctx context.Context, req *ExecuteSQLRequest, mode table
21892199
return nil, nil
21902200

21912201
}
2202+
2203+
var allowedTableNamePattern = regexp.MustCompile(`^table_\d+$`)
2204+
2205+
2206+
func validateCustomSQL(sql string) error {
2207+
upperSQL := strings.ToUpper(sql)
2208+
2209+
if strings.Contains(upperSQL, "UNION") {
2210+
return fmt.Errorf("UNION queries are not allowed")
2211+
}
2212+
2213+
dangerousTables := []string{"INFORMATION_SCHEMA", "MYSQL.", "PERFORMANCE_SCHEMA", "SYS."}
2214+
for _, t := range dangerousTables {
2215+
if strings.Contains(upperSQL, t) {
2216+
return fmt.Errorf("access to system tables is not allowed")
2217+
}
2218+
}
2219+
2220+
dangerousFuncs := []string{
2221+
"CURRENT_USER", "USER()", "SESSION_USER", "SYSTEM_USER",
2222+
"LOAD_FILE", "INTO OUTFILE", "INTO DUMPFILE",
2223+
"BENCHMARK(", "SLEEP(",
2224+
}
2225+
for _, f := range dangerousFuncs {
2226+
if strings.Contains(upperSQL, f) {
2227+
return fmt.Errorf("dangerous function %s is not allowed", f)
2228+
}
2229+
}
2230+
2231+
return nil
2232+
}
2233+
2234+
func validateParsedSQL(parsedSQL string) error {
2235+
tableNamePattern := regexp.MustCompile(`(?i)\b(FROM|JOIN|INTO|UPDATE)\s+` + "`?" + `(\w+)` + "`?")
2236+
matches := tableNamePattern.FindAllStringSubmatch(parsedSQL, -1)
2237+
2238+
for _, match := range matches {
2239+
if len(match) >= 3 {
2240+
tableName := match[2]
2241+
if tableName == "dual" || tableName == "DUAL" {
2242+
continue
2243+
}
2244+
if !allowedTableNamePattern.MatchString(tableName) {
2245+
return fmt.Errorf("invalid table name: %s, only table_<id> format is allowed", tableName)
2246+
}
2247+
}
2248+
}
2249+
2250+
return nil
2251+
}

0 commit comments

Comments
 (0)