@@ -221,7 +221,8 @@ def __init__(self, spatial_sigma, color_sigma):
221221 self .len_spatial_sigma = 3
222222 else :
223223 raise ValueError (
224- f"len(spatial_sigma) { spatial_sigma } must match number of spatial dims { self .ken_spatial_sigma } ."
224+ f"Length of `spatial_sigma` must match number of spatial dims (1, 2 or 3)"
225+ f"or be a single float value ({ spatial_sigma = } )."
225226 )
226227
227228 # Register sigmas as trainable parameters.
@@ -231,6 +232,10 @@ def __init__(self, spatial_sigma, color_sigma):
231232 self .sigma_color = torch .nn .Parameter (torch .tensor (color_sigma ))
232233
233234 def forward (self , input_tensor ):
235+ if len (input_tensor .shape ) < 3 :
236+ raise ValueError (
237+ f"Input must have at least 3 dimensions (batch, channel, *spatial_dims), got { len (input_tensor .shape )} "
238+ )
234239 if input_tensor .shape [1 ] != 1 :
235240 raise ValueError (
236241 f"Currently channel dimensions >1 ({ input_tensor .shape [1 ]} ) are not supported. "
@@ -239,24 +244,27 @@ def forward(self, input_tensor):
239244 )
240245
241246 len_input = len (input_tensor .shape )
247+ spatial_dims = len_input - 2
242248
243249 # C++ extension so far only supports 5-dim inputs.
244- if len_input == 3 :
250+ if spatial_dims == 1 :
245251 input_tensor = input_tensor .unsqueeze (3 ).unsqueeze (4 )
246- elif len_input == 4 :
252+ elif spatial_dims == 2 :
247253 input_tensor = input_tensor .unsqueeze (4 )
248254
249- if self .len_spatial_sigma != len_input :
250- raise ValueError (f"Spatial dimension ({ len_input } ) must match initialized len(spatial_sigma)." )
255+ if self .len_spatial_sigma != spatial_dims :
256+ raise ValueError (
257+ f"Number of spatial dimensions ({ spatial_dims } ) must match initialized `len(spatial_sigma)`."
258+ )
251259
252260 prediction = TrainableBilateralFilterFunction .apply (
253261 input_tensor , self .sigma_x , self .sigma_y , self .sigma_z , self .sigma_color
254262 )
255263
256264 # Make sure to return tensor of the same shape as the input.
257- if len_input == 3 :
265+ if spatial_dims == 1 :
258266 prediction = prediction .squeeze (4 ).squeeze (3 )
259- elif len_input == 4 :
267+ elif spatial_dims == 2 :
260268 prediction = prediction .squeeze (4 )
261269
262270 return prediction
@@ -389,7 +397,8 @@ def __init__(self, spatial_sigma, color_sigma):
389397 self .len_spatial_sigma = 3
390398 else :
391399 raise ValueError (
392- f"len(spatial_sigma) { spatial_sigma } must match number of spatial dims { self .ken_spatial_sigma } ."
400+ f"Length of `spatial_sigma` must match number of spatial dims (1, 2 or 3)\n "
401+ f"or be a single float value ({ spatial_sigma = } )."
393402 )
394403
395404 # Register sigmas as trainable parameters.
@@ -399,39 +408,45 @@ def __init__(self, spatial_sigma, color_sigma):
399408 self .sigma_color = torch .nn .Parameter (torch .tensor (color_sigma ))
400409
401410 def forward (self , input_tensor , guidance_tensor ):
411+ if len (input_tensor .shape ) < 3 :
412+ raise ValueError (
413+ f"Input must have at least 3 dimensions (batch, channel, *spatial_dims), got { len (input_tensor .shape )} "
414+ )
402415 if input_tensor .shape [1 ] != 1 :
403416 raise ValueError (
404- f"Currently channel dimensions >1 ({ input_tensor .shape [1 ]} ) are not supported. "
417+ f"Currently channel dimensions > 1 ({ input_tensor .shape [1 ]} ) are not supported. "
405418 "Please use multiple parallel filter layers if you want "
406419 "to filter multiple channels."
407420 )
408421 if input_tensor .shape != guidance_tensor .shape :
409422 raise ValueError (
410- "Shape of input image must equal shape of guidance image."
411- f"Got { input_tensor .shape } and { guidance_tensor .shape } ."
423+ f"Shape of input image must equal shape of guidance image, got { input_tensor .shape } and { guidance_tensor .shape } ."
412424 )
413425
414426 len_input = len (input_tensor .shape )
427+ spatial_dims = len_input - 2
415428
416429 # C++ extension so far only supports 5-dim inputs.
417- if len_input == 3 :
430+ if spatial_dims == 1 :
418431 input_tensor = input_tensor .unsqueeze (3 ).unsqueeze (4 )
419432 guidance_tensor = guidance_tensor .unsqueeze (3 ).unsqueeze (4 )
420- elif len_input == 4 :
433+ elif spatial_dims == 2 :
421434 input_tensor = input_tensor .unsqueeze (4 )
422435 guidance_tensor = guidance_tensor .unsqueeze (4 )
423436
424- if self .len_spatial_sigma != len_input :
425- raise ValueError (f"Spatial dimension ({ len_input } ) must match initialized len(spatial_sigma)." )
437+ if self .len_spatial_sigma != spatial_dims :
438+ raise ValueError (
439+ f"Number of spatial dimensions ({ spatial_dims } ) must match initialized `len(spatial_sigma)`."
440+ )
426441
427442 prediction = TrainableJointBilateralFilterFunction .apply (
428443 input_tensor , guidance_tensor , self .sigma_x , self .sigma_y , self .sigma_z , self .sigma_color
429444 )
430445
431446 # Make sure to return tensor of the same shape as the input.
432- if len_input == 3 :
447+ if spatial_dims == 1 :
433448 prediction = prediction .squeeze (4 ).squeeze (3 )
434- elif len_input == 4 :
449+ elif spatial_dims == 2 :
435450 prediction = prediction .squeeze (4 )
436451
437452 return prediction
0 commit comments