Solution:
replace few lines with this
from transformers.modeling_attn_mask_utils import AttentionMaskConverter
mask_converter = AttentionMaskConverter(is_causal=True)
causal_attention_mask = mask_converter._make_causal_mask(input_shape, hidden_states.dtype, device=hidden_states.device)
attention_mask = mask_converter._expand_mask(attention_mask, hidden_states.dtype)
Note this solution works only when the error is regarding causal_attention_mask and attention_mask.
Solution:
replace few lines with this
from transformers.modeling_attn_mask_utils import AttentionMaskConverter
mask_converter = AttentionMaskConverter(is_causal=True)
causal_attention_mask = mask_converter._make_causal_mask(input_shape, hidden_states.dtype, device=hidden_states.device)
attention_mask = mask_converter._expand_mask(attention_mask, hidden_states.dtype)
Note this solution works only when the error is regarding causal_attention_mask and attention_mask.