Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 14 additions & 15 deletions cmd/beast/init.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@ import (
"database/sql"
"errors"
"fmt"
"github.com/BurntSushi/toml"
_ "github.com/jackc/pgx/v5/stdlib"
"github.com/lib/pq"
"github.com/sdslabs/beastv4/core"
Expand Down Expand Up @@ -125,11 +124,11 @@ func dbUserCheck() (bool, error) {
func initDb() error {
log.Infoln("Initializing database...")

var configuration config.BeastConfig
_, err := toml.DecodeFile(BEAST_GLOBAL_CONFIG, &configuration)
err := config.ReloadBeastConfig()
if err != nil {
return err
}
configuration := config.Cfg.PsqlConf

isPostgres, err := dbUserCheck()
if err != nil {
Expand All @@ -140,7 +139,7 @@ func initDb() error {
if isPostgres {
log.Infoln("Attempting to connect to postgres as postgres super user...")

dsn := fmt.Sprintf("user=%s dbname=%s sslmode=%s", "postgres", "postgres", "disable")
dsn := fmt.Sprintf("user=%s dbname=%s host=%s port=%s sslmode=%s", "postgres", "postgres", configuration.Host, configuration.Port, "disable")

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you can add user adn dbname from config too

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is intended, at this point the db and user dont exist, so have to connect to postgres database as postgres first to create them.

db, err = sql.Open("pgx", dsn)

if err != nil {
Expand All @@ -152,7 +151,7 @@ func initDb() error {
if utils.PromptBinary("Do you use password authentication for the postgres super user?") {
password := utils.PromptSecret("Enter postgres super user password (leave blank if none):")

dsn := fmt.Sprintf("user=%s password=%s dbname=%s sslmode=%s", "postgres", password, "postgres", "disable")
dsn := fmt.Sprintf("user=%s password=%s dbname=%s host=%s port=%s sslmode=%s", "postgres", password, "postgres", configuration.Host, configuration.Port, "disable")

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same

db, err = sql.Open("pgx", dsn)

if err != nil {
Expand All @@ -167,41 +166,41 @@ func initDb() error {
defer db.Close()

var exists int
err = db.QueryRow("SELECT 1 FROM pg_roles WHERE rolname = $1", configuration.PsqlConf.User).Scan(&exists)
err = db.QueryRow("SELECT 1 FROM pg_roles WHERE rolname = $1", configuration.User).Scan(&exists)
if errors.Is(err, sql.ErrNoRows) {
if err = createBeastDbUser(db, &configuration.PsqlConf); err != nil {
if err = createBeastDbUser(db, &configuration); err != nil {
return err
}
} else if err != nil {
return err
} else {
log.Infoln(fmt.Sprintf("User %s already exists", configuration.PsqlConf.User))
log.Infoln(fmt.Sprintf("User %s already exists", configuration.User))
}

log.Infoln(fmt.Sprintf("Changing password for user %s", configuration.PsqlConf.User))
query := fmt.Sprintf("ALTER USER %s WITH PASSWORD %s", pq.QuoteIdentifier(configuration.PsqlConf.User), utils.QuoteLiteral(configuration.PsqlConf.Password))
log.Infoln(fmt.Sprintf("Changing password for user %s", configuration.User))
query := fmt.Sprintf("ALTER USER %s WITH PASSWORD %s", pq.QuoteIdentifier(configuration.User), utils.QuoteLiteral(configuration.Password))
_, err = db.Exec(query)
if err != nil {
return err
}

err = db.QueryRow("SELECT 1 FROM pg_database WHERE datname = $1", configuration.PsqlConf.Dbname).Scan(&exists)
err = db.QueryRow("SELECT 1 FROM pg_database WHERE datname = $1", configuration.Dbname).Scan(&exists)
if errors.Is(err, sql.ErrNoRows) {
if err = createBeastDatabase(db, &configuration.PsqlConf); err != nil {
if err = createBeastDatabase(db, &configuration); err != nil {
return err
}
} else if err != nil {
return err
} else {
log.Infoln(fmt.Sprintf("Database %s already exists", configuration.PsqlConf.Dbname))
log.Infoln(fmt.Sprintf("Database %s already exists", configuration.Dbname))
}

_, err = db.Exec(fmt.Sprintf("ALTER DATABASE %s OWNER TO %s", pq.QuoteIdentifier(configuration.PsqlConf.Dbname), pq.QuoteIdentifier(configuration.PsqlConf.User)))
_, err = db.Exec(fmt.Sprintf("ALTER DATABASE %s OWNER TO %s", pq.QuoteIdentifier(configuration.Dbname), pq.QuoteIdentifier(configuration.User)))
if err != nil {
return err
}

log.Infoln(fmt.Sprintf("%s set as owner of database %s", configuration.PsqlConf.User, configuration.PsqlConf.Dbname))
log.Infoln(fmt.Sprintf("%s set as owner of database %s", configuration.User, configuration.Dbname))
return nil
}

Expand Down
8 changes: 4 additions & 4 deletions cmd/beast/run.go
Original file line number Diff line number Diff line change
Expand Up @@ -76,13 +76,13 @@ func cleanupDatabaseConnections() {
log.Infoln("Database backup completed successfully")
}

log.Infoln("Terminating database connection...")
log.Infoln("Closing database connection...")

err = database.TerminateDatabaseConnections()
err = database.Close()
if err != nil {
log.Errorln("Unable to terminate database connections:", err)
log.Errorln("Unable to close database connections:", err)
} else {
log.Infoln("Database connections terminated successfully")
log.Infoln("Database connections close successfully")
}
}

Expand Down
4 changes: 2 additions & 2 deletions core/constants.go
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ const ( //paths
BEAST_SECRETS_DIR string = "secrets"
BEAST_EXAMPLE_DIR string = "_examples"
BEAST_CACHE_DIR string = "cache"
BEAST_BACKUP_DIR string = "backup"
)

const ( //chall types
Expand Down Expand Up @@ -203,11 +204,10 @@ var USER_STATUS = map[string]string{
const (
LEADERBOARD_SIZE = 25
LEADERBOARD_GRAPH_SIZE = 12
SUBMISSIONS_PAGE_SIZE = 10
SUBMISSIONS_PAGE_SIZE = 10
)

var NOTIFICATION_SERVICES = []string{
"slack",
"discord",
}

108 changes: 81 additions & 27 deletions core/database/database.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,10 @@ package database

import (
"crypto/rand"
"database/sql"
"fmt"
"github.com/lib/pq"
"gorm.io/gorm/logger"
"os"
"os/exec"
"path/filepath"
Expand All @@ -25,8 +28,7 @@ var (
)

var (
BEAST_GLOBAL_DIR string = filepath.Join(os.Getenv("HOME"), ".beast")
dbConfig Config
dbConfig Config
)

type Config struct {
Expand Down Expand Up @@ -54,7 +56,16 @@ func LoadDbConfig() {
func ConnectDatabase() error {
LoadDbConfig()
dsn := fmt.Sprintf("user=%s password=%s dbname=%s host=%s port=%s sslmode=%s", dbConfig.PsqlConf.User, dbConfig.PsqlConf.Password, dbConfig.PsqlConf.Dbname, dbConfig.PsqlConf.Host, dbConfig.PsqlConf.Port, dbConfig.PsqlConf.SslMode)
Db, dberr = gorm.Open(postgres.Open(dsn), &gorm.Config{})

Db, dberr = gorm.Open(postgres.Open(dsn), &gorm.Config{
Logger: logger.New(
log.New(),
logger.Config{
LogLevel: logger.Warn,
IgnoreRecordNotFoundError: true,
},
)})

if dberr != nil {
log.Error("Error while initializing the database.", dberr)
return dberr
Expand Down Expand Up @@ -117,6 +128,31 @@ func Init() {
}
}

func Close() error {
if Db == nil {
log.Warnln(fmt.Sprintf("Trying to close database connection when no connection is established..."))
return nil
}

DBMux.Lock()
defer DBMux.Unlock()

sqlDb, err := Db.DB()
if err != nil {
log.Errorln(fmt.Sprintf("Error while closing database connection gracefully: %s, attempting to terminate forcefully", err.Error()))
return TerminateDatabaseConnections()
}

err = sqlDb.Close()
if err != nil {
log.Errorln(fmt.Sprintf("Error while closing database connection gracefully: %s, attempting to terminate forcefully", err.Error()))
return TerminateDatabaseConnections()
}

Db = nil
return nil
}

func BackupAndReset() {
LoadDbConfig()

Expand All @@ -131,7 +167,7 @@ func BackupAndReset() {
return
}

backupPath := filepath.Join(core.BEAST_GLOBAL_DIR, "backup", core.BEAST_REMOTES_DIR)
backupPath := filepath.Join(core.BEAST_GLOBAL_DIR, core.BEAST_BACKUP_DIR, core.BEAST_REMOTES_DIR)
err = utils.CreateIfNotExistDir(backupPath)
if err != nil {
log.Errorf("Error while creating backup directory: %s", err)
Expand All @@ -146,7 +182,7 @@ func BackupAndReset() {
return
}

backupPath = filepath.Join(core.BEAST_GLOBAL_DIR, "backup", core.BEAST_STAGING_DIR)
backupPath = filepath.Join(core.BEAST_GLOBAL_DIR, core.BEAST_BACKUP_DIR, core.BEAST_STAGING_DIR)

err = utils.CreateIfNotExistDir(backupPath)
if err != nil {
Expand All @@ -167,17 +203,29 @@ func BackupDatabase() error {
if dbConfig == (Config{}) {
LoadDbConfig()
}

backupPath := filepath.Join(core.BEAST_GLOBAL_DIR, "backup", "db")
backupPath := filepath.Join(core.BEAST_GLOBAL_DIR, core.BEAST_BACKUP_DIR, "db")
err := utils.CreateIfNotExistDir(backupPath)
if err != nil {
log.Errorf("Error while creating backup directory: %s", err)
return err
}

backupFile := fmt.Sprintf("%s_%s.bak", dbConfig.PsqlConf.Dbname, time.Now().Format("20060102150405"))
cmd := exec.Command("pg_dump", "-U", dbConfig.PsqlConf.User, "-h", dbConfig.PsqlConf.Host, "-p", dbConfig.PsqlConf.Port, "-F", "c", "-f", filepath.Join(backupPath, backupFile), dbConfig.PsqlConf.Dbname)
cmd := exec.Command(
"pg_dump",
"-U",
dbConfig.PsqlConf.User,
"-h", dbConfig.PsqlConf.Host,
"-p",
dbConfig.PsqlConf.Port,
"-F",
"c",
"-f",
filepath.Join(backupPath, backupFile),
dbConfig.PsqlConf.Dbname,
)
cmd.Env = append(os.Environ(), fmt.Sprintf("PGPASSWORD=%s", dbConfig.PsqlConf.Password))

output, err := cmd.CombinedOutput()
if err != nil {
log.Printf("Backup error: %s\n", string(output))
Expand Down Expand Up @@ -206,14 +254,23 @@ func ResetDatabase() error {
return err
}

createCmd := exec.Command("psql", "-U", dbConfig.PsqlConf.User, "-h", dbConfig.PsqlConf.Host, "-p", dbConfig.PsqlConf.Port, "-d", "postgres", "-c", "CREATE DATABASE "+dbConfig.PsqlConf.Dbname+";")
createCmd.Env = append(os.Environ(), fmt.Sprintf("PGPASSWORD=%s", dbConfig.PsqlConf.Password))
dsn := fmt.Sprintf("user=%s password=%s dbname=%s host=%s port=%s sslmode=%s", dbConfig.PsqlConf.User, dbConfig.PsqlConf.Password, "postgres", dbConfig.PsqlConf.Host, dbConfig.PsqlConf.Port, "disable")
db, err := sql.Open("pgx", dsn)
if err != nil {
return fmt.Errorf("unable to connect to database: %s", err)
}
defer db.Close()

output, err = createCmd.CombinedOutput()
_, err = db.Exec(fmt.Sprintf("CREATE DATABASE %s", pq.QuoteIdentifier(dbConfig.PsqlConf.Dbname)))
if err != nil {
log.Printf("Create DB error: %s\n", string(output))
return err
return fmt.Errorf("unable to create database: %s", err)
}

_, err = db.Exec(fmt.Sprintf("ALTER DATABASE %s OWNER TO %s", pq.QuoteIdentifier(dbConfig.PsqlConf.Dbname), pq.QuoteIdentifier(dbConfig.PsqlConf.User)))
if err != nil {
return fmt.Errorf("unable to alter database owner: %s", err)
}

log.Debug("Reset successful.")
return nil
}
Expand All @@ -223,24 +280,21 @@ func TerminateDatabaseConnections() error {
if dbConfig == (Config{}) {
LoadDbConfig()
}
terminateCmd := exec.Command(
"psql",
"-U", dbConfig.PsqlConf.User,
"-h", dbConfig.PsqlConf.Host,
"-p", dbConfig.PsqlConf.Port,
"-d", "postgres",
"-c",
fmt.Sprintf("SELECT pg_terminate_backend(pg_stat_activity.pid) FROM pg_stat_activity WHERE datname = '%s' AND pid <> pg_backend_pid();", dbConfig.PsqlConf.Dbname),
)
terminateCmd.Env = append(os.Environ(), fmt.Sprintf("PGPASSWORD=%s", dbConfig.PsqlConf.Password))

output, err := terminateCmd.CombinedOutput()
outputStr := string(output)
dsn := fmt.Sprintf("host=%s port=%s user=%s password=%s dbname=%s sslmode=%s", dbConfig.PsqlConf.Host, dbConfig.PsqlConf.Port, dbConfig.PsqlConf.User, dbConfig.PsqlConf.Password, "postgres", dbConfig.PsqlConf.SslMode)
db, err := sql.Open("pgx", dsn)

if err != nil {
log.Errorf("Terminate connections error: %s\n", outputStr)
return err
}
log.Debug(outputStr)
defer db.Close()

_, err = db.Exec("SELECT pg_terminate_backend(pg_stat_activity.pid) FROM pg_stat_activity WHERE datname = $1 AND pid <> pg_backend_pid();", dbConfig.PsqlConf.Dbname)
if err != nil {
log.Errorf("Terminate connections error: %s\n", err.Error())
return err
}

return nil
}

Expand Down
3 changes: 1 addition & 2 deletions core/database/hints.go
Original file line number Diff line number Diff line change
Expand Up @@ -80,11 +80,10 @@ func UserHasTakenHint(userID, hintID uint) (bool, error) {

var userHint UserHint
if err := tx.Where("user_id = ? AND hint_id = ?", userID, hintID).First(&userHint).Error; err != nil {
tx.Rollback()
if errors.Is(err, gorm.ErrRecordNotFound) {
tx.Rollback()
return false, nil
}
tx.Rollback()
return false, fmt.Errorf("db_error")
}

Expand Down