Skip to content

Non trivial slice assignments and tensor manipulation #935

@vladimirmujagic

Description

@vladimirmujagic

Hello,

I am trying to port Retinaface (face and landmark detection in pytorch) to rust and was just wondering if you support operations similar to

def decode(loc, priors, variances):
    """Decode locations from predictions using priors to undo
    the encoding we did for offset regression at train time.
    Args:
        loc (tensor): location predictions for loc layers,
            Shape: [num_priors,4]
        priors (tensor): Prior boxes in center-offset form.
            Shape: [num_priors,4].
        variances: (list[float]) Variances of priorboxes
    Return:
        decoded bounding box predictions
    """

    boxes = torch.cat((
        priors[:, :2] + loc[:, :2] * variances[0] * priors[:, 2:],
        priors[:, 2:] * torch.exp(loc[:, 2:] * variances[1])), 1)
    boxes[:, :2] -= boxes[:, 2:] / 2
    boxes[:, 2:] += boxes[:, :2]
    return boxes

I couldn't find similar functionalities in your library to implement operations like boxes[:, :2] -= boxes[:, 2:] / 2

This is my current implementation which compiles but is still not tested for correctness and is not optimized.

pub fn decode(
    loc: ArrayBase<OwnedRepr<f32>, Dim<[usize; 2]>>,
    priors: ArrayBase<OwnedRepr<f32>, Dim<[usize; 2]>>,
    variances: &Vec<f32>
) -> Result<ArrayBase<OwnedRepr<f32>, Dim<[usize; 2]>>, Error>
{
    let priors_to = priors.clone().slice_move(s![.., ..2]);
    let priors_from = priors.slice_move(s![.., 2..]);

    let mut loc_to = loc.clone().slice_move(s![.., ..2]);
    let mut loc_from = loc.slice_move(s![.., 2..]);
    loc_to = loc_to.mapv(|v:f32| v * variances[0]);
    loc_from = loc_from.mapv(|v:f32| (v * variances[1]).exp());
    
    let a =
        priors_to + loc_to * priors_from.clone();
    let b =  
        priors_from * loc_from;
    let mut boxes = stack![Axis(1), a, b];

    let boxes_to = boxes.clone().slice_move(s![.., ..2]);
    let mut boxes_from = boxes.clone().slice_move(s![.., 2..]);
    boxes_from = boxes_from.mapv(|v: f32| v / 2.0);

    for i in 0..boxes.shape()[0] {
        for j in 0..2 {
            boxes[[i, j]] -= boxes_from[[i, j]];
        }
    }

    for i in 0..boxes.shape()[0] {
        for j in 0..2 {
            boxes[[i, j]] += boxes_to[[i, j]];
        }
    }

    Ok(boxes)
}

Also is it possible to stack multiple tensors like:

    landms = torch.cat((priors[:, :2] + pre[:, :2] * variances[0] * priors[:, 2:],
                        priors[:, :2] + pre[:, 2:4] * variances[0] * priors[:, 2:],
                        priors[:, :2] + pre[:, 4:6] * variances[0] * priors[:, 2:],
                        priors[:, :2] + pre[:, 6:8] * variances[0] * priors[:, 2:],
                        priors[:, :2] + pre[:, 8:10] * variances[0] * priors[:, 2:],
                        ), dim=1)

Or i have to go 2 by 2 and produce same results?

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions