@@ -546,6 +546,7 @@ class PirateNetBlock(nn.Layer):
546546 $$
547547
548548 Args:
549+ input_dim (int): Input dimension.
549550 embed_dim (int): Embedding dimension.
550551 activation (str, optional): Name of activation function. Defaults to "tanh".
551552 random_weight (Optional[Dict[str, float]]): Mean and std of random weight
@@ -554,16 +555,17 @@ class PirateNetBlock(nn.Layer):
554555
555556 def __init__ (
556557 self ,
558+ input_dim : int ,
557559 embed_dim : int ,
558560 activation : str = "tanh" ,
559561 random_weight : Optional [Dict [str , float ]] = None ,
560562 ):
561563 super ().__init__ ()
562564 self .linear1 = (
563- nn .Linear (embed_dim , embed_dim )
565+ nn .Linear (input_dim , embed_dim )
564566 if random_weight is None
565567 else RandomWeightFactorization (
566- embed_dim ,
568+ input_dim ,
567569 embed_dim ,
568570 mean = random_weight ["mean" ],
569571 std = random_weight ["std" ],
@@ -721,6 +723,9 @@ def __init__(
721723 cur_size , fourier ["dim" ], fourier ["scale" ]
722724 )
723725 cur_size = fourier ["dim" ]
726+ else :
727+ self .linear_emb = nn .Linear (cur_size , hidden_size [0 ])
728+ cur_size = hidden_size [0 ]
724729
725730 self .embed_u = nn .Sequential (
726731 (
@@ -769,6 +774,7 @@ def __init__(
769774 self .blocks .append (
770775 PirateNetBlock (
771776 cur_size ,
777+ _size ,
772778 activation = activation ,
773779 random_weight = random_weight ,
774780 )
@@ -811,6 +817,8 @@ def forward(self, x):
811817
812818 if self .fourier :
813819 y = self .fourier_emb (y )
820+ else :
821+ y = self .linear_emb (y )
814822
815823 y = self .forward_tensor (y )
816824 y = self .split_to_dict (y , self .output_keys , axis = - 1 )
0 commit comments