@@ -7,8 +7,8 @@ use parking_lot::RwLock;
77use std:: collections:: HashMap ;
88use std:: sync:: Arc ;
99use std:: time:: { Duration , SystemTime } ;
10- use tokio:: sync:: { mpsc, Notify } ;
11- use tokio:: time:: interval;
10+ use tokio:: sync:: { mpsc, oneshot , Notify } ;
11+ use tokio:: time:: { interval, timeout } ;
1212
1313use crate :: types:: * ;
1414
@@ -50,6 +50,14 @@ pub trait CommunicationBus {
5050 /// Unregister an agent
5151 async fn unregister_agent ( & self , agent_id : AgentId ) -> Result < ( ) , CommunicationError > ;
5252
53+ /// Send a request and wait for response with timeout
54+ async fn request (
55+ & self ,
56+ target_agent : AgentId ,
57+ request_payload : bytes:: Bytes ,
58+ timeout_duration : Duration ,
59+ ) -> Result < bytes:: Bytes , CommunicationError > ;
60+
5361 /// Shutdown the communication bus
5462 async fn shutdown ( & self ) -> Result < ( ) , CommunicationError > ;
5563}
@@ -89,6 +97,7 @@ pub struct DefaultCommunicationBus {
8997 subscriptions : Arc < RwLock < HashMap < String , Vec < AgentId > > > > ,
9098 message_tracker : Arc < RwLock < HashMap < MessageId , MessageTracker > > > ,
9199 dead_letter_queue : Arc < RwLock < DeadLetterQueue > > ,
100+ pending_requests : Arc < RwLock < HashMap < RequestId , oneshot:: Sender < bytes:: Bytes > > > > ,
92101 event_sender : mpsc:: UnboundedSender < CommunicationEvent > ,
93102 shutdown_notify : Arc < Notify > ,
94103 is_running : Arc < RwLock < bool > > ,
@@ -103,6 +112,7 @@ impl DefaultCommunicationBus {
103112 let dead_letter_queue = Arc :: new ( RwLock :: new ( DeadLetterQueue :: new (
104113 config. dead_letter_queue_size ,
105114 ) ) ) ;
115+ let pending_requests = Arc :: new ( RwLock :: new ( HashMap :: new ( ) ) ) ;
106116 let ( event_sender, event_receiver) = mpsc:: unbounded_channel ( ) ;
107117 let shutdown_notify = Arc :: new ( Notify :: new ( ) ) ;
108118 let is_running = Arc :: new ( RwLock :: new ( true ) ) ;
@@ -113,6 +123,7 @@ impl DefaultCommunicationBus {
113123 subscriptions,
114124 message_tracker,
115125 dead_letter_queue,
126+ pending_requests,
116127 event_sender,
117128 shutdown_notify,
118129 is_running,
@@ -134,6 +145,7 @@ impl DefaultCommunicationBus {
134145 let subscriptions = self . subscriptions . clone ( ) ;
135146 let message_tracker = self . message_tracker . clone ( ) ;
136147 let dead_letter_queue = self . dead_letter_queue . clone ( ) ;
148+ let pending_requests = self . pending_requests . clone ( ) ;
137149 let shutdown_notify = self . shutdown_notify . clone ( ) ;
138150 let config = self . config . clone ( ) ;
139151
@@ -148,6 +160,7 @@ impl DefaultCommunicationBus {
148160 & subscriptions,
149161 & message_tracker,
150162 & dead_letter_queue,
163+ & pending_requests,
151164 & config,
152165 ) . await ;
153166 } else {
@@ -198,13 +211,24 @@ impl DefaultCommunicationBus {
198211 subscriptions : & Arc < RwLock < HashMap < String , Vec < AgentId > > > > ,
199212 message_tracker : & Arc < RwLock < HashMap < MessageId , MessageTracker > > > ,
200213 dead_letter_queue : & Arc < RwLock < DeadLetterQueue > > ,
214+ pending_requests : & Arc < RwLock < HashMap < RequestId , oneshot:: Sender < bytes:: Bytes > > > > ,
201215 config : & CommunicationConfig ,
202216 ) {
203217 match event {
204218 CommunicationEvent :: MessageSent { message } => {
205219 let recipient = message. recipient ;
206220 let message_id = message. id ;
207221
222+ // Check if this is a response to a pending request
223+ if let MessageType :: Response ( request_id) = & message. message_type {
224+ if let Some ( sender) = pending_requests. write ( ) . remove ( request_id) {
225+ // Send response payload to waiting request
226+ let _ = sender. send ( message. payload . data . clone ( ) ) ;
227+ tracing:: debug!( "Response {} sent for request {:?}" , message_id, request_id) ;
228+ return ;
229+ }
230+ }
231+
208232 // Add to message tracker
209233 message_tracker
210234 . write ( )
@@ -302,6 +326,7 @@ impl DefaultCommunicationBus {
302326 subscriptions,
303327 message_tracker,
304328 dead_letter_queue,
329+ pending_requests,
305330 config,
306331 ) )
307332 . await ;
@@ -518,6 +543,66 @@ impl CommunicationBus for DefaultCommunicationBus {
518543 Ok ( ( ) )
519544 }
520545
546+ async fn request (
547+ & self ,
548+ target_agent : AgentId ,
549+ request_payload : bytes:: Bytes ,
550+ timeout_duration : Duration ,
551+ ) -> Result < bytes:: Bytes , CommunicationError > {
552+ if !* self . is_running . read ( ) {
553+ return Err ( CommunicationError :: ShuttingDown ) ;
554+ }
555+
556+ // Create request ID and oneshot channel for response
557+ let request_id = RequestId :: new ( ) ;
558+ let ( response_sender, response_receiver) = oneshot:: channel ( ) ;
559+
560+ // Store the response sender
561+ self . pending_requests . write ( ) . insert ( request_id, response_sender) ;
562+
563+ // Create request message
564+ let request_message = SecureMessage {
565+ id : MessageId :: new ( ) ,
566+ sender : AgentId :: new ( ) , // TODO: Should be the actual sender agent ID
567+ recipient : Some ( target_agent) ,
568+ topic : None ,
569+ message_type : MessageType :: Request ( request_id) ,
570+ payload : EncryptedPayload {
571+ data : request_payload,
572+ nonce : vec ! [ 0u8 ; 12 ] , // TODO: Generate proper nonce
573+ encryption_algorithm : EncryptionAlgorithm :: Aes256Gcm ,
574+ } ,
575+ signature : MessageSignature {
576+ signature : vec ! [ 0u8 ; 64 ] , // TODO: Generate proper signature
577+ algorithm : SignatureAlgorithm :: Ed25519 ,
578+ public_key : vec ! [ 0u8 ; 32 ] , // TODO: Use proper public key
579+ } ,
580+ ttl : timeout_duration,
581+ timestamp : SystemTime :: now ( ) ,
582+ } ;
583+
584+ // Send the request
585+ self . send_message ( request_message) . await ?;
586+
587+ // Wait for response with timeout
588+ match timeout ( timeout_duration, response_receiver) . await {
589+ Ok ( Ok ( response_payload) ) => Ok ( response_payload) ,
590+ Ok ( Err ( _) ) => {
591+ // Remove from pending requests if channel was dropped
592+ self . pending_requests . write ( ) . remove ( & request_id) ;
593+ Err ( CommunicationError :: RequestCancelled { request_id } )
594+ }
595+ Err ( _) => {
596+ // Timeout occurred
597+ self . pending_requests . write ( ) . remove ( & request_id) ;
598+ Err ( CommunicationError :: RequestTimeout {
599+ request_id,
600+ timeout : timeout_duration,
601+ } )
602+ }
603+ }
604+ }
605+
521606 async fn shutdown ( & self ) -> Result < ( ) , CommunicationError > {
522607 tracing:: info!( "Shutting down communication bus" ) ;
523608
@@ -826,4 +911,91 @@ mod tests {
826911 let result = bus. receive_messages ( agent_id) . await ;
827912 assert ! ( result. is_err( ) ) ;
828913 }
914+
915+ #[ tokio:: test]
916+ async fn test_request_response_timeout ( ) {
917+ let bus = DefaultCommunicationBus :: new ( CommunicationConfig :: default ( ) )
918+ . await
919+ . unwrap ( ) ;
920+ let target_agent = AgentId :: new ( ) ;
921+
922+ // Register target agent (but it won't respond)
923+ bus. register_agent ( target_agent) . await . unwrap ( ) ;
924+ tokio:: time:: sleep ( Duration :: from_millis ( 50 ) ) . await ;
925+
926+ // Make request with short timeout
927+ let request_payload = bytes:: Bytes :: from ( "test request" ) ;
928+ let timeout = Duration :: from_millis ( 100 ) ;
929+
930+ let result = bus. request ( target_agent, request_payload, timeout) . await ;
931+ assert ! ( result. is_err( ) ) ;
932+
933+ if let Err ( CommunicationError :: RequestTimeout { request_id : _, timeout : actual_timeout } ) = result {
934+ assert_eq ! ( actual_timeout, timeout) ;
935+ } else {
936+ panic ! ( "Expected RequestTimeout error" ) ;
937+ }
938+ }
939+
940+ #[ tokio:: test]
941+ async fn test_request_response_success ( ) {
942+ let bus = DefaultCommunicationBus :: new ( CommunicationConfig :: default ( ) )
943+ . await
944+ . unwrap ( ) ;
945+ let requester = AgentId :: new ( ) ;
946+ let responder = AgentId :: new ( ) ;
947+
948+ // Register both agents
949+ bus. register_agent ( requester) . await . unwrap ( ) ;
950+ bus. register_agent ( responder) . await . unwrap ( ) ;
951+ tokio:: time:: sleep ( Duration :: from_millis ( 50 ) ) . await ;
952+
953+ let request_payload = bytes:: Bytes :: from ( "test request" ) ;
954+ let response_payload = bytes:: Bytes :: from ( "test response" ) ;
955+
956+ // Start request in background
957+ let bus_clone = Arc :: new ( bus) ;
958+ let request_bus = bus_clone. clone ( ) ;
959+ let request_handle = tokio:: spawn ( async move {
960+ request_bus. request ( responder, request_payload, Duration :: from_secs ( 5 ) ) . await
961+ } ) ;
962+
963+ // Give request time to be sent
964+ tokio:: time:: sleep ( Duration :: from_millis ( 100 ) ) . await ;
965+
966+ // Responder receives the request and sends response
967+ let messages = bus_clone. receive_messages ( responder) . await . unwrap ( ) ;
968+ assert_eq ! ( messages. len( ) , 1 ) ;
969+ assert ! ( matches!( messages[ 0 ] . message_type, MessageType :: Request ( _) ) ) ;
970+
971+ // Extract request ID and send response
972+ if let MessageType :: Request ( request_id) = & messages[ 0 ] . message_type {
973+ let response_message = SecureMessage {
974+ id : MessageId :: new ( ) ,
975+ sender : responder,
976+ recipient : Some ( requester) ,
977+ topic : None ,
978+ message_type : MessageType :: Response ( * request_id) ,
979+ payload : EncryptedPayload {
980+ data : response_payload. clone ( ) ,
981+ nonce : vec ! [ 0u8 ; 12 ] ,
982+ encryption_algorithm : EncryptionAlgorithm :: Aes256Gcm ,
983+ } ,
984+ signature : MessageSignature {
985+ signature : vec ! [ 0u8 ; 64 ] ,
986+ algorithm : SignatureAlgorithm :: Ed25519 ,
987+ public_key : vec ! [ 0u8 ; 32 ] ,
988+ } ,
989+ ttl : Duration :: from_secs ( 3600 ) ,
990+ timestamp : SystemTime :: now ( ) ,
991+ } ;
992+
993+ bus_clone. send_message ( response_message) . await . unwrap ( ) ;
994+ }
995+
996+ // Wait for request to complete
997+ let result = request_handle. await . unwrap ( ) ;
998+ assert ! ( result. is_ok( ) ) ;
999+ assert_eq ! ( result. unwrap( ) , response_payload) ;
1000+ }
8291001}
0 commit comments