-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathutils.py
More file actions
49 lines (42 loc) · 1.97 KB
/
utils.py
File metadata and controls
49 lines (42 loc) · 1.97 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
import json
import random
import re
import pandas as pd
from typing import List, Dict, Union
from torch import Tensor
import torch
def gumbel_sigmoid(logits: Tensor, tau: float = 1, hard: bool = False, threshold: float = 0.5) -> Tensor:
"""
Original code from: https://github.com/AngelosNal/PyTorch-Gumbel-Sigmoid/blob/main/gumbel_sigmoid.py
We modified the gumbel as the left-tail gumbel distribution.
Samples from the Gumbel-Sigmoid distribution and optionally discretizes.
The discretization converts the values greater than `threshold` to 1 and the rest to 0.
The code is adapted from the official PyTorch implementation of gumbel_softmax:
https://pytorch.org/docs/stable/_modules/torch/nn/functional.html#gumbel_softmax
Args:
logits: `[..., num_features]` unnormalized log probabilities
tau: non-negative scalar temperature
hard: if ``True``, the returned samples will be discretized,
but will be differentiated as if it is the soft sample in autograd
threshold: threshold for the discretization,
values greater than this will be set to 1 and the rest to 0
Returns:
Sampled tensor of same shape as `logits` from the Gumbel-Sigmoid distribution.
If ``hard=True``, the returned samples are descretized according to `threshold`, otherwise they will
be probability distributions.
"""
gumbels = (
torch.empty_like(logits, memory_format=torch.legacy_contiguous_format).exponential_().log()
) # ~Gumbel(0, 1), left-tail
gumbels = (logits + gumbels) / tau # ~Gumbel(logits, tau)
y_soft = gumbels.sigmoid()
if hard:
# Straight through.
indices = (y_soft > threshold).nonzero(as_tuple=True)
y_hard = torch.zeros_like(logits, memory_format=torch.legacy_contiguous_format)
y_hard[indices[0], indices[1]] = 1.0
ret = y_hard - y_soft.detach() + y_soft
else:
# Reparametrization trick.
ret = y_soft
return ret