Skip to content

Commit 7b4dd41

Browse files
Check for pool max runners in CreateInstance tx
This change moves the check for max runners within the CreateInstance function, which will check that the pool max runners is not yet reached within a transaction before creating a new instance. Signed-off-by: Gabriel Adrian Samfira <gsamfira@cloudbasesolutions.com>
1 parent 26dfb7a commit 7b4dd41

File tree

6 files changed

+95
-51
lines changed

6 files changed

+95
-51
lines changed

config/config.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -569,7 +569,7 @@ func (s *SQLite) Validate() error {
569569
}
570570

571571
func (s *SQLite) ConnectionString() (string, error) {
572-
connectionString := fmt.Sprintf("%s?_journal_mode=WAL&_foreign_keys=ON", s.DBFile)
572+
connectionString := fmt.Sprintf("%s?_journal_mode=WAL&_foreign_keys=ON&_txlock=immediate", s.DBFile)
573573
if s.BusyTimeoutSeconds > 0 {
574574
timeout := s.BusyTimeoutSeconds * 1000
575575
connectionString = fmt.Sprintf("%s&_busy_timeout=%d", connectionString, timeout)

config/config_test.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -387,13 +387,13 @@ func TestGormParams(t *testing.T) {
387387
dbType, uri, err := cfg.GormParams()
388388
require.Nil(t, err)
389389
require.Equal(t, SQLiteBackend, dbType)
390-
require.Equal(t, filepath.Join(dir, "garm.db?_journal_mode=WAL&_foreign_keys=ON"), uri)
390+
require.Equal(t, filepath.Join(dir, "garm.db?_journal_mode=WAL&_foreign_keys=ON&_txlock=immediate"), uri)
391391

392392
cfg.SQLite.BusyTimeoutSeconds = 5
393393
dbType, uri, err = cfg.GormParams()
394394
require.Nil(t, err)
395395
require.Equal(t, SQLiteBackend, dbType)
396-
require.Equal(t, filepath.Join(dir, "garm.db?_journal_mode=WAL&_foreign_keys=ON&_busy_timeout=5000"), uri)
396+
require.Equal(t, filepath.Join(dir, "garm.db?_journal_mode=WAL&_foreign_keys=ON&_txlock=immediate&_busy_timeout=5000"), uri)
397397

398398
cfg.DbBackend = MySQLBackend
399399
cfg.MySQL = getMySQLDefaultConfig()

database/sql/instances.go

Lines changed: 56 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,9 @@ package sql
1717
import (
1818
"context"
1919
"encoding/json"
20+
"fmt"
2021
"log/slog"
22+
"math"
2123

2224
"github.com/google/uuid"
2325
"github.com/pkg/errors"
@@ -30,57 +32,76 @@ import (
3032
"github.com/cloudbase/garm/params"
3133
)
3234

33-
func (s *sqlDatabase) CreateInstance(_ context.Context, poolID string, param params.CreateInstanceParams) (instance params.Instance, err error) {
35+
func (s *sqlDatabase) CreateInstance(ctx context.Context, poolID string, param params.CreateInstanceParams) (instance params.Instance, err error) {
3436
s.writeMux.Lock()
3537
defer s.writeMux.Unlock()
3638

37-
pool, err := s.getPoolByID(s.conn, poolID)
38-
if err != nil {
39-
return params.Instance{}, errors.Wrap(err, "fetching pool")
40-
}
41-
4239
defer func() {
4340
if err == nil {
4441
s.sendNotify(common.InstanceEntityType, common.CreateOperation, instance)
4542
}
4643
}()
4744

48-
var labels datatypes.JSON
49-
if len(param.AditionalLabels) > 0 {
50-
labels, err = json.Marshal(param.AditionalLabels)
45+
err = s.conn.Transaction(func(tx *gorm.DB) error {
46+
pool, err := s.getPoolByID(tx, poolID)
5147
if err != nil {
52-
return params.Instance{}, errors.Wrap(err, "marshalling labels")
48+
return errors.Wrap(err, "fetching pool")
49+
}
50+
var cnt int64
51+
q := tx.Model(&Instance{}).Where("pool_id = ?", pool.ID).Count(&cnt)
52+
if q.Error != nil {
53+
return fmt.Errorf("error fetching instance count: %w", q.Error)
5354
}
54-
}
5555

56-
var secret []byte
57-
if len(param.JitConfiguration) > 0 {
58-
secret, err = s.marshalAndSeal(param.JitConfiguration)
59-
if err != nil {
60-
return params.Instance{}, errors.Wrap(err, "marshalling jit config")
56+
var maxRunners int64
57+
if pool.MaxRunners > math.MaxInt64 {
58+
maxRunners = math.MaxInt64
59+
} else {
60+
maxRunners = int64(pool.MaxRunners)
61+
}
62+
if cnt >= maxRunners {
63+
return runnerErrors.NewConflictError("max runners reached for pool %s", pool.ID)
64+
}
65+
var labels datatypes.JSON
66+
if len(param.AditionalLabels) > 0 {
67+
labels, err = json.Marshal(param.AditionalLabels)
68+
if err != nil {
69+
return errors.Wrap(err, "marshalling labels")
70+
}
6171
}
62-
}
6372

64-
newInstance := Instance{
65-
Pool: pool,
66-
Name: param.Name,
67-
Status: param.Status,
68-
RunnerStatus: param.RunnerStatus,
69-
OSType: param.OSType,
70-
OSArch: param.OSArch,
71-
CallbackURL: param.CallbackURL,
72-
MetadataURL: param.MetadataURL,
73-
GitHubRunnerGroup: param.GitHubRunnerGroup,
74-
JitConfiguration: secret,
75-
AditionalLabels: labels,
76-
AgentID: param.AgentID,
77-
}
78-
q := s.conn.Create(&newInstance)
79-
if q.Error != nil {
80-
return params.Instance{}, errors.Wrap(q.Error, "creating instance")
73+
var secret []byte
74+
if len(param.JitConfiguration) > 0 {
75+
secret, err = s.marshalAndSeal(param.JitConfiguration)
76+
if err != nil {
77+
return errors.Wrap(err, "marshalling jit config")
78+
}
79+
}
80+
newInstance := Instance{
81+
Pool: pool,
82+
Name: param.Name,
83+
Status: param.Status,
84+
RunnerStatus: param.RunnerStatus,
85+
OSType: param.OSType,
86+
OSArch: param.OSArch,
87+
CallbackURL: param.CallbackURL,
88+
MetadataURL: param.MetadataURL,
89+
GitHubRunnerGroup: param.GitHubRunnerGroup,
90+
JitConfiguration: secret,
91+
AditionalLabels: labels,
92+
AgentID: param.AgentID,
93+
}
94+
q = tx.Create(&newInstance)
95+
if q.Error != nil {
96+
return errors.Wrap(q.Error, "creating instance")
97+
}
98+
return nil
99+
})
100+
if err != nil {
101+
return params.Instance{}, errors.Wrap(err, "creating instance")
81102
}
82103

83-
return s.sqlToParamsInstance(newInstance)
104+
return s.GetInstanceByName(ctx, param.Name)
84105
}
85106

86107
func (s *sqlDatabase) getPoolInstanceByName(poolID string, instanceName string) (Instance, error) {

database/sql/instances_test.go

Lines changed: 35 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -204,17 +204,48 @@ func (s *InstancesTestSuite) TestCreateInstance() {
204204
func (s *InstancesTestSuite) TestCreateInstanceInvalidPoolID() {
205205
_, err := s.Store.CreateInstance(s.adminCtx, "dummy-pool-id", params.CreateInstanceParams{})
206206

207-
s.Require().Equal("fetching pool: parsing id: invalid request", err.Error())
207+
s.Require().Equal("creating instance: fetching pool: parsing id: invalid request", err.Error())
208+
}
209+
210+
func (s *InstancesTestSuite) TestCreateInstanceMaxRunnersReached() {
211+
// Pool has MaxRunners=4 and already has 3 instances
212+
// Create one more to reach the limit
213+
_, err := s.Store.CreateInstance(s.adminCtx, s.Fixtures.Pool.ID, params.CreateInstanceParams{
214+
Name: "test-instance-4",
215+
OSType: "linux",
216+
OSArch: "amd64",
217+
CallbackURL: "https://garm.example.com/",
218+
Status: commonParams.InstanceRunning,
219+
RunnerStatus: params.RunnerIdle,
220+
})
221+
s.Require().Nil(err)
222+
223+
// Now try to create a 5th instance, which should fail
224+
_, err = s.Store.CreateInstance(s.adminCtx, s.Fixtures.Pool.ID, params.CreateInstanceParams{
225+
Name: "test-instance-5",
226+
OSType: "linux",
227+
OSArch: "amd64",
228+
CallbackURL: "https://garm.example.com/",
229+
Status: commonParams.InstanceRunning,
230+
RunnerStatus: params.RunnerIdle,
231+
})
232+
233+
s.Require().NotNil(err)
234+
s.Require().Contains(err.Error(), "max runners reached")
208235
}
209236

210237
func (s *InstancesTestSuite) TestCreateInstanceDBCreateErr() {
211238
pool := s.Fixtures.Pool
212239

240+
s.Fixtures.SQLMock.ExpectBegin()
213241
s.Fixtures.SQLMock.
214242
ExpectQuery(regexp.QuoteMeta("SELECT * FROM `pools` WHERE id = ? AND `pools`.`deleted_at` IS NULL ORDER BY `pools`.`id` LIMIT ?")).
215243
WithArgs(pool.ID, 1).
216-
WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(pool.ID))
217-
s.Fixtures.SQLMock.ExpectBegin()
244+
WillReturnRows(sqlmock.NewRows([]string{"id", "max_runners"}).AddRow(pool.ID, pool.MaxRunners))
245+
s.Fixtures.SQLMock.
246+
ExpectQuery(regexp.QuoteMeta("SELECT count(*) FROM `instances` WHERE pool_id = ? AND `instances`.`deleted_at` IS NULL")).
247+
WithArgs(pool.ID).
248+
WillReturnRows(sqlmock.NewRows([]string{"count(*)"}).AddRow(0))
218249
s.Fixtures.SQLMock.
219250
ExpectExec("INSERT INTO `pools`").
220251
WillReturnResult(sqlmock.NewResult(1, 1))
@@ -227,7 +258,7 @@ func (s *InstancesTestSuite) TestCreateInstanceDBCreateErr() {
227258

228259
s.assertSQLMockExpectations()
229260
s.Require().NotNil(err)
230-
s.Require().Equal("creating instance: mocked insert instance error", err.Error())
261+
s.Require().Equal("creating instance: creating instance: mocked insert instance error", err.Error())
231262
}
232263

233264
func (s *InstancesTestSuite) TestGetPoolInstanceByName() {

database/watcher/watcher_store_test.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -167,6 +167,7 @@ func (s *WatcherStoreTestSuite) TestInstanceWatcher() {
167167
ProviderName: "test-provider",
168168
Image: "test-image",
169169
Flavor: "test-flavor",
170+
MaxRunners: 100,
170171
OSType: commonParams.Linux,
171172
OSArch: commonParams.Amd64,
172173
Tags: []string{"test-tag"},

runner/pool/pool.go

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1118,15 +1118,6 @@ func (r *basePoolManager) addRunnerToPool(pool params.Pool, aditionalLabels []st
11181118
return fmt.Errorf("pool %s is disabled", pool.ID)
11191119
}
11201120

1121-
poolInstanceCount, err := r.store.PoolInstanceCount(r.ctx, pool.ID)
1122-
if err != nil {
1123-
return fmt.Errorf("failed to list pool instances: %w", err)
1124-
}
1125-
1126-
if poolInstanceCount >= int64(pool.MaxRunners) {
1127-
return fmt.Errorf("max workers (%d) reached for pool %s", pool.MaxRunners, pool.ID)
1128-
}
1129-
11301121
if err := r.AddRunner(r.ctx, pool.ID, aditionalLabels); err != nil {
11311122
return fmt.Errorf("failed to add new instance for pool %s: %s", pool.ID, err)
11321123
}

0 commit comments

Comments
 (0)