2121 PipelineDrawingError ,
2222 PipelineError ,
2323 PipelineMaxComponentRuns ,
24+ PipelineRuntimeError ,
2425 PipelineUnmarshalError ,
2526 PipelineValidationError ,
2627)
3940 component_to_dict ,
4041 generate_qualified_class_name ,
4142)
42- from haystack .core .type_utils import _safe_get_origin , _type_name , _types_are_compatible
43+ from haystack .core .type_utils import _convert_value , _safe_get_origin , _type_name , _types_are_compatible
4344from haystack .marshal import Marshaller , YamlMarshaller
4445from haystack .utils import is_in_jupyter , type_serialization
4546
@@ -499,11 +500,25 @@ def connect(self, sender: str, receiver: str) -> "PipelineBase": # noqa: PLR091
499500 [receiver_socket ] if receiver_socket else list (receiver_sockets .values ())
500501 )
501502
503+ is_strict_match = True
504+
502505 # Find all possible connections between these two components
503- possible_connections = []
506+ possible_connections : list [ tuple [ OutputSocket , InputSocket , bool ]] = []
504507 for sender_sock , receiver_sock in itertools .product (sender_socket_candidates , receiver_socket_candidates ):
505- if _types_are_compatible (sender_sock .type , receiver_sock .type , self ._connection_type_validation ):
506- possible_connections .append ((sender_sock , receiver_sock ))
508+ is_compat , is_strict = _types_are_compatible (
509+ sender_sock .type , receiver_sock .type , self ._connection_type_validation
510+ )
511+ if is_compat :
512+ possible_connections .append ((sender_sock , receiver_sock , is_strict ))
513+
514+ # If there are multiple possibilities, prioritize strict matches over convertible ones.
515+ # This ensures backward compatibility: previously, pipelines did not allow type conversion.
516+ if len (possible_connections ) > 1 and self ._connection_type_validation :
517+ strict_matches = [
518+ (out_sock , in_sock , is_strict ) for out_sock , in_sock , is_strict in possible_connections if is_strict
519+ ]
520+ if strict_matches :
521+ possible_connections [:] = strict_matches
507522
508523 # We need this status for error messages, since we might need it in multiple places we calculate it here
509524 status = _connections_status (
@@ -532,11 +547,14 @@ def connect(self, sender: str, receiver: str) -> "PipelineBase": # noqa: PLR091
532547 # There's only one possible connection, use it
533548 sender_socket = possible_connections [0 ][0 ]
534549 receiver_socket = possible_connections [0 ][1 ]
550+ is_strict_match = possible_connections [0 ][2 ]
535551
536552 if len (possible_connections ) > 1 :
537553 # There are multiple possible connection, let's try to match them by name
538554 name_matches = [
539- (out_sock , in_sock ) for out_sock , in_sock in possible_connections if in_sock .name == out_sock .name
555+ (out_sock , in_sock , is_strict )
556+ for out_sock , in_sock , is_strict in possible_connections
557+ if in_sock .name == out_sock .name
540558 ]
541559 if len (name_matches ) != 1 :
542560 # There's are either no matches or more than one, we can't pick one reliably
@@ -552,6 +570,7 @@ def connect(self, sender: str, receiver: str) -> "PipelineBase": # noqa: PLR091
552570 # Get the only possible match
553571 sender_socket = name_matches [0 ][0 ]
554572 receiver_socket = name_matches [0 ][1 ]
573+ is_strict_match = name_matches [0 ][2 ]
555574
556575 # Connection must be valid on both sender/receiver sides
557576 if not sender_socket or not receiver_socket or not sender_component_name or not receiver_component_name :
@@ -613,6 +632,7 @@ def connect(self, sender: str, receiver: str) -> "PipelineBase": # noqa: PLR091
613632 from_socket = sender_socket ,
614633 to_socket = receiver_socket ,
615634 mandatory = receiver_socket .is_mandatory ,
635+ convert = not is_strict_match ,
616636 )
617637 return self
618638
@@ -1017,22 +1037,23 @@ def from_template(
10171037 msg += f"Source:\n { rendered } "
10181038 raise PipelineUnmarshalError (msg )
10191039
1020- def _find_receivers_from (self , component_name : str ) -> list [tuple [str , OutputSocket , InputSocket ]]:
1040+ def _find_receivers_from (self , component_name : str ) -> list [tuple [str , OutputSocket , InputSocket , bool ]]:
10211041 """
10221042 Utility function to find all Components that receive input from `component_name`.
10231043
10241044 :param component_name:
10251045 Name of the sender Component
10261046
10271047 :returns:
1028- List of tuples containing name of the receiver Component and sender OutputSocket
1029- and receiver InputSocket instances
1048+ List of tuples containing name of the receiver Component, sender OutputSocket,
1049+ receiver InputSocket instances, and a boolean indicating if conversion is needed.
10301050 """
10311051 res = []
10321052 for _ , receiver_name , connection in self .graph .edges (nbunch = component_name , data = True ):
10331053 sender_socket : OutputSocket = connection ["from_socket" ]
10341054 receiver_socket : InputSocket = connection ["to_socket" ]
1035- res .append ((receiver_name , sender_socket , receiver_socket ))
1055+ convert : bool = connection .get ("convert" , False )
1056+ res .append ((receiver_name , sender_socket , receiver_socket , convert ))
10361057 return res
10371058
10381059 @staticmethod
@@ -1277,12 +1298,13 @@ def _tiebreak_waiting_components(
12771298
12781299 return component_name , topological_sort
12791300
1280- @staticmethod
12811301 def _write_component_outputs (
1302+ self ,
1303+ * ,
12821304 component_name : str ,
12831305 component_outputs : Mapping [str , Any ],
12841306 inputs : dict [str , Any ],
1285- receivers : list [tuple [str , OutputSocket , InputSocket ]],
1307+ receivers : list [tuple [str , OutputSocket , InputSocket , bool ]],
12861308 include_outputs_from : set [str ],
12871309 ) -> Mapping [str , Any ]:
12881310 """
@@ -1291,16 +1313,40 @@ def _write_component_outputs(
12911313 :param component_name: The name of the component.
12921314 :param component_outputs: The outputs of the component.
12931315 :param inputs: The current global input state.
1294- :param receivers: List of tuples containing name of the receiver Component and sender OutputSocket
1295- and receiver InputSocket instances.
1316+ :param receivers: List of tuples containing name of the receiver Component, sender OutputSocket,
1317+ receiver InputSocket instances, and a boolean indicating if conversion is needed .
12961318 :param include_outputs_from: Set of component names that should always return an output from the pipeline.
12971319 """
1298- for receiver_name , sender_socket , receiver_socket in receivers :
1320+ for receiver_name , sender_socket , receiver_socket , convert in receivers :
12991321 # We either get the value that was produced by the actor or we use the _NO_OUTPUT_PRODUCED class to indicate
13001322 # that the sender did not produce an output for this socket.
13011323 # This allows us to track if a predecessor already ran but did not produce an output.
13021324 value = component_outputs .get (sender_socket .name , _NO_OUTPUT_PRODUCED )
13031325
1326+ if value is not _NO_OUTPUT_PRODUCED and convert :
1327+ try :
1328+ value = _convert_value (
1329+ value = value , sender_type = sender_socket .type , receiver_type = receiver_socket .type
1330+ )
1331+ except Exception as e :
1332+ sender_node = self .graph .nodes .get (component_name )
1333+ sender_instance = sender_node .get ("instance" ) if sender_node else None
1334+ sender_type_name = type (sender_instance ).__name__ if sender_instance else "unknown"
1335+
1336+ receiver_node = self .graph .nodes .get (receiver_name )
1337+ receiver_instance = receiver_node .get ("instance" ) if receiver_node else None
1338+ receiver_type_name = type (receiver_instance ).__name__ if receiver_instance else "unknown"
1339+
1340+ msg = (
1341+ f"Failed to perform conversion between components:\n "
1342+ f"Sender component: '{ component_name } ' (type: '{ sender_type_name } ')\n "
1343+ f"Sender socket: '{ sender_socket .name } '\n "
1344+ f"Receiver component: '{ receiver_name } ' (type: '{ receiver_type_name } ')\n "
1345+ f"Receiver socket: '{ receiver_socket .name } '\n "
1346+ f"Error: { e } "
1347+ )
1348+ raise PipelineRuntimeError (component_name = None , component_type = None , message = msg ) from e
1349+
13041350 if receiver_name not in inputs :
13051351 inputs [receiver_name ] = {}
13061352
@@ -1332,7 +1378,7 @@ def _write_component_outputs(
13321378
13331379 # We prune outputs that were consumed by any receiving sockets.
13341380 # All remaining outputs will be added to the final outputs of the pipeline.
1335- consumed_outputs = {sender_socket .name for _ , sender_socket , __ in receivers }
1381+ consumed_outputs = {sender_socket .name for _ , sender_socket , __ , ___ in receivers }
13361382 pruned_outputs = {key : value for key , value in component_outputs .items () if key not in consumed_outputs }
13371383
13381384 return pruned_outputs
@@ -1431,7 +1477,7 @@ def _merge_super_component_pipelines(self) -> tuple[networkx.MultiDiGraph, dict[
14311477 # find a matching input socket in the entry point
14321478 entry_point_sockets = internal_graph .nodes [entry_point ]["input_sockets" ]
14331479 for socket_name , socket in entry_point_sockets .items ():
1434- if _types_are_compatible (sender_socket .type , socket .type , self ._connection_type_validation ):
1480+ if _types_are_compatible (sender_socket .type , socket .type , self ._connection_type_validation )[ 0 ] :
14351481 merged_graph .add_edge (
14361482 sender ,
14371483 entry_point ,
@@ -1449,7 +1495,9 @@ def _merge_super_component_pipelines(self) -> tuple[networkx.MultiDiGraph, dict[
14491495 # find a matching output socket in the exit point
14501496 exit_point_sockets = internal_graph .nodes [exit_point ]["output_sockets" ]
14511497 for socket_name , socket in exit_point_sockets .items ():
1452- if _types_are_compatible (socket .type , receiver_socket .type , self ._connection_type_validation ):
1498+ if _types_are_compatible (socket .type , receiver_socket .type , self ._connection_type_validation )[
1499+ 0
1500+ ]:
14531501 merged_graph .add_edge (
14541502 exit_point ,
14551503 receiver ,
0 commit comments