Skip to content

otoomey/taxpr

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

33 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

Taxpr

Run tests PyPI Version

Taxpr is a collection of utilities for performing manipulation of Jaxprs. This is achieved by tag-ing specific arrays at trace time, then extracting and manipulating those tags in the final Jaxpr.

⚠️ This package is still very experimental, so expect broken code and breaking changes.

The provided routines are designed to work seamlessly with jit, vmap, custom_jvp and cousins.

Example

The following example shows how you can use taxpr to emulate functions with side effects without violating Jax's pure function rules.

import itertools as it

import jax
import jax.numpy as jnp
from jax._src.core import eval_jaxpr
import taxpr as tx

_state_counter = it.count()

def get_state(shape, dtype):
    count = next(_state_counter)

    def set_state(value):
        return tx.tag(value, op="set", id=count)

    value = jax.numpy.zeros(shape, dtype=dtype)
    return tx.tag(value, op="get", id=count), set_state


def uncurry(fn, *args, **kwargs):
    jaxpr = jax.make_jaxpr(fn)(args, kwargs)
    states = {}

    # iterate through all tags in the jaxpr
    # this recurses all child Jaxprs too

    for params, shape in tx.iter_tags(jaxpr.jaxpr):
        if params["op"] == "get":
            states[params["id"]] = shape

    initial_states = jax.tree.map(
        lambda x: jax.numpy.full_like(x, 0), states
    )

    def injector(states, token, params):
        if params["op"] == "get":
            state = states[params["id"]]
            return state, states
        elif params["op"] == "set": 
            states[params["id"]] = token
            return token, states
        raise ValueError(f"Unknown tag op: {params['op']}")

    # replace the token with a function that performs the state manipulation
    # here we can pass our own context (`initial_states`)

    jaxpr = tx.inject(jaxpr, injector, initial_states)

    def wrapper(states, *args, **kwargs):
        return eval_jaxpr(jaxpr.jaxpr, jaxpr.consts, states, args, kwargs)

    return wrapper, initial_states

################################################

# Usage

def running_sum(x):
    a, set_state = get_state(x.shape, x.dtype)
    sum = set_state(a + x)
    return sum

rsum, state = uncurry(running_sum, jnp.zeros(0))

_, state = rsum(state, jnp.ones(1))
_, state = rsum(state, jnp.ones(1))
_, state = rsum(state, jnp.ones(1))

assert jnp.allclose(next(iter(state.values())), 3)

About

Create, read and modify tags in Jaxprs

Resources

License

Stars

Watchers

Forks

Packages

 
 
 

Contributors

Languages