Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions adserver/api/serializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,11 @@ class AdDecisionSerializer(serializers.Serializer):
user_ip = serializers.CharField(required=False)
user_ua = serializers.CharField(required=False)

# Chat/AI prompt text used for embedding-based ad targeting
# When provided, the ad server generates an embedding from this text
# to match against advertiser content for niche targeting
prompt = serializers.CharField(max_length=8000, required=False)

# Used to specify a specific ad or campaign to show (used for debugging mostly)
force_ad = serializers.CharField(required=False) # slug
force_campaign = serializers.CharField(required=False) # slug
Expand Down
2 changes: 2 additions & 0 deletions adserver/api/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -340,6 +340,8 @@ def decision(self, request, data):
campaign_types=campaign_types,
url=url,
placement_index=serializer.validated_data.get("placement_index"),
# Prompt text for embedding-based targeting
prompt=serializer.validated_data.get("prompt"),
# Debugging parameters
ad_slug=serializer.validated_data.get("force_ad"),
campaign_slug=serializer.validated_data.get("force_campaign"),
Expand Down
Empty file added adserver/chatdemo/__init__.py
Empty file.
139 changes: 139 additions & 0 deletions adserver/chatdemo/embedding.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,139 @@
"""Embedding utilities for generating ad-targeting embeddings from chat prompts."""

import hashlib
import logging

from django.conf import settings
from django.core.cache import cache


log = logging.getLogger(__name__)

# Cache embeddings for 1 hour to avoid redundant API calls
EMBEDDING_CACHE_TIMEOUT = 60 * 60

# The OpenAI embedding model to use - text-embedding-3-small is cheap and fast
EMBEDDING_MODEL = "text-embedding-3-small"


def get_prompt_embedding(prompt_text):
"""
Generate an embedding vector for the given prompt text using OpenAI.

Returns a list of floats (the embedding vector) or None on failure.
"""
if not prompt_text or not prompt_text.strip():
return None

api_key = getattr(settings, "OPENAI_API_KEY", None)
if not api_key:
log.warning("OPENAI_API_KEY not configured, cannot generate prompt embedding")
return None

# Check cache first
cache_key = _embedding_cache_key(prompt_text)
cached = cache.get(cache_key)
if cached is not None:
return cached

try:
import openai

client = openai.OpenAI(api_key=api_key)
response = client.embeddings.create(
model=EMBEDDING_MODEL,
input=prompt_text.strip()[:8000], # Limit input length
)
embedding = response.data[0].embedding

cache.set(cache_key, embedding, EMBEDDING_CACHE_TIMEOUT)
return embedding

except Exception:
log.exception("Failed to generate embedding for prompt")
return None


def cosine_similarity(vec_a, vec_b):
"""Compute cosine similarity between two vectors."""
if not vec_a or not vec_b or len(vec_a) != len(vec_b):
return 0.0

dot_product = sum(a * b for a, b in zip(vec_a, vec_b))
magnitude_a = sum(a * a for a in vec_a) ** 0.5
magnitude_b = sum(b * b for b in vec_b) ** 0.5

if magnitude_a == 0 or magnitude_b == 0:
return 0.0

return dot_product / (magnitude_a * magnitude_b)


def get_prompt_niche_weights(prompt_text, flights):
"""
Compute niche targeting weights for flights based on prompt embedding similarity.

This mirrors the ethicalads_ext.embedding.utils.get_niche_weights interface
but works directly from prompt text instead of URL content.

Returns a dict mapping Advertiser -> distance (lower = more similar).
"""
prompt_embedding = get_prompt_embedding(prompt_text)
if not prompt_embedding:
return {}

weights = {}

for flight in flights:
advertiser = flight.campaign.advertiser

# Skip if we already computed for this advertiser
if advertiser in weights:
continue

# Get the advertiser's embedding if stored, or generate from ad text
ad_embedding = _get_advertiser_embedding(flight)
if not ad_embedding:
continue

similarity = cosine_similarity(prompt_embedding, ad_embedding)
# Convert similarity to distance (lower = better match)
# niche_targeting threshold compares distance < goal
distance = 1.0 - similarity
weights[advertiser] = distance

return weights


def _get_advertiser_embedding(flight):
"""
Get or generate an embedding for a flight's advertiser content.

Uses the flight's advertisement text as the content to embed.
"""
# Build a text representation from the flight's ads
ad_texts = []
for ad in flight.advertisements.filter(live=True)[:5]:
parts = []
if ad.headline:
parts.append(ad.headline)
if ad.content:
parts.append(ad.content)
if ad.cta:
parts.append(ad.cta)
if not parts and ad.text:
parts.append(ad.text)
if parts:
ad_texts.append(" ".join(parts))

if not ad_texts:
return None

combined_text = " | ".join(ad_texts)
return get_prompt_embedding(combined_text)


def _embedding_cache_key(text):
"""Generate a cache key for an embedding."""
text_hash = hashlib.md5(text.strip().lower().encode()).hexdigest()
return f"prompt-embedding-{text_hash}"
14 changes: 14 additions & 0 deletions adserver/chatdemo/urls.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
"""URL configuration for the chat demo."""

from django.urls import path

from .views import ChatCompletionProxyView
from .views import ChatDemoView


app_name = "chatdemo"

