Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 22 additions & 3 deletions PhysicsTools/PyTorch/interface/Model.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,17 +13,34 @@ namespace cms::torch {
// - https://docs.pytorch.org/cppdocs/api/classtorch_1_1nn_1_1_module.html#class-module
class Model {
public:
explicit Model(const std::string &model_path) : model_(cms::torch::load(model_path)), device_(::torch::kCPU) {}
explicit Model(const std::string &model_path, bool auto_freeze = true)
: model_(cms::torch::load(model_path)), device_(::torch::kCPU), auto_freeze_(auto_freeze) {
model_.eval();
}

explicit Model(const std::string &model_path, ::torch::Device dev)
: model_(cms::torch::load(model_path, dev)), device_(dev) {}
explicit Model(const std::string &model_path, ::torch::Device dev, bool auto_freeze = true)
: model_(cms::torch::load(model_path, dev)), device_(dev), auto_freeze_(auto_freeze) {
model_.eval();
}

// Move model to specified device memory space. Async load by specifying `non_blocking` (in default stream if not overridden by the caller)
void to(::torch::Device dev, const bool non_blocking = false) {
if (dev == device_)
return;

assert(!is_frozen_ && "Model is frozen, cannot be moved to another device!");
model_.to(dev, non_blocking);
device_ = dev;
if (auto_freeze_) {
freeze();
}
}

void freeze() {
if (!is_frozen_) {
model_ = ::torch::jit::freeze(model_);
is_frozen_ = true;
}
}

// Forward pass (inference) of model, returns torch::IValue (multi output support). Match native torchlib interface.
Expand All @@ -39,6 +56,8 @@ namespace cms::torch {
protected:
::torch::jit::script::Module model_; // underlying JIT model
::torch::Device device_; // device where the model is allocated (default CPU)
bool auto_freeze_; // flag to indicate if the model should be automatically frozen after loading or moving to device
bool is_frozen_ = false; // flag to indicate if the model is frozen
};

} // namespace cms::torch
Expand Down
39 changes: 35 additions & 4 deletions PhysicsTools/PyTorchAlpaka/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,15 @@ The interface provides a converter to dynamically wrap SoA data into one or more

**Due to the lack of const correctness ensured by PyTorch, `const` data is currently being copied.**

## Model behavior

The `Model` wrapper automatically sets the loaded TorchScript module to evaluation mode (`eval()`).

Optionally, the model can be automatically frozen using `torch::jit::freeze` at construction time when a device is specified, or the first time it is moved.
**Important:** Once a model is frozen, it cannot be moved to another device. Attempting to do so will trigger a runtime assertion.



### TensorCollection
The structural information of the inputs/outputs SoA are stored in an `TensorCollection`. Which is a high level object to register column lists from which tensors are created

Expand All @@ -41,10 +50,12 @@ GENERATE_SOA_LAYOUT(SoATemplate,
GENERATE_SOA_LAYOUT(SoAOutputTemplate,
SOA_COLUMN(int, cluster));
```

- **Get Metarecords from Portable Collections:**
If constructed with a single argument (`total_size`), the entire dataset is treated as a single batch.
```cpp
PortableCollection<SoA, Device> deviceCollection(batch_size, queue);
PortableCollection<SoA_Result, Device> deviceResultCollection(batch_size, queue);
PortableCollection<SoA, Device> deviceCollection(total_size, queue);
PortableCollection<SoA_Result, Device> deviceResultCollection(total_size, queue);
fill(queue, deviceCollection);
auto records = deviceCollection.view().records();
auto result_records = deviceResultCollection.view().records();
Expand All @@ -53,14 +64,14 @@ auto result_records = deviceResultCollection.view().records();

**IMPORTANT:** continuity of memory is a strict requirement!
```
TensorCollection input(batch_size);
TensorCollection input(total_size);
input.add<SoA>("eigen_vector", records.a(), records.b());
input.add<SoA>("eigen_matrix", records.c());
input.add<SoA>("column", records.x(), records.y(), records.z());
input.add<SoA>("scalar", records.type());
input.change_order({"column", "scalar", "eigen_matrix", "eigen_vector"});

TensorCollection output(batch_size);
TensorCollection output(total_size);
output.add<SoA>("result", result_view.cluster());
```

Expand All @@ -72,6 +83,26 @@ After adding all the blocks to the `TensorCollection`, the order of the blocks f

More examples about usage can be found in [PyTorchAlpakaTest](../PyTorchAlpakaTest).

### Batching semantics

When using batched inference, `TensorCollection` is constructed with `(batch_size, total_size)` and internally manages batch offsets.

**IMPORTANT:** the batchsize should be chosen carefully in order to respect the alignment. Otherwise, an assert will be trigged.

**Constraints:**
- `total_size` must be divisible by `batch_size`
- All batches are assumed to be of equal size
- Partial (last) batches are currently not supported

The `batch_id` passed to `add()` selects which batch slice is exposed to the model.

Runtime checks are performed to ensure:
- valid batch indices
- consistency between batch size and total size
- memory contiguity between columns

These checks rely on `assert`.

## Limitations
- Current implementation supports `SerialSync` and `CudaAsync` backends only. `ROCmAsync` backend is supported via SerialSync fallback mechanism due to missing `pytorch-hip` library in CMSSW (see: https://github.com/pytorch/pytorch/blob/main/aten/CMakeLists.txt#L75), with explicit `alpaka::wait()` call to copy data to host and back to device.
- Const correctness and thread-safety relies on `torch::from_blob()` mechanism which currently does not ensure that data will not be modified internally. There is ongoing work to support COW tensors but until this support will be integrated in mainstream PyTorch the provided solution materialises (copies) the tensors if passed registry points to `const` memory. For more information please check [Const correctness and thread-safety of torch::from_blob with external memory](https://discuss.pytorch.org/t/const-correctness-and-thread-safety-of-torch-from-blob-with-external-memory/223521) and [pytorch:#97856](https://github.com/pytorch/pytorch/issues/97856)
Expand Down
9 changes: 7 additions & 2 deletions PhysicsTools/PyTorchAlpaka/interface/SoAConversion.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,15 @@ namespace cms::torch::alpakatools::detail {
}

template <typename TQueue>
inline std::vector<::torch::IValue> convertInput(TensorCollection<TQueue>& inputs, ::torch::Device device) {
inline std::vector<::torch::IValue> convertInput(TensorCollection<TQueue>& inputs,
::torch::Device device,
bool to_half = false) {
std::vector<::torch::IValue> tensors(inputs.size());
for (size_t i = 0; i < inputs.size(); i++) {
tensors[i] = cms::torch::alpakatools::detail::arrayToTensor(device, inputs[i]);
if (to_half)
tensors[i] = cms::torch::alpakatools::detail::arrayToTensor(device, inputs[i]).to(::torch::kHalf);
else
tensors[i] = cms::torch::alpakatools::detail::arrayToTensor(device, inputs[i]);
}
return tensors;
}
Expand Down
72 changes: 45 additions & 27 deletions PhysicsTools/PyTorchAlpaka/interface/TensorCollection.h
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
#include <string>
#include <tuple>
#include <type_traits>
#include <unordered_map>
#include <vector>

#include <alpaka/alpaka.hpp>
Expand Down Expand Up @@ -69,11 +70,18 @@ namespace cms::torch::alpakatools {
// SOA_COLUMN(float, phi))
//
// can register the following:
// TensorCollection<Device> registry(batch_size);
//
// TensorCollection<Device> registry(batch_size, total_size);
// registry.add<ParticleLayout>("features", batch_id, records.pt(), records.eta(), records.phi());
//
// In the above example, the add function automatically computes the offset for the batch and ensures the provided columns are contiguous in memory.
// If the user wants to perform inference on the entire dataset without batching, he can simply register by passing just the total size:
//
// TensorCollection<Device> registry(total_size);
// registry.add<ParticleLayout>("features", records.pt(), records.eta(), records.phi());
//
// but if want to use only pt() and phi() then below will not work as pt() and phi() are not contiguous:
// TensorCollection<Device> registry(batch_size);
// If the user wants to use only pt() and phi() then below will not work as pt() and phi() are not contiguous:
// TensorCollection<Device> registry(batch_size, total_size);
// registry.add<ParticleLayout>("features", records.pt(), records.phi());
//
// potential solution would be to arrange layout dependent on model requirements
Expand All @@ -91,24 +99,29 @@ namespace cms::torch::alpakatools {
friend class alpaka_rocm_async::torch::AlpakaModel;
friend class alpaka_serial_sync::torch::AlpakaModel;

explicit TensorCollection(int batch_size) : batch_size_(batch_size) {}
explicit TensorCollection(int total_size) : batch_size_(total_size), total_size_(total_size) {}
explicit TensorCollection(int batch_size, int total_size) : batch_size_(batch_size), total_size_(total_size) {}

// SOA_EIGEN_COLUMN
template <typename SoALayout, typename TSoAParamsImpl, typename... Others>
requires(SameValueType<TSoAParamsImpl, Others...> && TSoAParamsImpl::columnType == cms::soa::SoAColumnType::eigen)
void add(const std::string& name,
int batch_size,
int batch_id,
std::tuple<TSoAParamsImpl, cms::soa::size_type> column,
std::tuple<Others, cms::soa::size_type>... others) {
using DataType = typename TSoAParamsImpl::ScalarType;
assert_batch_size(batch_id);
int offset = batch_id * batch_size_;
auto ptr = std::get<0>(column).data();
int n_elems =
cms::torch::alpakatools::detail::num_elements_per_column(batch_size, SoALayout::alignment, sizeof(DataType));
cms::torch::alpakatools::detail::num_elements_per_column(total_size_, SoALayout::alignment, sizeof(DataType));
assert_location(
n_elems * TSoAParamsImpl::ValueType::RowsAtCompileTime * TSoAParamsImpl::ValueType::ColsAtCompileTime,
ptr,
std::get<0>(others).data()...);

ptr += offset;

std::vector<int> tensor_dims;
if constexpr (TSoAParamsImpl::ValueType::ColsAtCompileTime > 1)
tensor_dims = {1 + sizeof...(Others),
Expand All @@ -117,60 +130,55 @@ namespace cms::torch::alpakatools {
else
tensor_dims = {1 + sizeof...(Others), TSoAParamsImpl::ValueType::RowsAtCompileTime};

emplace_tensor(name, SoALayout::alignment, ptr, batch_size, tensor_dims);
emplace_tensor(name, SoALayout::alignment, ptr, batch_size_, total_size_, tensor_dims);
}

// SOA_EIGEN_COLUMN with default batch size
// SOA_EIGEN_COLUMN with default batch size = default size
template <typename SoALayout, typename TSoAParamsImpl, typename... Others>
requires(SameValueType<TSoAParamsImpl, Others...> && TSoAParamsImpl::columnType == cms::soa::SoAColumnType::eigen)
void add(const std::string& name,
std::tuple<TSoAParamsImpl, cms::soa::size_type> column,
std::tuple<Others, cms::soa::size_type>... others) {
add<SoALayout, TSoAParamsImpl, Others...>(name, batch_size_, column, others...);
add<SoALayout, TSoAParamsImpl, Others...>(name, 0, column, others...);
}

// SOA_COLUMN
template <typename SoALayout, typename TSoAParamsImpl, typename... Others>
requires(SameScalarType<TSoAParamsImpl, Others...> &&
TSoAParamsImpl::columnType == cms::soa::SoAColumnType::column)
void add(const std::string& name,
int batch_size,
int batch_id,
std::tuple<TSoAParamsImpl, cms::soa::size_type> column,
std::tuple<Others, cms::soa::size_type>... others) {
using DataType = typename TSoAParamsImpl::ScalarType;
int n_elems =
cms::torch::alpakatools::detail::num_elements_per_column(batch_size, SoALayout::alignment, sizeof(DataType));
assert_location(n_elems, std::get<0>(column).data(), std::get<0>(others).data()...);
assert_batch_size(batch_id);
int offset = batch_id * batch_size_;
auto ptr = std::get<0>(column).data();
emplace_tensor(name, SoALayout::alignment, ptr, batch_size, {1 + sizeof...(Others)});
int n_elems =
cms::torch::alpakatools::detail::num_elements_per_column(total_size_, SoALayout::alignment, sizeof(DataType));
assert_location(n_elems, ptr, std::get<0>(others).data()...);

ptr += offset;
emplace_tensor(name, SoALayout::alignment, ptr, batch_size_, total_size_, {1 + sizeof...(Others)});
}

// SOA_COLUMN with default batch size
// SOA_COLUMN with default batch size = total size
template <typename SoALayout, typename TSoAParamsImpl, typename... Others>
requires(SameScalarType<TSoAParamsImpl, Others...> &&
TSoAParamsImpl::columnType == cms::soa::SoAColumnType::column)
void add(const std::string& name,
std::tuple<TSoAParamsImpl, cms::soa::size_type> column,
std::tuple<Others, cms::soa::size_type>... others) {
add<SoALayout, TSoAParamsImpl, Others...>(name, batch_size_, column, others...);
add<SoALayout, TSoAParamsImpl, Others...>(name, 0, column, others...);
}

// SOA_SCALAR
template <typename SoALayout, cms::soa::SoAColumnType column_t, typename T>
requires(std::is_arithmetic_v<T> && column_t == cms::soa::SoAColumnType::scalar)
void add(const std::string& name,
int batch_size,
std::tuple<cms::soa::SoAParametersImpl<column_t, T>, cms::soa::size_type> column) {
auto ptr = std::get<0>(column).data();
emplace_tensor(name, SoALayout::alignment, ptr, batch_size, {1}, true);
}

// SOA_SCALAR with default batch size
template <typename SoALayout, cms::soa::SoAColumnType column_t, typename T>
requires(std::is_arithmetic_v<T> && column_t == cms::soa::SoAColumnType::scalar)
void add(const std::string& name,
std::tuple<cms::soa::SoAParametersImpl<column_t, T>, cms::soa::size_type> column) {
add<SoALayout, column_t, T>(name, batch_size_, column);
emplace_tensor(name, SoALayout::alignment, ptr, batch_size_, total_size_, {1}, true);
}

// The order is defined by the order `add()` is called.
Expand Down Expand Up @@ -201,16 +209,26 @@ namespace cms::torch::alpakatools {
size_t alignment,
Tptr ptr,
int batch_size,
int total_size,
std::vector<int> dims = {1},
const bool is_scalar = false) {
using T = std::remove_pointer_t<Tptr>;
registry_.try_emplace(name,
std::make_unique<cms::torch::alpakatools::detail::TensorHandle<TQueue, T>>(
alignment, sizeof(T), ptr, batch_size, std::move(dims), is_scalar));
alignment, sizeof(T), ptr, batch_size, total_size, std::move(dims), is_scalar));
order_.push_back(name);
}

void assert_batch_size(int batch_id) {
assert(batch_size_ > 0 && "Batch size must be positive!");
assert(total_size_ > 0 && "Total size must be positive!");
assert(batch_id >= 0 && "Batch id must be non-negative!");
assert(total_size_ % batch_size_ == 0 && "Total size must be divisible by batch size!");
assert((batch_id * batch_size_ < total_size_) && "Batch id is out of bounds!");
}

int batch_size_;
int total_size_;
std::vector<std::string> order_;
std::unordered_map<std::string, std::unique_ptr<cms::torch::alpakatools::detail::ITensorHandle<TQueue>>> registry_;
};
Expand Down
7 changes: 5 additions & 2 deletions PhysicsTools/PyTorchAlpaka/interface/TensorHandle.h
Original file line number Diff line number Diff line change
Expand Up @@ -86,13 +86,15 @@ namespace cms::torch::alpakatools::detail {
const size_t bytes,
T* data,
const int batch_size,
const int total_size,
const std::vector<int> dims,
const bool is_scalar = false)
: alignment_(alignment),
bytes_(bytes),
data_(data),
total_size_(total_size),
dims_(batch_size, dims, is_scalar),
policy_(data, dims_.volume() * num_elements_per_column(batch_size, alignment, bytes)) {
policy_(data, dims_.volume() * num_elements_per_column(total_size, alignment, bytes)) {
init_sizes();
init_strides();
}
Expand Down Expand Up @@ -130,7 +132,7 @@ namespace cms::torch::alpakatools::detail {
strides_ = std::vector<long int>(N);

int per_bunch = alignment_ / bytes_;
int bunches = std::ceil(1.0 * dims_.batch_size() / per_bunch);
int bunches = std::ceil(1.0 * total_size_ / per_bunch);

// base stride initialization
if (!dims_.is_scalar())
Expand Down Expand Up @@ -160,6 +162,7 @@ namespace cms::torch::alpakatools::detail {
const size_t alignment_;
const size_t bytes_;
T* data_;
const int total_size_;
const Dims dims_;

std::vector<long int> strides_;
Expand Down
5 changes: 3 additions & 2 deletions PhysicsTools/PyTorchAlpaka/interface/alpaka/AlpakaModel.h
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,8 @@ namespace ALPAKA_ACCELERATOR_NAMESPACE::torch {
// Refer: PhysicsTools/PyTorch/interface/SoAConversion.h for details about wrapping memory layouts.
void forward(Queue &queue,
cms::torch::alpakatools::TensorCollection<Queue> &inputs,
cms::torch::alpakatools::TensorCollection<Queue> &outputs) {
cms::torch::alpakatools::TensorCollection<Queue> &outputs,
bool to_half = false) {
#ifdef ALPAKA_ACC_GPU_HIP_ENABLED
inputs.copy(queue, cms::torch::alpakatools::detail::MemcpyKind::DeviceToHost);
outputs.copy(queue, cms::torch::alpakatools::detail::MemcpyKind::DeviceToHost);
Expand All @@ -48,7 +49,7 @@ namespace ALPAKA_ACCELERATOR_NAMESPACE::torch {
to(queue);
}

auto input_tensor = cms::torch::alpakatools::detail::convertInput(inputs, device_);
auto input_tensor = cms::torch::alpakatools::detail::convertInput(inputs, device_, to_half);
if (outputs.size() > 1) {
auto output_tensors = model_.forward(input_tensor);
cms::torch::alpakatools::detail::convertOutput(output_tensors, outputs, device_);
Expand Down
Loading