Skip to content

Generalize prep to allow for differentiating w.r.t. any single argument #967

@cgeoga

Description

@cgeoga

Hey @gdalle, I hope you're well! I'm using DI in a new project and absolutely loving it. Thank you so much for your hard work and the immense amount of thought and planning that this must have required.

I have seen #910 and understand why that was closed as not planned. But in this new project, I cooked up a little workaround that seems all green in @code_warntype and is pretty low-overhead (after a somewhat expensive first compilation). I'm wondering if this is a pattern that you considered and found issues with in generalizing what gets to be preped?

The idea is really just to make an ArgSwap object that swaps the first and j-th parameter (in the case where you want to have everything Constant besides the j-th arg), make the prep for that object, and then in evaluation swap the first and j-th arguments again so that the two swaps cancel out. Some of the structure below is specific to my problem, and if that makes the point or pattern hard to see please let me know and I'll streamline it.

# Represents a function f(x, param_1, param_2, ..., param_k), and I'll want to
# differentiate w.r.t. the params.
struct ParametricFunction{F,P}
  fn::F
  params::NTuple{P,Float64}
end

(psdf::ParametricFunction{F,P})(w::Float64) where{F,P} = psdf.fn(w, psdf.params...)

struct ArgSwap{F,P,J}
  fn::ParametricFunction{F,P} 
end

function (as::ArgSwap{F,P,J})(args...) where{F,P,J}
  swapped_args = ntuple(j->(j==1 ? args[J] : (j==J ? args[1] : args[j])), P+1)
  as.fn.fn(swapped_args...)
end

struct ParametricDerivative{F,P,J,A,B}
  swap_sdf::ArgSwap{F,P,J}
  prep::A
  backend::B
end

function ParametricDerivative(psdf::ParametricFunction{F,P}, ::Val{J}, 
                              backend::B) where{F,P,J,B}
  swap = ArgSwap{F,P,J+1}(psdf) # J+1 since args[1] is the frequency.
  prep = prepare_derivative(swap, backend, 1.0, Constant.(psdf.params)...)
  ParametricDerivative(swap, prep, backend)
end

# computes, with full prep, 
#
# ∂/(∂ param_J) fn(w, param_1, ..., param_k).
function (pd::ParametricDerivative{F,P,J,A,B})(w) where{F,P,J,A,B}
  args          = (w, pd.swap_sdf.fn.params...)
  permuted_args = ntuple(j->(j==1 ? args[J] : Constant((j==J ? w : args[j]))), P+1)
  derivative(pd.swap_sdf, pd.prep, pd.backend, permuted_args...)
end

I probably don't have a ton of time to commit to a PR here in the next few months at least. But if this is appealing, maybe it would still be worth discussing and then either I will eventually get to it (probably this summer) or somebody else who needs it sooner may end up seeing this and taking a whack at it?

Please excuse me if this has been discussed somewhere already besides #910 and I missed it.

Cheers,
CJG

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions