Skip to content

Commit 9348771

Browse files
authored
feat: Pipelines - support connection and conversion between ChatMessage and str (#10507)
* draft * optimizations * more * simplify * better types * better types * license + improvs * test + improvements + types * revert unneeded changes * fix lint * remove 3.10 guards * relnote * fix * improvements from feedback * revert wrong change + more tests
1 parent de2da07 commit 9348771

File tree

5 files changed

+455
-158
lines changed

5 files changed

+455
-158
lines changed

haystack/core/pipeline/base.py

Lines changed: 65 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
PipelineDrawingError,
2222
PipelineError,
2323
PipelineMaxComponentRuns,
24+
PipelineRuntimeError,
2425
PipelineUnmarshalError,
2526
PipelineValidationError,
2627
)
@@ -39,7 +40,7 @@
3940
component_to_dict,
4041
generate_qualified_class_name,
4142
)
42-
from haystack.core.type_utils import _safe_get_origin, _type_name, _types_are_compatible
43+
from haystack.core.type_utils import _convert_value, _safe_get_origin, _type_name, _types_are_compatible
4344
from haystack.marshal import Marshaller, YamlMarshaller
4445
from haystack.utils import is_in_jupyter, type_serialization
4546

@@ -499,11 +500,25 @@ def connect(self, sender: str, receiver: str) -> "PipelineBase": # noqa: PLR091
499500
[receiver_socket] if receiver_socket else list(receiver_sockets.values())
500501
)
501502

503+
is_strict_match = True
504+
502505
# Find all possible connections between these two components
503-
possible_connections = []
506+
possible_connections: list[tuple[OutputSocket, InputSocket, bool]] = []
504507
for sender_sock, receiver_sock in itertools.product(sender_socket_candidates, receiver_socket_candidates):
505-
if _types_are_compatible(sender_sock.type, receiver_sock.type, self._connection_type_validation):
506-
possible_connections.append((sender_sock, receiver_sock))
508+
is_compat, is_strict = _types_are_compatible(
509+
sender_sock.type, receiver_sock.type, self._connection_type_validation
510+
)
511+
if is_compat:
512+
possible_connections.append((sender_sock, receiver_sock, is_strict))
513+
514+
# If there are multiple possibilities, prioritize strict matches over convertible ones.
515+
# This ensures backward compatibility: previously, pipelines did not allow type conversion.
516+
if len(possible_connections) > 1 and self._connection_type_validation:
517+
strict_matches = [
518+
(out_sock, in_sock, is_strict) for out_sock, in_sock, is_strict in possible_connections if is_strict
519+
]
520+
if strict_matches:
521+
possible_connections[:] = strict_matches
507522

508523
# We need this status for error messages, since we might need it in multiple places we calculate it here
509524
status = _connections_status(
@@ -532,11 +547,14 @@ def connect(self, sender: str, receiver: str) -> "PipelineBase": # noqa: PLR091
532547
# There's only one possible connection, use it
533548
sender_socket = possible_connections[0][0]
534549
receiver_socket = possible_connections[0][1]
550+
is_strict_match = possible_connections[0][2]
535551

536552
if len(possible_connections) > 1:
537553
# There are multiple possible connection, let's try to match them by name
538554
name_matches = [
539-
(out_sock, in_sock) for out_sock, in_sock in possible_connections if in_sock.name == out_sock.name
555+
(out_sock, in_sock, is_strict)
556+
for out_sock, in_sock, is_strict in possible_connections
557+
if in_sock.name == out_sock.name
540558
]
541559
if len(name_matches) != 1:
542560
# There's are either no matches or more than one, we can't pick one reliably
@@ -552,6 +570,7 @@ def connect(self, sender: str, receiver: str) -> "PipelineBase": # noqa: PLR091
552570
# Get the only possible match
553571
sender_socket = name_matches[0][0]
554572
receiver_socket = name_matches[0][1]
573+
is_strict_match = name_matches[0][2]
555574

556575
# Connection must be valid on both sender/receiver sides
557576
if not sender_socket or not receiver_socket or not sender_component_name or not receiver_component_name:
@@ -613,6 +632,7 @@ def connect(self, sender: str, receiver: str) -> "PipelineBase": # noqa: PLR091
613632
from_socket=sender_socket,
614633
to_socket=receiver_socket,
615634
mandatory=receiver_socket.is_mandatory,
635+
convert=not is_strict_match,
616636
)
617637
return self
618638

@@ -1017,22 +1037,23 @@ def from_template(
10171037
msg += f"Source:\n{rendered}"
10181038
raise PipelineUnmarshalError(msg)
10191039

1020-
def _find_receivers_from(self, component_name: str) -> list[tuple[str, OutputSocket, InputSocket]]:
1040+
def _find_receivers_from(self, component_name: str) -> list[tuple[str, OutputSocket, InputSocket, bool]]:
10211041
"""
10221042
Utility function to find all Components that receive input from `component_name`.
10231043
10241044
:param component_name:
10251045
Name of the sender Component
10261046
10271047
:returns:
1028-
List of tuples containing name of the receiver Component and sender OutputSocket
1029-
and receiver InputSocket instances
1048+
List of tuples containing name of the receiver Component, sender OutputSocket,
1049+
receiver InputSocket instances, and a boolean indicating if conversion is needed.
10301050
"""
10311051
res = []
10321052
for _, receiver_name, connection in self.graph.edges(nbunch=component_name, data=True):
10331053
sender_socket: OutputSocket = connection["from_socket"]
10341054
receiver_socket: InputSocket = connection["to_socket"]
1035-
res.append((receiver_name, sender_socket, receiver_socket))
1055+
convert: bool = connection.get("convert", False)
1056+
res.append((receiver_name, sender_socket, receiver_socket, convert))
10361057
return res
10371058

10381059
@staticmethod
@@ -1277,12 +1298,13 @@ def _tiebreak_waiting_components(
12771298

12781299
return component_name, topological_sort
12791300

1280-
@staticmethod
12811301
def _write_component_outputs(
1302+
self,
1303+
*,
12821304
component_name: str,
12831305
component_outputs: Mapping[str, Any],
12841306
inputs: dict[str, Any],
1285-
receivers: list[tuple[str, OutputSocket, InputSocket]],
1307+
receivers: list[tuple[str, OutputSocket, InputSocket, bool]],
12861308
include_outputs_from: set[str],
12871309
) -> Mapping[str, Any]:
12881310
"""
@@ -1291,16 +1313,40 @@ def _write_component_outputs(
12911313
:param component_name: The name of the component.
12921314
:param component_outputs: The outputs of the component.
12931315
:param inputs: The current global input state.
1294-
:param receivers: List of tuples containing name of the receiver Component and sender OutputSocket
1295-
and receiver InputSocket instances.
1316+
:param receivers: List of tuples containing name of the receiver Component, sender OutputSocket,
1317+
receiver InputSocket instances, and a boolean indicating if conversion is needed.
12961318
:param include_outputs_from: Set of component names that should always return an output from the pipeline.
12971319
"""
1298-
for receiver_name, sender_socket, receiver_socket in receivers:
1320+
for receiver_name, sender_socket, receiver_socket, convert in receivers:
12991321
# We either get the value that was produced by the actor or we use the _NO_OUTPUT_PRODUCED class to indicate
13001322
# that the sender did not produce an output for this socket.
13011323
# This allows us to track if a predecessor already ran but did not produce an output.
13021324
value = component_outputs.get(sender_socket.name, _NO_OUTPUT_PRODUCED)
13031325

1326+
if value is not _NO_OUTPUT_PRODUCED and convert:
1327+
try:
1328+
value = _convert_value(
1329+
value=value, sender_type=sender_socket.type, receiver_type=receiver_socket.type
1330+
)
1331+
except Exception as e:
1332+
sender_node = self.graph.nodes.get(component_name)
1333+
sender_instance = sender_node.get("instance") if sender_node else None
1334+
sender_type_name = type(sender_instance).__name__ if sender_instance else "unknown"
1335+
1336+
receiver_node = self.graph.nodes.get(receiver_name)
1337+
receiver_instance = receiver_node.get("instance") if receiver_node else None
1338+
receiver_type_name = type(receiver_instance).__name__ if receiver_instance else "unknown"
1339+
1340+
msg = (
1341+
f"Failed to perform conversion between components:\n"
1342+
f"Sender component: '{component_name}' (type: '{sender_type_name}')\n"
1343+
f"Sender socket: '{sender_socket.name}'\n"
1344+
f"Receiver component: '{receiver_name}' (type: '{receiver_type_name}')\n"
1345+
f"Receiver socket: '{receiver_socket.name}'\n"
1346+
f"Error: {e}"
1347+
)
1348+
raise PipelineRuntimeError(component_name=None, component_type=None, message=msg) from e
1349+
13041350
if receiver_name not in inputs:
13051351
inputs[receiver_name] = {}
13061352

@@ -1332,7 +1378,7 @@ def _write_component_outputs(
13321378

13331379
# We prune outputs that were consumed by any receiving sockets.
13341380
# All remaining outputs will be added to the final outputs of the pipeline.
1335-
consumed_outputs = {sender_socket.name for _, sender_socket, __ in receivers}
1381+
consumed_outputs = {sender_socket.name for _, sender_socket, __, ___ in receivers}
13361382
pruned_outputs = {key: value for key, value in component_outputs.items() if key not in consumed_outputs}
13371383

13381384
return pruned_outputs
@@ -1431,7 +1477,7 @@ def _merge_super_component_pipelines(self) -> tuple[networkx.MultiDiGraph, dict[
14311477
# find a matching input socket in the entry point
14321478
entry_point_sockets = internal_graph.nodes[entry_point]["input_sockets"]
14331479
for socket_name, socket in entry_point_sockets.items():
1434-
if _types_are_compatible(sender_socket.type, socket.type, self._connection_type_validation):
1480+
if _types_are_compatible(sender_socket.type, socket.type, self._connection_type_validation)[0]:
14351481
merged_graph.add_edge(
14361482
sender,
14371483
entry_point,
@@ -1449,7 +1495,9 @@ def _merge_super_component_pipelines(self) -> tuple[networkx.MultiDiGraph, dict[
14491495
# find a matching output socket in the exit point
14501496
exit_point_sockets = internal_graph.nodes[exit_point]["output_sockets"]
14511497
for socket_name, socket in exit_point_sockets.items():
1452-
if _types_are_compatible(socket.type, receiver_socket.type, self._connection_type_validation):
1498+
if _types_are_compatible(socket.type, receiver_socket.type, self._connection_type_validation)[
1499+
0
1500+
]:
14531501
merged_graph.add_edge(
14541502
exit_point,
14551503
receiver,

haystack/core/type_utils.py

Lines changed: 68 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -4,24 +4,33 @@
44

55
import collections.abc
66
from types import NoneType, UnionType
7-
from typing import Any, TypeVar, Union, get_args, get_origin
7+
from typing import Any, Union, get_args, get_origin
88

9-
T = TypeVar("T")
9+
from haystack.dataclasses import ChatMessage
1010

1111

12-
def _types_are_compatible(sender: type | UnionType, receiver: type | UnionType, type_validation: bool = True) -> bool:
12+
def _types_are_compatible(
13+
sender: type | UnionType, receiver: type | UnionType, type_validation: bool = True
14+
) -> tuple[bool, bool]:
1315
"""
1416
Determines if two types are compatible based on the specified validation mode.
1517
1618
:param sender: The sender type.
1719
:param receiver: The receiver type.
1820
:param type_validation: Whether to perform strict type validation.
19-
:return: True if the types are compatible, False otherwise.
21+
:return: A tuple where the first element is True if the types are compatible, and the second
22+
element is True if they are strictly compatible.
2023
"""
21-
if type_validation:
22-
return _strict_types_are_compatible(sender, receiver)
23-
else:
24-
return True
24+
if not type_validation:
25+
return True, True
26+
27+
if _strict_types_are_compatible(sender, receiver):
28+
return True, True
29+
30+
if _types_are_convertible(sender, receiver):
31+
return True, False
32+
33+
return False, False
2534

2635

2736
def _safe_get_origin(_type: type | UnionType) -> type | None:
@@ -43,13 +52,60 @@ def _safe_get_origin(_type: type | UnionType) -> type | None:
4352
return origin
4453

4554

46-
def _strict_types_are_compatible(sender, receiver): # pylint: disable=too-many-return-statements
55+
def _contains_type(container: Any, target: Any) -> bool:
56+
"""Checks if the container type includes the target type"""
57+
if container == target:
58+
return True
59+
return _safe_get_origin(container) is Union and target in get_args(container)
60+
61+
62+
def _types_are_convertible(sender: Any, receiver: Any) -> bool:
63+
"""
64+
Checks whether the sender type is convertible to the receiver type.
65+
66+
ChatMessage is convertible to str and vice versa.
67+
"""
68+
# Optional[T] must not connect to T
69+
if _contains_type(sender, NoneType) and not _contains_type(receiver, NoneType):
70+
return False
71+
72+
# if sender is a single type and receiver is a Union/Optional containing that type, they are convertible
73+
# e.g. str is convertible to Optional[str] or Union[str, int] etc.
74+
if _contains_type(receiver, sender):
75+
return True
76+
77+
if _contains_type(sender, ChatMessage) and _contains_type(receiver, str):
78+
return True
79+
80+
return _contains_type(sender, str) and _contains_type(receiver, ChatMessage)
81+
82+
83+
def _convert_value(value: Any, sender_type: Any, receiver_type: Any) -> Any:
84+
"""
85+
Converts a value from the sender type to the receiver type.
86+
87+
:param value: The value to convert.
88+
:param sender_type: The sender type.
89+
:param receiver_type: The receiver type.
90+
:return: The converted value.
91+
"""
92+
if _contains_type(sender_type, ChatMessage) and _contains_type(receiver_type, str):
93+
if value.text is None:
94+
msg = "Cannot convert `ChatMessage` to `str` because it has no text. "
95+
raise ValueError(msg)
96+
return value.text
97+
98+
if _contains_type(sender_type, str) and _contains_type(receiver_type, ChatMessage):
99+
return ChatMessage.from_user(value)
100+
101+
return value
102+
103+
104+
def _strict_types_are_compatible(sender: Any, receiver: Any) -> bool: # pylint: disable=too-many-return-statements
47105
"""
48106
Checks whether the sender type is equal to or a subtype of the receiver type under strict validation.
49107
50-
Note: this method has no pretense to perform proper type matching. It especially does not deal with aliasing of
51-
typing classes such as `List` or `Dict` to their runtime counterparts `list` and `dict`. It also does not deal well
52-
with "bare" types, so `List` is treated differently from `List[Any]`, even though they should be the same.
108+
Note: this method has no pretense to perform complete type matching.
53109
Consider simplifying the typing of your components if you observe unexpected errors during component connection.
54110
55111
:param sender: The sender type.
Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
---
2+
features:
3+
- |
4+
Pipelines now support connection and automatic conversion between ``ChatMessage`` and ``str`` types.
5+
- When a ``str`` output is connected to a ``ChatMessage`` input, it is automatically converted to a user
6+
``ChatMessage``.
7+
- When a ``ChatMessage`` output is connected to a ``str`` input, its ``text`` attribute is automatically
8+
extracted. If ``text`` is ``None``, an informative ``PipelineRuntimeError`` is raised.
9+
- To maintain backward compatibility, when multiple connections are available, strict type matching is prioritized
10+
over conversion.

0 commit comments

Comments
 (0)