Skip to content

Commit d9b1625

Browse files
sjrlbogdankostic
authored andcommitted
fix: Fix auto-variadic assignment (#10688)
* Try out idea * Undo some changes * update function to handle Optional[list] as well * Add unit tests for auto-variadic + conversion * add test for optional list * add reno * remove unnecessary tests for now * add another unit test
1 parent df89c7e commit d9b1625

File tree

5 files changed

+140
-36
lines changed

5 files changed

+140
-36
lines changed

haystack/core/component/types.py

Lines changed: 11 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -85,21 +85,18 @@ def __post_init__(self):
8585
except AttributeError:
8686
self.is_lazy_variadic = False
8787
self.is_greedy = False
88+
89+
# We need to "unpack" the type inside the Variadic annotation, otherwise the pipeline connection api will try
90+
# to match `Annotated[type, HAYSTACK_VARIADIC_ANNOTATION]`.
91+
#
92+
# Note1: Variadic is expressed as an annotation of one single type, so the return value of get_args will
93+
# always be a one-item tuple.
94+
#
95+
# Note2: a pipeline always passes a list of items when a component input is declared as Variadic, so the
96+
# type itself always wraps an iterable of the declared type. For example, Variadic[int] is eventually an
97+
# alias for Iterable[int]. Since we're interested in getting the inner type `int`, we call `get_args`
98+
# twice: the first time to get `list[int]` out of `Variadic`, the second time to get `int` out of `list[int]`.
8899
if self.is_lazy_variadic or self.is_greedy:
89-
# We need to "unpack" the type inside the Variadic annotation,
90-
# otherwise the pipeline connection api will try to match
91-
# `Annotated[type, HAYSTACK_VARIADIC_ANNOTATION]`.
92-
#
93-
# Note1: Variadic is expressed as an annotation of one single type,
94-
# so the return value of get_args will always be a one-item tuple.
95-
#
96-
# Note2: a pipeline always passes a list of items when a component
97-
# input is declared as Variadic, so the type itself always wraps
98-
# an iterable of the declared type. For example, Variadic[int]
99-
# is eventually an alias for Iterable[int]. Since we're interested
100-
# in getting the inner type `int`, we call `get_args` twice: the
101-
# first time to get `list[int]` out of `Variadic`, the second time
102-
# to get `int` out of `list[int]`.
103100
self.type = get_args(get_args(self.type)[0])[0]
104101

105102

haystack/core/pipeline/base.py

Lines changed: 37 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
from datetime import datetime
1111
from enum import IntEnum
1212
from pathlib import Path
13-
from typing import Any, TextIO, TypeVar
13+
from typing import Any, TextIO, TypeVar, Union, get_args
1414

1515
import networkx
1616

@@ -615,16 +615,7 @@ def connect(self, sender: str, receiver: str) -> "PipelineBase": # noqa: PLR091
615615
return self
616616

617617
if receiver_socket.senders:
618-
# We automatically set the receiver socket as variadic if:
619-
# - it has at least one sender already connected
620-
# - it's not already variadic
621-
# - its origin type is list
622-
if not receiver_socket.is_variadic and _safe_get_origin(receiver_socket.type) == list:
623-
receiver_socket.is_lazy_variadic = True
624-
# We also disable wrapping inputs into list so the sender outputs matches the type of the receiver
625-
# socket.
626-
receiver_socket.wrap_input_in_list = False
627-
618+
receiver_socket = self._make_socket_auto_variadic(receiver_socket=receiver_socket)
628619
if not receiver_socket.is_variadic:
629620
# Only variadic input sockets can receive from multiple senders
630621
msg = (
@@ -949,23 +940,48 @@ def validate_input(self, data: dict[str, Any]) -> None:
949940

950941
# Check if an input is provided more than once for non-variadic sockets
951942
if socket.senders and socket_name in component_inputs:
952-
# We automatically set the receiver socket as lazy variadic if:
953-
# - it has at least one sender already connected
954-
# - it's not already variadic
955-
# - its origin type is list
956-
if not socket.is_variadic and _safe_get_origin(socket.type) == list:
957-
socket.is_lazy_variadic = True
958-
# We also disable wrapping inputs into list so the sender outputs matches the type of the
959-
# receiver socket.
960-
socket.wrap_input_in_list = False
961-
943+
socket = self._make_socket_auto_variadic(receiver_socket=socket)
962944
if not socket.is_variadic:
963945
raise ValueError(
964946
f"Component '{component_name}' cannot accept multiple inputs to '{socket_name}'. "
965947
f"It is already connected to component '{socket.senders[0]}' so it cannot accept "
966948
"additional inputs."
967949
)
968950

951+
def _make_socket_auto_variadic(self, receiver_socket: InputSocket) -> InputSocket:
952+
"""
953+
Checks if the receiver socket can be made lazy variadic and modifies it in-place if that's the case.
954+
955+
We automatically set the receiver socket as lazy variadic if:
956+
- it has at least one sender already connected
957+
- it's not already variadic
958+
- its type is list or Optional[list]
959+
960+
NOTE: We also disable wrapping inputs into list for these auto-variadic sockets, so the sender outputs match the
961+
type of the receiver socket.
962+
963+
:returns:
964+
The potentially modified receiver socket.
965+
"""
966+
# If it's already variadic, we don't change anything
967+
if receiver_socket.is_variadic:
968+
return receiver_socket
969+
970+
origin = _safe_get_origin(receiver_socket.type)
971+
972+
# Unwrap Optional types
973+
if origin == Union:
974+
non_none_args = [a for a in get_args(receiver_socket.type) if a is not type(None)]
975+
if len(non_none_args) == 1:
976+
origin = _safe_get_origin(non_none_args[0])
977+
978+
# If the origin is list, we can make the socket lazy variadic
979+
if origin == list:
980+
receiver_socket.is_lazy_variadic = True
981+
receiver_socket.wrap_input_in_list = False
982+
983+
return receiver_socket
984+
969985
def _prepare_component_input_data(self, data: dict[str, Any]) -> dict[str, dict[str, Any]]:
970986
"""
971987
Prepares input data for pipeline components.
Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
---
2+
enhancements:
3+
- |
4+
Auto variadic sockets now also support ``Optional[list[...]]`` input types, in addition to plain ``list[...]``.
5+
fixes:
6+
- |
7+
Fixed smart connection logic to support connecting multiple outputs to a socket whose type is ``Optional[list[...]]`` (e.g. ``list[ChatMessage] | None``).
8+
Previously, connecting two ``list[ChatMessage]`` outputs to ``Agent.messages`` would fail after its type was updated from ``list[ChatMessage]`` to ``list[ChatMessage] | None``.

test/core/pipeline/test_pipeline.py

Lines changed: 59 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,11 +7,19 @@
77

88
import pytest
99

10+
from haystack.components.agents import Agent
1011
from haystack.components.joiners import BranchJoiner
1112
from haystack.core.component import component
1213
from haystack.core.errors import PipelineRuntimeError
1314
from haystack.core.pipeline import Pipeline
14-
from haystack.dataclasses.document import Document
15+
from haystack.dataclasses import ChatMessage, Document
16+
17+
18+
@component
19+
class MockChatGenerator:
20+
@component.output_types(replies=list[ChatMessage])
21+
def run(self, messages: list[ChatMessage]) -> dict[str, list[ChatMessage]]:
22+
return {"replies": [ChatMessage.from_assistant("Hello, world!")]}
1523

1624

1725
@component
@@ -251,3 +259,53 @@ def run(self, document: Document) -> dict[str, Document]:
251259
# Without deep copying the inputs, the second component would also see the modified document and produce
252260
# "modified" instead of "original"
253261
assert result["second"]["output"].content == "original"
262+
263+
def test_pipeline_does_not_corrupt_outputs(self):
264+
"""
265+
Test that a component's output collected via include_outputs_from is not corrupted when a downstream
266+
component receives and mutates the same data in-place.
267+
"""
268+
269+
@component
270+
class Producer:
271+
@component.output_types(doc=Document)
272+
def run(self) -> dict:
273+
return {"doc": Document(content="original")}
274+
275+
@component
276+
class Mutator:
277+
@component.output_types(doc=Document)
278+
def run(self, doc: Document) -> dict:
279+
doc.content = "mutated"
280+
return {"doc": doc}
281+
282+
pipe = Pipeline()
283+
pipe.add_component("producer", Producer())
284+
pipe.add_component("mutator", Mutator())
285+
pipe.connect("producer.doc", "mutator.doc")
286+
287+
result = pipe.run({}, include_outputs_from={"producer"})
288+
289+
assert result["producer"]["doc"].content == "original"
290+
assert result["mutator"]["doc"].content == "mutated"
291+
292+
def test_auto_variadic_connection_to_agent(self):
293+
@component
294+
class MessageProducer:
295+
@component.output_types(messages=list[ChatMessage])
296+
def run(self) -> dict[str, list[ChatMessage]]:
297+
return {"messages": [ChatMessage.from_user("Hello, world!")]}
298+
299+
p = Pipeline()
300+
p.add_component("message_producer", MessageProducer())
301+
p.add_component("message_producer2", MessageProducer())
302+
p.add_component("agent", Agent(chat_generator=MockChatGenerator()))
303+
p.connect("message_producer", "agent.messages")
304+
p.connect("message_producer2", "agent.messages")
305+
306+
result = p.run({})
307+
assert result["agent"]["messages"] == [
308+
ChatMessage.from_user("Hello, world!"),
309+
ChatMessage.from_user("Hello, world!"),
310+
ChatMessage.from_assistant("Hello, world!"),
311+
]

test/core/pipeline/test_pipeline_base.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1949,6 +1949,31 @@ def run(self, numbers: list[int]) -> dict[str, list[int]]:
19491949
assert receiver.__haystack_input__._sockets_dict == {"numbers": inp_socket} # type: ignore[attr-defined]
19501950
assert receiver.__haystack_input__._sockets_dict["numbers"].senders == ["sender1", "sender2"] # type: ignore[attr-defined]
19511951

1952+
def test_connect_auto_variadic_optional_list(self):
1953+
@component
1954+
class ListAcceptor:
1955+
@component.output_types(result=list[int])
1956+
def run(self, numbers: list[int] | None = None) -> dict[str, list[int]]:
1957+
return {"result": numbers or []}
1958+
1959+
pipeline = PipelineBase()
1960+
receiver = ListAcceptor()
1961+
pipeline.add_component("sender1", ListAcceptor())
1962+
pipeline.add_component("sender2", ListAcceptor())
1963+
pipeline.add_component("receiver", receiver)
1964+
1965+
pipeline.connect("sender1.result", "receiver.numbers")
1966+
pipeline.connect("sender2.result", "receiver.numbers")
1967+
1968+
# Check that the receiver's input socket is correctly set to lazy variadic with wrap_input_in_list=False
1969+
inp_socket = InputSocket(
1970+
name="numbers", type=list[int] | None, senders=["sender1", "sender2"], default_value=None
1971+
)
1972+
inp_socket.is_lazy_variadic = True
1973+
inp_socket.wrap_input_in_list = False
1974+
assert receiver.__haystack_input__._sockets_dict == {"numbers": inp_socket} # type: ignore[attr-defined]
1975+
assert receiver.__haystack_input__._sockets_dict["numbers"].senders == ["sender1", "sender2"] # type: ignore[attr-defined]
1976+
19521977
def test_connect_with_conversion_chat_message_to_str(self):
19531978
@component
19541979
class ChatMessageOutput:

0 commit comments

Comments
 (0)