Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
132 changes: 100 additions & 32 deletions go/internal/feast/onlinestore/cassandraonlinestore.go
Original file line number Diff line number Diff line change
Expand Up @@ -214,8 +214,8 @@ func extractCassandraConfig(onlineStoreConfig map[string]any) (*CassandraConfig,
readBatchSize = legacyBatchSize
log.Warn().Msg("key_batch_size is deprecated, please use read_batch_size instead")
} else {
readBatchSize = 100.0
log.Warn().Msg("read_batch_size not specified, defaulting to batches of size 100")
readBatchSize = 10.0
log.Warn().Msg("read_batch_size not specified, defaulting to batches of size 10")
}
}
cassandraConfig.readBatchSize = int(readBatchSize.(float64))
Expand Down Expand Up @@ -342,25 +342,23 @@ func (c *CassandraOnlineStore) getFqTableName(keySpace string, project string, f
return fmt.Sprintf(`"%s"."%s"`, keySpace, dbTableName), nil
}

func (c *CassandraOnlineStore) getSingleKeyCQLStatement(tableName string, featureNames []string) string {
// this prevents fetching unnecessary features
quotedFeatureNames := make([]string, len(featureNames))
for i, featureName := range featureNames {
quotedFeatureNames[i] = fmt.Sprintf(`'%s'`, featureName)
func (c *CassandraOnlineStore) getSingleKeyCQLStatement(tableName string, numFeatures int) string {
featurePlaceholders := make([]string, numFeatures)
for i := 0; i < numFeatures; i++ {
featurePlaceholders[i] = "?"
}

return fmt.Sprintf(
`SELECT "entity_key", "feature_name", "event_ts", "value" FROM %s WHERE "entity_key" = ? AND "feature_name" IN (%s)`,
tableName,
strings.Join(quotedFeatureNames, ","),
strings.Join(featurePlaceholders, ","),
)
}

func (c *CassandraOnlineStore) getMultiKeyCQLStatement(tableName string, featureNames []string, nkeys int) string {
// this prevents fetching unnecessary features
quotedFeatureNames := make([]string, len(featureNames))
for i, featureName := range featureNames {
quotedFeatureNames[i] = fmt.Sprintf(`'%s'`, featureName)
func (c *CassandraOnlineStore) getMultiKeyCQLStatement(tableName string, numFeatures int, nkeys int) string {
featurePlaceholders := make([]string, numFeatures)
for i := 0; i < numFeatures; i++ {
featurePlaceholders[i] = "?"
}

keyPlaceholders := make([]string, nkeys)
Expand All @@ -371,7 +369,7 @@ func (c *CassandraOnlineStore) getMultiKeyCQLStatement(tableName string, feature
`SELECT "entity_key", "feature_name", "event_ts", "value" FROM %s WHERE "entity_key" IN (%s) AND "feature_name" IN (%s)`,
tableName,
strings.Join(keyPlaceholders, ","),
strings.Join(quotedFeatureNames, ","),
strings.Join(featurePlaceholders, ","),
)
}

Expand Down Expand Up @@ -425,6 +423,15 @@ type BatchJob struct {
CQLStatement string
}

func buildQueryParams(entityKeys []any, featureNames []string) []any {
params := make([]any, 0, len(entityKeys)+len(featureNames))
params = append(params, entityKeys...)
for _, fn := range featureNames {
params = append(params, fn)
}
return params
}

func (c *CassandraOnlineStore) OnlineReadV2(ctx context.Context, entityKeys []*types.EntityKey, featureViewNames []string, featureNames []string) ([][]FeatureData, error) {
serializedEntityKeys, serializedEntityKeyToIndex, err := c.buildCassandraEntityKeys(entityKeys)
if err != nil {
Expand All @@ -443,38 +450,97 @@ func (c *CassandraOnlineStore) OnlineReadV2(ctx context.Context, entityKeys []*t
return nil, err
}

var cqlForBatch string
cqlForBatch = c.getMultiKeyCQLStatement(tableName, featureNames, len(serializedEntityKeys))

job := BatchJob{
ViewName: featureViewName,
TableName: tableName,
FeatureNames: featureNames,
EntityKeys: serializedEntityKeys,
CQLStatement: cqlForBatch,
results := make([][]FeatureData, len(entityKeys))
for i := range results {
results[i] = make([]FeatureData, len(featureNames))
}

results, err := c.executeBatchV2(ctx, job, serializedEntityKeyToIndex, featureNamesToIdx)
batches := c.createBatches(serializedEntityKeys)

if err != nil {
g, ctx := errgroup.WithContext(ctx)
var mu sync.Mutex

var prevBatchLength int
var cqlStatement string

for i, batch := range batches {
var cqlForBatch string
if i == 0 || len(batch) != prevBatchLength {
cqlForBatch = c.getMultiKeyCQLStatement(tableName, len(featureNames), len(batch))
prevBatchLength = len(batch)
cqlStatement = cqlForBatch
} else {
cqlForBatch = cqlStatement
}

job := BatchJob{
ViewName: featureViewName,
TableName: tableName,
FeatureNames: featureNames,
EntityKeys: batch,
CQLStatement: cqlForBatch,
}

g.Go(func() error {
batchResults, err := c.executeBatchV2(ctx, job, featureNamesToIdx)
if err != nil {
return err
}

mu.Lock()
defer mu.Unlock()
for localIdx, key := range job.EntityKeys {
globalIdx := serializedEntityKeyToIndex[key.(string)]
for featIdx := range featureNames {
results[globalIdx][featIdx] = batchResults[localIdx][featIdx]
}
}
return nil
})
}

if err := g.Wait(); err != nil {
return nil, err
}

for i := range results {
for j, feat := range results[i] {
if feat.Value.Val == nil {
results[i][j] = FeatureData{
Reference: serving.FeatureReferenceV2{
FeatureViewName: featureViewName,
FeatureName: featureNames[j],
},
Value: types.Value{
Val: &types.Value_NullVal{
NullVal: types.Null_NULL,
},
},
}
}
}
}

return results, nil
}

func (c *CassandraOnlineStore) executeBatchV2(
ctx context.Context,
job BatchJob,
serializedEntityKeyToIndex map[string]int,
featureNamesToIdx map[string]int,
) ([][]FeatureData, error) {
localKeyToIndex := make(map[string]int, len(job.EntityKeys))
for i, key := range job.EntityKeys {
localKeyToIndex[key.(string)] = i
}

results := make([][]FeatureData, len(job.EntityKeys))
for i := range results {
results[i] = make([]FeatureData, len(job.FeatureNames))
}

iter := c.session.Query(job.CQLStatement, job.EntityKeys...).WithContext(ctx).Iter()
queryParams := buildQueryParams(job.EntityKeys, job.FeatureNames)
iter := c.session.Query(job.CQLStatement, queryParams...).WithContext(ctx).Iter()
defer iter.Close()

scanner := iter.Scanner()
Expand Down Expand Up @@ -519,9 +585,10 @@ func (c *CassandraOnlineStore) executeBatchV2(
for _, serializedEntityKey := range job.EntityKeys {
for _, featName := range job.FeatureNames {
keyString := serializedEntityKey.(string)
localIdx := localKeyToIndex[keyString]

if featureData, exists := batchFeatures[keyString][featName]; exists {
results[serializedEntityKeyToIndex[keyString]][featureNamesToIdx[featName]] = FeatureData{
results[localIdx][featureNamesToIdx[featName]] = FeatureData{
Reference: serving.FeatureReferenceV2{
FeatureViewName: featureData.Reference.FeatureViewName,
FeatureName: featureData.Reference.FeatureName,
Expand All @@ -532,8 +599,7 @@ func (c *CassandraOnlineStore) executeBatchV2(
},
}
} else {
// TODO: return not found status to differentiate between nulls and not found features
results[serializedEntityKeyToIndex[keyString]][featureNamesToIdx[featName]] = FeatureData{
results[localIdx][featureNamesToIdx[featName]] = FeatureData{
Reference: serving.FeatureReferenceV2{
FeatureViewName: job.ViewName,
FeatureName: featName,
Expand Down Expand Up @@ -592,7 +658,7 @@ func (c *CassandraOnlineStore) OnlineRead(ctx context.Context, entityKeys []*typ
for i, batch := range batches {
var cqlForBatch string
if i == 0 || len(batch) != prevBatchLength {
cqlForBatch = c.getMultiKeyCQLStatement(tableName, currentFeatureNames, len(batch))
cqlForBatch = c.getMultiKeyCQLStatement(tableName, len(currentFeatureNames), len(batch))
prevBatchLength = len(batch)
cqlStatement = cqlForBatch
} else {
Expand Down Expand Up @@ -640,7 +706,8 @@ func (c *CassandraOnlineStore) executeBatch(
results [][]FeatureData,
featureNamesToIdx map[string]int,
) error {
iter := c.session.Query(job.CQLStatement, job.EntityKeys...).WithContext(ctx).Iter()
queryParams := buildQueryParams(job.EntityKeys, job.FeatureNames)
iter := c.session.Query(job.CQLStatement, queryParams...).WithContext(ctx).Iter()
defer iter.Close()

scanner := iter.Scanner()
Expand Down Expand Up @@ -1038,3 +1105,4 @@ func (c *CassandraOnlineStore) GetDataModelType() OnlineStoreDataModel {
func (c *CassandraOnlineStore) GetReadBatchSize() int {
return c.KeyBatchSize
}

Loading