@@ -40,8 +40,8 @@ use crate::gpu::memory::{ensure_device_memory_available, map_allocation_error};
4040use cudarc:: driver:: { DevicePtr , DevicePtrMut } ;
4141#[ cfg( target_os = "linux" ) ]
4242use qdp_kernels:: {
43- launch_amplitude_encode, launch_amplitude_encode_batch, launch_l2_norm , launch_l2_norm_batch ,
44- launch_l2_norm_f32,
43+ launch_amplitude_encode, launch_amplitude_encode_batch, launch_amplitude_encode_batch_f32 ,
44+ launch_l2_norm , launch_l2_norm_batch , launch_l2_norm_batch_f32 , launch_l2_norm_f32,
4545} ;
4646#[ cfg( target_os = "linux" ) ]
4747use std:: ffi:: c_void;
@@ -206,7 +206,7 @@ impl QuantumEncoder for AmplitudeEncoder {
206206 // Allocate single large GPU buffer for all states
207207 let batch_state_vector = {
208208 crate :: profile_scope!( "GPU::AllocBatch" ) ;
209- GpuStateVector :: new_batch ( device, num_samples, num_qubits) ?
209+ GpuStateVector :: new_batch ( device, num_samples, num_qubits, Precision :: Float64 ) ?
210210 } ;
211211
212212 // Upload input data to GPU
@@ -386,7 +386,7 @@ impl QuantumEncoder for AmplitudeEncoder {
386386 let input_batch_d = input_batch_d as * const f64 ;
387387 let batch_state_vector = {
388388 crate :: profile_scope!( "GPU::AllocBatch" ) ;
389- GpuStateVector :: new_batch ( device, num_samples, num_qubits) ?
389+ GpuStateVector :: new_batch ( device, num_samples, num_qubits, Precision :: Float64 ) ?
390390 } ;
391391 let inv_norms_gpu = {
392392 crate :: profile_scope!( "GPU::BatchNormKernel" ) ;
@@ -579,6 +579,119 @@ impl AmplitudeEncoder {
579579}
580580
581581impl AmplitudeEncoder {
582+ /// Encode a batch directly from a GPU float32 pointer.
583+ ///
584+ /// # Safety
585+ /// The caller must ensure `input_batch_d` points to valid GPU memory containing
586+ /// at least `num_samples * sample_size` f32 elements on the same device as `device`.
587+ #[ cfg( target_os = "linux" ) ]
588+ pub unsafe fn encode_batch_from_gpu_ptr_f32_with_stream (
589+ device : & Arc < CudaDevice > ,
590+ input_batch_d : * const f32 ,
591+ num_samples : usize ,
592+ sample_size : usize ,
593+ num_qubits : usize ,
594+ stream : * mut c_void ,
595+ ) -> Result < GpuStateVector > {
596+ let state_len = 1 << num_qubits;
597+ if num_samples == 0 {
598+ return Err ( MahoutError :: InvalidInput (
599+ "Number of samples cannot be zero" . into ( ) ,
600+ ) ) ;
601+ }
602+ if sample_size == 0 {
603+ return Err ( MahoutError :: InvalidInput (
604+ "Sample size cannot be zero" . into ( ) ,
605+ ) ) ;
606+ }
607+ if sample_size > state_len {
608+ return Err ( MahoutError :: InvalidInput ( format ! (
609+ "Sample size {} exceeds state vector size {} (2^{} qubits)" ,
610+ sample_size, state_len, num_qubits
611+ ) ) ) ;
612+ }
613+
614+ let batch_state_vector =
615+ GpuStateVector :: new_batch ( device, num_samples, num_qubits, Precision :: Float32 ) ?;
616+
617+ let inv_norms_gpu = {
618+ crate :: profile_scope!( "GPU::BatchNormKernelF32" ) ;
619+ use cudarc:: driver:: DevicePtrMut ;
620+
621+ let mut buffer = device. alloc_zeros :: < f32 > ( num_samples) . map_err ( |e| {
622+ MahoutError :: MemoryAllocation ( format ! (
623+ "Failed to allocate f32 norm buffer: {:?}" ,
624+ e
625+ ) )
626+ } ) ?;
627+ let ret = unsafe {
628+ launch_l2_norm_batch_f32 (
629+ input_batch_d,
630+ num_samples,
631+ sample_size,
632+ * buffer. device_ptr_mut ( ) as * mut f32 ,
633+ stream,
634+ )
635+ } ;
636+ if ret != 0 {
637+ return Err ( MahoutError :: KernelLaunch ( format ! (
638+ "Norm reduction kernel f32 failed with CUDA error code: {} ({})" ,
639+ ret,
640+ cuda_error_to_string( ret)
641+ ) ) ) ;
642+ }
643+ buffer
644+ } ;
645+
646+ {
647+ crate :: profile_scope!( "GPU::NormValidationF32" ) ;
648+ let host_inv_norms = device. dtoh_sync_copy ( & inv_norms_gpu) . map_err ( |e| {
649+ MahoutError :: Cuda ( format ! ( "Failed to copy f32 norms to host: {:?}" , e) )
650+ } ) ?;
651+ if host_inv_norms. iter ( ) . any ( |v| !v. is_finite ( ) || * v == 0.0 ) {
652+ return Err ( MahoutError :: InvalidInput (
653+ "One or more float32 samples have zero or invalid norm" . to_string ( ) ,
654+ ) ) ;
655+ }
656+ }
657+
658+ {
659+ crate :: profile_scope!( "GPU::BatchKernelLaunchF32" ) ;
660+ use cudarc:: driver:: DevicePtr ;
661+
662+ let state_ptr = batch_state_vector. ptr_f32 ( ) . ok_or_else ( || {
663+ MahoutError :: InvalidInput (
664+ "Batch state vector precision mismatch (expected float32 buffer)" . to_string ( ) ,
665+ )
666+ } ) ?;
667+ let ret = unsafe {
668+ launch_amplitude_encode_batch_f32 (
669+ input_batch_d,
670+ state_ptr as * mut c_void ,
671+ * inv_norms_gpu. device_ptr ( ) as * const f32 ,
672+ num_samples,
673+ sample_size,
674+ state_len,
675+ stream,
676+ )
677+ } ;
678+ if ret != 0 {
679+ return Err ( MahoutError :: KernelLaunch ( format ! (
680+ "Batch kernel f32 launch failed with CUDA error code: {} ({})" ,
681+ ret,
682+ cuda_error_to_string( ret)
683+ ) ) ) ;
684+ }
685+ }
686+
687+ {
688+ crate :: profile_scope!( "GPU::Synchronize" ) ;
689+ sync_cuda_stream ( stream, "CUDA stream synchronize failed" ) ?;
690+ }
691+
692+ Ok ( batch_state_vector)
693+ }
694+
582695 /// Compute inverse L2 norm on GPU using the reduction kernel.
583696 ///
584697 /// # Arguments
0 commit comments