@@ -457,6 +457,232 @@ impl QuantumEncoder for AmplitudeEncoder {
457457 Ok ( batch_state_vector)
458458 }
459459
460+ /// Encode multiple samples in a single GPU allocation and kernel launch for f32 inputs
461+ #[ cfg( target_os = "linux" ) ]
462+ fn encode_batch_f32 (
463+ & self ,
464+ device : & Arc < CudaDevice > ,
465+ batch_data : & [ f32 ] ,
466+ num_samples : usize ,
467+ sample_size : usize ,
468+ num_qubits : usize ,
469+ ) -> Result < GpuStateVector > {
470+ crate :: profile_scope!( "AmplitudeEncoder::encode_batch_f32" ) ;
471+
472+ let state_len = 1 << num_qubits;
473+
474+ if sample_size == 0 {
475+ return Err ( MahoutError :: InvalidInput (
476+ "sample_size cannot be zero" . into ( ) ,
477+ ) ) ;
478+ }
479+ if sample_size > state_len {
480+ return Err ( MahoutError :: InvalidInput ( format ! (
481+ "sample_size {} exceeds state vector length {} (2^{} qubits)" ,
482+ sample_size, state_len, num_qubits
483+ ) ) ) ;
484+ }
485+ if batch_data. len ( ) != num_samples * sample_size {
486+ return Err ( MahoutError :: InvalidInput ( format ! (
487+ "batch_data length mismatch (expected {} * {} = {}, got {})" ,
488+ num_samples,
489+ sample_size,
490+ num_samples * sample_size,
491+ batch_data. len( )
492+ ) ) ) ;
493+ }
494+
495+ let batch_state_vector = {
496+ crate :: profile_scope!( "GPU::AllocBatch_f32" ) ;
497+ GpuStateVector :: new_batch ( device, num_samples, num_qubits, Precision :: Float32 ) ?
498+ } ;
499+
500+ // Upload input data to GPU
501+ let input_batch_gpu = {
502+ crate :: profile_scope!( "GPU::H2D_InputBatch_f32" ) ;
503+ device. htod_sync_copy ( batch_data) . map_err ( |e| {
504+ MahoutError :: MemoryAllocation ( format ! ( "Failed to upload batch input: {:?}" , e) )
505+ } ) ?
506+ } ;
507+
508+ // Compute inverse norms on GPU using warp-reduced kernel
509+ let inv_norms_gpu = {
510+ crate :: profile_scope!( "GPU::BatchNormKernel_f32" ) ;
511+ use cudarc:: driver:: DevicePtrMut ;
512+ let mut buffer = device. alloc_zeros :: < f32 > ( num_samples) . map_err ( |e| {
513+ MahoutError :: MemoryAllocation ( format ! ( "Failed to allocate norm buffer: {:?}" , e) )
514+ } ) ?;
515+
516+ let ret = unsafe {
517+ launch_l2_norm_batch_f32 (
518+ * input_batch_gpu. device_ptr ( ) as * const f32 ,
519+ num_samples,
520+ sample_size,
521+ * buffer. device_ptr_mut ( ) as * mut f32 ,
522+ std:: ptr:: null_mut ( ) , // default stream
523+ )
524+ } ;
525+
526+ if ret != 0 {
527+ return Err ( MahoutError :: KernelLaunch ( format ! (
528+ "Norm reduction kernel failed: {} ({})" ,
529+ ret,
530+ cuda_error_to_string( ret)
531+ ) ) ) ;
532+ }
533+ buffer
534+ } ;
535+
536+ // Validate norms on host
537+ {
538+ crate :: profile_scope!( "GPU::NormValidation_f32" ) ;
539+ let host_inv_norms = device
540+ . dtoh_sync_copy ( & inv_norms_gpu)
541+ . map_err ( |e| MahoutError :: Cuda ( format ! ( "Failed to copy norms to host: {:?}" , e) ) ) ?;
542+
543+ if host_inv_norms. iter ( ) . any ( |v| !v. is_finite ( ) || * v == 0.0 ) {
544+ return Err ( MahoutError :: InvalidInput (
545+ "One or more samples have zero or invalid norm" . to_string ( ) ,
546+ ) ) ;
547+ }
548+ }
549+
550+ // Launch batch kernel
551+ {
552+ crate :: profile_scope!( "GPU::BatchKernelLaunch_f32" ) ;
553+ use cudarc:: driver:: DevicePtr ;
554+ let state_ptr = batch_state_vector. ptr_f32 ( ) . ok_or_else ( || {
555+ MahoutError :: InvalidInput (
556+ "Batch state vector precision mismatch (expected float32 buffer)" . to_string ( ) ,
557+ )
558+ } ) ?;
559+ let ret = unsafe {
560+ launch_amplitude_encode_batch_f32 (
561+ * input_batch_gpu. device_ptr ( ) as * const f32 ,
562+ state_ptr as * mut c_void ,
563+ * inv_norms_gpu. device_ptr ( ) as * const f32 ,
564+ num_samples,
565+ sample_size,
566+ state_len,
567+ std:: ptr:: null_mut ( ) , // default stream
568+ )
569+ } ;
570+
571+ if ret != 0 {
572+ return Err ( MahoutError :: KernelLaunch ( format ! (
573+ "Batch kernel launch failed: {} ({})" ,
574+ ret,
575+ cuda_error_to_string( ret)
576+ ) ) ) ;
577+ }
578+ }
579+
580+ {
581+ crate :: profile_scope!( "GPU::Synchronize" ) ;
582+ device
583+ . synchronize ( )
584+ . map_err ( |e| MahoutError :: Cuda ( format ! ( "Sync failed: {:?}" , e) ) ) ?;
585+ }
586+
587+ Ok ( batch_state_vector)
588+ }
589+
590+ #[ cfg( target_os = "linux" ) ]
591+ unsafe fn encode_batch_from_gpu_ptr_f32 (
592+ & self ,
593+ device : & Arc < CudaDevice > ,
594+ input_batch_d : * const c_void ,
595+ num_samples : usize ,
596+ sample_size : usize ,
597+ num_qubits : usize ,
598+ stream : * mut c_void ,
599+ ) -> Result < GpuStateVector > {
600+ let state_len = 1 << num_qubits;
601+ if sample_size == 0 {
602+ return Err ( MahoutError :: InvalidInput (
603+ "Sample size cannot be zero" . into ( ) ,
604+ ) ) ;
605+ }
606+ if sample_size > state_len {
607+ return Err ( MahoutError :: InvalidInput ( format ! (
608+ "Sample size {} exceeds state vector size {} (2^{} qubits)" ,
609+ sample_size, state_len, num_qubits
610+ ) ) ) ;
611+ }
612+ let input_batch_d = input_batch_d as * const f32 ;
613+ let batch_state_vector = {
614+ crate :: profile_scope!( "GPU::AllocBatch_f32" ) ;
615+ GpuStateVector :: new_batch ( device, num_samples, num_qubits, Precision :: Float32 ) ?
616+ } ;
617+ let inv_norms_gpu = {
618+ crate :: profile_scope!( "GPU::BatchNormKernel_f32" ) ;
619+ use cudarc:: driver:: DevicePtrMut ;
620+ let mut buffer = device. alloc_zeros :: < f32 > ( num_samples) . map_err ( |e| {
621+ MahoutError :: MemoryAllocation ( format ! ( "Failed to allocate norm buffer: {:?}" , e) )
622+ } ) ?;
623+ let ret = unsafe {
624+ launch_l2_norm_batch_f32 (
625+ input_batch_d,
626+ num_samples,
627+ sample_size,
628+ * buffer. device_ptr_mut ( ) as * mut f32 ,
629+ stream,
630+ )
631+ } ;
632+ if ret != 0 {
633+ return Err ( MahoutError :: KernelLaunch ( format ! (
634+ "Norm reduction kernel failed with CUDA error code: {} ({})" ,
635+ ret,
636+ cuda_error_to_string( ret)
637+ ) ) ) ;
638+ }
639+ buffer
640+ } ;
641+ {
642+ crate :: profile_scope!( "GPU::NormValidation_f32" ) ;
643+ let host_inv_norms = device
644+ . dtoh_sync_copy ( & inv_norms_gpu)
645+ . map_err ( |e| MahoutError :: Cuda ( format ! ( "Failed to copy norms to host: {:?}" , e) ) ) ?;
646+ if host_inv_norms. iter ( ) . any ( |v| !v. is_finite ( ) || * v == 0.0 ) {
647+ return Err ( MahoutError :: InvalidInput (
648+ "One or more samples have zero or invalid norm" . to_string ( ) ,
649+ ) ) ;
650+ }
651+ }
652+ {
653+ crate :: profile_scope!( "GPU::BatchKernelLaunch_f32" ) ;
654+ use cudarc:: driver:: DevicePtr ;
655+ let state_ptr = batch_state_vector. ptr_f32 ( ) . ok_or_else ( || {
656+ MahoutError :: InvalidInput (
657+ "Batch state vector precision mismatch (expected float32 buffer)" . to_string ( ) ,
658+ )
659+ } ) ?;
660+ let ret = unsafe {
661+ launch_amplitude_encode_batch_f32 (
662+ input_batch_d,
663+ state_ptr as * mut c_void ,
664+ * inv_norms_gpu. device_ptr ( ) as * const f32 ,
665+ num_samples,
666+ sample_size,
667+ state_len,
668+ stream,
669+ )
670+ } ;
671+ if ret != 0 {
672+ return Err ( MahoutError :: KernelLaunch ( format ! (
673+ "Batch kernel launch failed with CUDA error code: {} ({})" ,
674+ ret,
675+ cuda_error_to_string( ret)
676+ ) ) ) ;
677+ }
678+ }
679+ {
680+ crate :: profile_scope!( "GPU::Synchronize" ) ;
681+ sync_cuda_stream ( stream, "CUDA stream synchronize failed" ) ?;
682+ }
683+ Ok ( batch_state_vector)
684+ }
685+
460686 fn name ( & self ) -> & ' static str {
461687 "amplitude"
462688 }
0 commit comments