Skip to content

Add interpolate_like function to oneflow.nn.functional#10644

Merged
ShawnXuan merged 4 commits intomasterfrom
interpolate_like_py
Jun 11, 2025
Merged

Add interpolate_like function to oneflow.nn.functional#10644
ShawnXuan merged 4 commits intomasterfrom
interpolate_like_py

Conversation

@ShawnXuan
Copy link
Collaborator

This PR adds a new utility function interpolate_like to the oneflow.nn.functional module. It allows resizing an input tensor to match the spatial dimensions of a reference tensor using interpolation.

Changes:

  • Added the interpolate_like function in interpolate.py.
  • Exposed interpolate_like via functional/init.py.

Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull Request Overview

This PR adds a new interpolate_like utility to resize an input tensor to match the spatial dimensions of a reference tensor and exposes it in the functional API.

  • Implements interpolate_like in interpolate.py, wrapping the existing Interpolate module.
  • Updates functional/__init__.py to export interpolate_like.

Reviewed Changes

Copilot reviewed 2 out of 2 changed files in this pull request and generated 1 comment.

File Description
python/oneflow/nn/modules/interpolate.py Added interpolate_like function leveraging Interpolate(...).forward
python/oneflow/nn/functional/init.py Imported and exposed interpolate_like in the functional namespace
Comments suppressed due to low confidence (1)

python/oneflow/nn/modules/interpolate.py:313

  • No unit tests were added for interpolate_like. Consider adding tests covering different input ranks (3D, 4D, 5D) and modes to ensure correct behavior across all supported cases.
def interpolate_like(

ShawnXuan and others added 2 commits June 10, 2025 11:21
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
@ShawnXuan ShawnXuan marked this pull request as ready for review June 10, 2025 03:23
@github-actions
Copy link
Contributor

Code got formatted by CI. Please request CI again if you still want to have this PR merged. If the PR is from a forked repo, please download the patch files from the GitHub Actions web page and apply them locally.

@github-actions
Copy link
Contributor

@github-actions
Copy link
Contributor

Speed stats:
GPU Name: NVIDIA GeForce RTX 3080 Ti 

❌ OneFlow resnet50 time: 43.8ms (= 4375.1ms / 100, input_shape=[16, 3, 224, 224])
PyTorch resnet50 time: 57.4ms (= 5742.6ms / 100, input_shape=[16, 3, 224, 224])
✔️ Relative speed: 1.31 (= 57.4ms / 43.8ms)

OneFlow resnet50 time: 26.3ms (= 2632.5ms / 100, input_shape=[8, 3, 224, 224])
PyTorch resnet50 time: 37.5ms (= 3749.5ms / 100, input_shape=[8, 3, 224, 224])
✔️ Relative speed: 1.42 (= 37.5ms / 26.3ms)

OneFlow resnet50 time: 18.6ms (= 3714.7ms / 200, input_shape=[4, 3, 224, 224])
PyTorch resnet50 time: 36.6ms (= 7318.1ms / 200, input_shape=[4, 3, 224, 224])
✔️ Relative speed: 1.97 (= 36.6ms / 18.6ms)

OneFlow resnet50 time: 17.4ms (= 3477.3ms / 200, input_shape=[2, 3, 224, 224])
PyTorch resnet50 time: 29.9ms (= 5978.5ms / 200, input_shape=[2, 3, 224, 224])
✔️ Relative speed: 1.72 (= 29.9ms / 17.4ms)

OneFlow resnet50 time: 17.5ms (= 3490.4ms / 200, input_shape=[1, 3, 224, 224])
PyTorch resnet50 time: 29.7ms (= 5944.2ms / 200, input_shape=[1, 3, 224, 224])
✔️ Relative speed: 1.70 (= 29.7ms / 17.5ms)

OneFlow swin dataloader time: 0.199s (= 39.753s / 200, num_workers=1)
PyTorch swin dataloader time: 0.128s (= 25.686s / 200, num_workers=1)
Relative speed: 0.646 (= 0.128s / 0.199s)

OneFlow swin dataloader time: 0.055s (= 10.976s / 200, num_workers=4)
PyTorch swin dataloader time: 0.033s (= 6.563s / 200, num_workers=4)
Relative speed: 0.598 (= 0.033s / 0.055s)

OneFlow swin dataloader time: 0.031s (= 6.125s / 200, num_workers=8)
PyTorch swin dataloader time: 0.016s (= 3.288s / 200, num_workers=8)
Relative speed: 0.537 (= 0.016s / 0.031s)

❌ OneFlow resnet50 time: 49.2ms (= 4919.7ms / 100, input_shape=[16, 3, 224, 224], ddp, world size=2)
PyTorch resnet50 time: 66.0ms (= 6595.4ms / 100, input_shape=[16, 3, 224, 224], ddp, world size=2)
✔️ Relative speed: 1.34 (= 66.0ms / 49.2ms)

OneFlow resnet50 time: 37.4ms (= 3742.7ms / 100, input_shape=[8, 3, 224, 224], ddp, world size=2)
PyTorch resnet50 time: 47.1ms (= 4713.0ms / 100, input_shape=[8, 3, 224, 224], ddp, world size=2)
✔️ Relative speed: 1.26 (= 47.1ms / 37.4ms)

OneFlow resnet50 time: 27.7ms (= 5547.1ms / 200, input_shape=[4, 3, 224, 224], ddp, world size=2)
PyTorch resnet50 time: 40.2ms (= 8035.0ms / 200, input_shape=[4, 3, 224, 224], ddp, world size=2)
✔️ Relative speed: 1.45 (= 40.2ms / 27.7ms)

OneFlow resnet50 time: 25.3ms (= 5051.9ms / 200, input_shape=[2, 3, 224, 224], ddp, world size=2)
PyTorch resnet50 time: 38.9ms (= 7783.4ms / 200, input_shape=[2, 3, 224, 224], ddp, world size=2)
✔️ Relative speed: 1.54 (= 38.9ms / 25.3ms)

OneFlow resnet50 time: 24.8ms (= 4951.2ms / 200, input_shape=[1, 3, 224, 224], ddp, world size=2)
PyTorch resnet50 time: 35.8ms (= 7164.6ms / 200, input_shape=[1, 3, 224, 224], ddp, world size=2)
✔️ Relative speed: 1.45 (= 35.8ms / 24.8ms)

@ShawnXuan ShawnXuan requested review from oneflow-ci-bot and removed request for oneflow-ci-bot June 10, 2025 12:33
@github-actions
Copy link
Contributor

@github-actions
Copy link
Contributor

Speed stats:
GPU Name: NVIDIA GeForce RTX 3080 Ti 

❌ OneFlow resnet50 time: 43.4ms (= 4342.8ms / 100, input_shape=[16, 3, 224, 224])
PyTorch resnet50 time: 58.0ms (= 5799.6ms / 100, input_shape=[16, 3, 224, 224])
✔️ Relative speed: 1.34 (= 58.0ms / 43.4ms)

OneFlow resnet50 time: 26.5ms (= 2649.4ms / 100, input_shape=[8, 3, 224, 224])
PyTorch resnet50 time: 37.8ms (= 3779.0ms / 100, input_shape=[8, 3, 224, 224])
✔️ Relative speed: 1.43 (= 37.8ms / 26.5ms)

OneFlow resnet50 time: 19.2ms (= 3835.5ms / 200, input_shape=[4, 3, 224, 224])
PyTorch resnet50 time: 36.1ms (= 7223.0ms / 200, input_shape=[4, 3, 224, 224])
✔️ Relative speed: 1.88 (= 36.1ms / 19.2ms)

OneFlow resnet50 time: 17.7ms (= 3534.3ms / 200, input_shape=[2, 3, 224, 224])
PyTorch resnet50 time: 30.6ms (= 6118.1ms / 200, input_shape=[2, 3, 224, 224])
✔️ Relative speed: 1.73 (= 30.6ms / 17.7ms)

OneFlow resnet50 time: 17.3ms (= 3466.2ms / 200, input_shape=[1, 3, 224, 224])
PyTorch resnet50 time: 31.5ms (= 6298.9ms / 200, input_shape=[1, 3, 224, 224])
✔️ Relative speed: 1.82 (= 31.5ms / 17.3ms)

OneFlow swin dataloader time: 0.201s (= 40.298s / 200, num_workers=1)
PyTorch swin dataloader time: 0.136s (= 27.172s / 200, num_workers=1)
Relative speed: 0.674 (= 0.136s / 0.201s)

OneFlow swin dataloader time: 0.055s (= 11.035s / 200, num_workers=4)
PyTorch swin dataloader time: 0.034s (= 6.766s / 200, num_workers=4)
Relative speed: 0.613 (= 0.034s / 0.055s)

OneFlow swin dataloader time: 0.031s (= 6.159s / 200, num_workers=8)
PyTorch swin dataloader time: 0.017s (= 3.415s / 200, num_workers=8)
Relative speed: 0.554 (= 0.017s / 0.031s)

❌ OneFlow resnet50 time: 49.2ms (= 4920.8ms / 100, input_shape=[16, 3, 224, 224], ddp, world size=2)
PyTorch resnet50 time: 68.0ms (= 6804.8ms / 100, input_shape=[16, 3, 224, 224], ddp, world size=2)
✔️ Relative speed: 1.38 (= 68.0ms / 49.2ms)

OneFlow resnet50 time: 37.2ms (= 3722.7ms / 100, input_shape=[8, 3, 224, 224], ddp, world size=2)
PyTorch resnet50 time: 45.8ms (= 4577.4ms / 100, input_shape=[8, 3, 224, 224], ddp, world size=2)
✔️ Relative speed: 1.23 (= 45.8ms / 37.2ms)

OneFlow resnet50 time: 28.5ms (= 5708.0ms / 200, input_shape=[4, 3, 224, 224], ddp, world size=2)
PyTorch resnet50 time: 42.3ms (= 8459.5ms / 200, input_shape=[4, 3, 224, 224], ddp, world size=2)
✔️ Relative speed: 1.48 (= 42.3ms / 28.5ms)

OneFlow resnet50 time: 25.4ms (= 5075.5ms / 200, input_shape=[2, 3, 224, 224], ddp, world size=2)
PyTorch resnet50 time: 40.2ms (= 8049.3ms / 200, input_shape=[2, 3, 224, 224], ddp, world size=2)
✔️ Relative speed: 1.59 (= 40.2ms / 25.4ms)

OneFlow resnet50 time: 25.6ms (= 5116.3ms / 200, input_shape=[1, 3, 224, 224], ddp, world size=2)
PyTorch resnet50 time: 38.1ms (= 7620.0ms / 200, input_shape=[1, 3, 224, 224], ddp, world size=2)
✔️ Relative speed: 1.49 (= 38.1ms / 25.6ms)

@ShawnXuan ShawnXuan merged commit 9fdc6d8 into master Jun 11, 2025
20 checks passed
@ShawnXuan ShawnXuan deleted the interpolate_like_py branch June 11, 2025 08:41
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants