diff --git a/pkg/gateway/server/llmproxy.go b/pkg/gateway/server/llmproxy.go index ba8460fc70..d4b263b5e7 100644 --- a/pkg/gateway/server/llmproxy.go +++ b/pkg/gateway/server/llmproxy.go @@ -113,19 +113,21 @@ func (s *Server) dispatchLLMProxy(req api.Context) error { personalToken = true } - // Check if the user has permission to use this model - if modelID != "" && token.UserID != "" { - userID, err := strconv.ParseUint(token.UserID, 10, 64) + // Fetch auth provider groups once for all checks that need them. + var authProviderGroups []string + if token.UserID != "" { + userIDInt, err := strconv.ParseUint(token.UserID, 10, 64) if err != nil { return fmt.Errorf("failed to parse user ID: %w", err) } - - // Get the user's auth provider groups - authProviderGroups, err := req.GatewayClient.ListGroupIDsForUser(req.Context(), uint(userID)) + authProviderGroups, err = req.GatewayClient.ListGroupIDsForUser(req.Context(), uint(userIDInt)) if err != nil { return fmt.Errorf("failed to get user groups: %w", err) } + } + // Check if the user has permission to use this model + if modelID != "" && token.UserID != "" { hasAccess, err := s.mapHelper.UserHasAccessToModel(&user.DefaultInfo{ UID: token.UserID, Groups: token.UserGroups, @@ -157,6 +159,9 @@ func (s *Server) dispatchLLMProxy(req api.Context) error { userInfo := &user.DefaultInfo{ UID: token.UserID, Groups: token.UserGroups, + Extra: map[string][]string{ + "auth_provider_groups": authProviderGroups, + }, } outputPolicies, conversationHistory, inputPolicyReplacement, err = applyMessagePolicies( req.Context(), messagePolicyHelper, userInfo, req.GatewayClient, body, token.ProjectID, token.ThreadID,