urlpatterns = [
path("", ChatDemoView.as_view(), name="chat-demo"),
path("completion/", ChatCompletionProxyView.as_view(), name="chat-completion"),
]
88 changes: 88 additions & 0 deletions adserver/chatdemo/views.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
"""Views for the AI chat demo with ethical ad targeting."""

import json
import logging

from django.conf import settings
from django.http import JsonResponse
from django.utils.decorators import method_decorator
from django.views import View
from django.views.decorators.csrf import csrf_exempt
from django.views.generic import TemplateView


log = logging.getLogger(__name__)


class ChatDemoView(TemplateView):
"""Serve the chat demo HTML page."""

template_name = "adserver/chatdemo/chat.html"

def get_context_data(self, **kwargs):
context = super().get_context_data(**kwargs)
context["publisher_slug"] = getattr(
settings, "ADSERVER_CHAT_DEMO_PUBLISHER", ""
)
return context


@method_decorator(csrf_exempt, name="dispatch")
class ChatCompletionProxyView(View):
"""
Proxy chat completion requests to OpenAI.

This keeps the OpenAI API key on the server side
and uses a cheap model (gpt-4o-mini) for completions.
"""

OPENAI_MODEL = "gpt-4o-mini"

def post(self, request):
api_key = getattr(settings, "OPENAI_API_KEY", None)
if not api_key:
return JsonResponse(
{"error": "OpenAI API key not configured on the server"},
status=500,
)

try:
body = json.loads(request.body)
except (json.JSONDecodeError, ValueError):
return JsonResponse({"error": "Invalid JSON"}, status=400)

messages = body.get("messages", [])
if not messages:
return JsonResponse({"error": "No messages provided"}, status=400)

# Limit conversation length to prevent abuse
messages = messages[:50]

try:
import openai

client = openai.OpenAI(api_key=api_key)
response = client.chat.completions.create(
model=self.OPENAI_MODEL,
messages=messages,
max_tokens=1024,
temperature=0.7,
)

return JsonResponse(
{
"content": response.choices[0].message.content,
"model": response.model,
"usage": {
"prompt_tokens": response.usage.prompt_tokens,
"completion_tokens": response.usage.completion_tokens,
},
}
)

except Exception:
log.exception("OpenAI chat completion failed")
return JsonResponse(
{"error": "Chat completion request failed"},
status=502,
)
37 changes: 25 additions & 12 deletions adserver/decisionengine/backends.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,9 @@ def __init__(self, request, placements, publisher, **kwargs):
self.ad_slug = kwargs.get("ad_slug")
self.campaign_slug = kwargs.get("campaign_slug")

# Chat/AI prompt text for embedding-based targeting
self.prompt = kwargs.get("prompt") or ""

self.niche_weights = None

def get_analyzer_keywords(self):
Expand Down Expand Up @@ -420,19 +423,29 @@ def select_flight(self):

# Apply niche targeting only when any flight has it.
# This is to track whether we should do expensive distance queries.
if (
flights_with_niche_targeting
and "ethicalads_ext.embedding" in settings.INSTALLED_APPS
):
# We have to do this here,
# so we can filter by the weight in the filter_flight call below
from ethicalads_ext.embedding.utils import get_niche_weights # noqa
if flights_with_niche_targeting:
if self.prompt:
# Use prompt-based embedding for niche targeting
# This enables AI chat contexts to target ads
from ..chatdemo.embedding import get_prompt_niche_weights

self.niche_weights = get_prompt_niche_weights(
self.prompt, flights_with_niche_targeting
)
if self.niche_weights:
log.debug(
"Prompt niche targeting weights: %s",
self.niche_weights,
)
elif "ethicalads_ext.embedding" in settings.INSTALLED_APPS:
# Fall back to URL-based niche targeting
from ethicalads_ext.embedding.utils import get_niche_weights # noqa

self.niche_weights = get_niche_weights(
url=self.url, flights=flights_with_niche_targeting
)
if self.niche_weights:
log.debug("Niche targeting weights: %s", self.niche_weights)
self.niche_weights = get_niche_weights(
url=self.url, flights=flights_with_niche_targeting
)
if self.niche_weights:
log.debug("Niche targeting weights: %s", self.niche_weights)

for flight in possible_flights:
# Handle excluding flights based on targeting
Expand Down
32 changes: 32 additions & 0 deletions adserver/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -1484,3 +1484,35 @@ def run_publisher_importers():
"""
# PSF is the only importer for now..
psf.run_import(sync=True)


@app.task()
def publish_celery_queue_depth():
"""Publish Celery queue depth to CloudWatch for autoscaling the celery ASG."""
import boto3
import redis

broker_url = settings.CELERY_BROKER_URL
queues = ["celery", "analyzer", "priority"]

try:
r = redis.Redis.from_url(broker_url, socket_timeout=5, socket_connect_timeout=5)
total = sum(r.llen(q) for q in queues)
except Exception:
log.exception("Failed to read Celery queue depths from Redis")
return

try:
client = boto3.client("cloudwatch", region_name=settings.AWS_S3_REGION_NAME)
client.put_metric_data(
Namespace="EthicalAds/Celery",
MetricData=[
{
"MetricName": "QueueDepth",
"Value": total,
"Unit": "Count",
}
],
)
except Exception:
log.exception("Failed to publish queue depth to CloudWatch")
Loading
Loading