@@ -24,8 +24,10 @@ import (
2424 "errors"
2525 "fmt"
2626 "math/rand"
27+ "regexp"
2728 "runtime/debug"
2829 "strconv"
30+ "strings"
2931 "time"
3032
3133 "github.com/tealeg/xlsx/v3"
@@ -1045,6 +1047,10 @@ func (d databaseService) executeCustomSQL(ctx context.Context, req *ExecuteSQLRe
10451047 return nil , fmt .Errorf ("SQL is empty" )
10461048 }
10471049
1050+ if err := validateCustomSQL (* req .SQL ); err != nil {
1051+ return nil , fmt .Errorf ("SQL validation failed: %v" , err )
1052+ }
1053+
10481054 operation , err := sqlparser .New ().GetSQLOperation (* req .SQL )
10491055 if err != nil {
10501056 return nil , err
@@ -1076,6 +1082,10 @@ func (d databaseService) executeCustomSQL(ctx context.Context, req *ExecuteSQLRe
10761082 if err != nil {
10771083 return nil , fmt .Errorf ("parse sql failed: %v" , err )
10781084 }
1085+
1086+ if err := validateParsedSQL (parsedSQL ); err != nil {
1087+ return nil , fmt .Errorf ("SQL validation failed: %v" , err )
1088+ }
10791089 // add rw mode
10801090 if tableInfo .RwMode == table .BotTableRWMode_LimitedReadWrite && len (req .UserID ) != 0 {
10811091 switch operation {
@@ -2189,3 +2199,53 @@ func generateComplexCond(ctx context.Context, req *ExecuteSQLRequest, mode table
21892199 return nil , nil
21902200
21912201}
2202+
2203+ var allowedTableNamePattern = regexp .MustCompile (`^table_\d+$` )
2204+
2205+
2206+ func validateCustomSQL (sql string ) error {
2207+ upperSQL := strings .ToUpper (sql )
2208+
2209+ if strings .Contains (upperSQL , "UNION" ) {
2210+ return fmt .Errorf ("UNION queries are not allowed" )
2211+ }
2212+
2213+ dangerousTables := []string {"INFORMATION_SCHEMA" , "MYSQL." , "PERFORMANCE_SCHEMA" , "SYS." }
2214+ for _ , t := range dangerousTables {
2215+ if strings .Contains (upperSQL , t ) {
2216+ return fmt .Errorf ("access to system tables is not allowed" )
2217+ }
2218+ }
2219+
2220+ dangerousFuncs := []string {
2221+ "CURRENT_USER" , "USER()" , "SESSION_USER" , "SYSTEM_USER" ,
2222+ "LOAD_FILE" , "INTO OUTFILE" , "INTO DUMPFILE" ,
2223+ "BENCHMARK(" , "SLEEP(" ,
2224+ }
2225+ for _ , f := range dangerousFuncs {
2226+ if strings .Contains (upperSQL , f ) {
2227+ return fmt .Errorf ("dangerous function %s is not allowed" , f )
2228+ }
2229+ }
2230+
2231+ return nil
2232+ }
2233+
2234+ func validateParsedSQL (parsedSQL string ) error {
2235+ tableNamePattern := regexp .MustCompile (`(?i)\b(FROM|JOIN|INTO|UPDATE)\s+` + "`?" + `(\w+)` + "`?" )
2236+ matches := tableNamePattern .FindAllStringSubmatch (parsedSQL , - 1 )
2237+
2238+ for _ , match := range matches {
2239+ if len (match ) >= 3 {
2240+ tableName := match [2 ]
2241+ if tableName == "dual" || tableName == "DUAL" {
2242+ continue
2243+ }
2244+ if ! allowedTableNamePattern .MatchString (tableName ) {
2245+ return fmt .Errorf ("invalid table name: %s, only table_<id> format is allowed" , tableName )
2246+ }
2247+ }
2248+ }
2249+
2250+ return nil
2251+ }
0 commit comments