Skip to article frontmatterSkip to article content
Site not loading correctly?

This may be due to an incorrect BASE_URL configuration. See the MyST Documentation for reference.

Dispatchers

pylcm solves and simulates models by evaluating scalar functions on structured spaces — Cartesian products of grids during solution, aligned arrays of agent states during simulation. JAX’s vmap is the primitive for this, but raw vmap loses function signatures and doesn’t distinguish between product-mapping and aligned-mapping.

The three dispatchers in lcm.dispatchers solve this:

  • productmap — evaluate on the Cartesian product of variables (solution phase)

  • vmap_1d — evaluate on aligned 1D arrays (simulation phase)

  • simulation_spacemap — hybrid: product over actions, aligned over states (simulation phase)

import inspect

import jax.numpy as jnp
import plotly.graph_objects as go
from jax import vmap

from lcm.dispatchers import productmap, simulation_spacemap, vmap_1d

blue, orange, green = "#4C78A8", "#F58518", "#54A24B"

A scalar function

All dispatchers start from a scalar function — one that takes scalar inputs and returns a scalar output. We’ll use a simple production function throughout:

def f(x, y):
    """Cobb-Douglas-style function."""
    return x**0.4 * y**0.6


# Scalar inputs, scalar output
f(2.0, 3.0)
2.5508490012515814

productmap

During the solution phase, pylcm needs to evaluate utility and constraint functions on every combination of state and action grid points — the Cartesian product of the grids.

productmap does this by iteratively applying vmap. Given variables ("x", "y"), it vmaps over y first (innermost), then over x (outermost), producing an output array with shape (n_x, n_y) — dimensions matching the variable order.

f_product = productmap(func=f, variables=("x", "y"))

x = jnp.array([1.0, 2.0, 3.0, 4.0])
y = jnp.array([10.0, 20.0, 30.0])

result = f_product(x=x, y=y)

print(f"x has {len(x)} points, y has {len(y)} points")
print(f"Output shape: {result.shape}  (n_x, n_y)")
print(f"\nresult[2, 1] = f(x[2], y[1]) = f({x[2]}, {y[1]}) = {result[2, 1]:.4f}")
print(f"f(3.0, 20.0) = {f(3.0, 20.0):.4f}")
x has 4 points, y has 3 points
Output shape: (4, 3)  (n_x, n_y)

result[2, 1] = f(x[2], y[1]) = f(3.0, 20.0) = 9.3641
f(3.0, 20.0) = 9.3641
fig = go.Figure(
    data=go.Heatmap(
        z=result.T,
        x=x,
        y=y,
        colorscale="Blues",
        colorbar={"title": "f(x, y)"},
    )
)
fig.update_layout(
    title="productmap output: f evaluated on the Cartesian product of x and y",
    xaxis_title="x",
    yaxis_title="y",
    width=600,
    height=400,
)
fig.show()
Loading...

How it works: iterative vmap

Under the hood, productmap applies vmap one variable at a time, in reverse order. For variables=("x", "y"):

  1. vmap over y (position 1) → function now takes scalar x, array y → returns 1D

  2. vmap over x (position 0) → function now takes array x, array y → returns 2D

The reverse iteration ensures output dimensions match variable order.

Let’s verify this manually:

# Step 1: vmap over y (axis 1) — x is not mapped (None), y is mapped (0)
f_over_y = vmap(f, in_axes=(None, 0))

# Step 2: vmap over x (axis 0) — x is mapped (0), y is not mapped (None)
f_over_xy = vmap(f_over_y, in_axes=(0, None))

result_manual = f_over_xy(x, y)

print(f"productmap result matches manual vmap: {jnp.allclose(result, result_manual)}")
productmap result matches manual vmap: True

vmap_1d

During simulation, states are not grids to form products over — they are arrays of realized values, one per simulated agent. All state arrays share the same leading axis (the agent axis).

vmap_1d maps the function over all specified variables simultaneously along their leading axis, producing a 1D output (one value per agent).

f_aligned = vmap_1d(func=f, variables=("x", "y"))

# 5 agents, each with their own (x, y) pair
x_agents = jnp.array([1.0, 2.0, 3.0, 4.0, 5.0])
y_agents = jnp.array([10.0, 20.0, 30.0, 40.0, 50.0])

result_1d = f_aligned(x=x_agents, y=y_agents)

print(f"5 agents, aligned arrays → output shape: {result_1d.shape}")
print(
    f"\nresult[2] = f(x[2], y[2])"
    f" = f({x_agents[2]}, {y_agents[2]})"
    f" = {result_1d[2]:.4f}"
)
print(f"f(3.0, 30.0) = {f(3.0, 30.0):.4f}")
5 agents, aligned arrays → output shape: (5,)

result[2] = f(x[2], y[2]) = f(3.0, 30.0) = 11.9432
f(3.0, 30.0) = 11.9432
fig = go.Figure()
fig.add_trace(
    go.Bar(
        x=[f"Agent {i}" for i in range(len(result_1d))],
        y=result_1d,
        marker_color=blue,
        text=[
            f"f({x_agents[i]:.0f}, {y_agents[i]:.0f})" for i in range(len(result_1d))
        ],
        textposition="outside",
    )
)
fig.update_layout(
    title="vmap_1d output: f evaluated on aligned agent arrays",
    yaxis_title="f(x, y)",
    width=600,
    height=400,
)
fig.show()
Loading...

productmap vs vmap_1d

The key difference is how the inputs are combined:

productmapvmap_1d
CombinationCartesian productElementwise (aligned)
Output shape(n_x, n_y)(n,) where n = len(x) = len(y)
Number of evaluationsn_x * n_yn
Used inSolution (grid evaluation)Simulation (per-agent evaluation)
vmap strategyOne vmap per variable (iterative)Single vmap over all variables

How vmap_1d works: in_axes

vmap_1d uses a single vmap call. For each function parameter, the in_axes entry is 0 (map over the leading axis) if the parameter is in variables, or None (broadcast) otherwise.

# Equivalent manual construction for f(x, y) with variables=("x", "y")
f_manual_1d = vmap(f, in_axes=(0, 0))  # both x and y mapped

result_manual_1d = f_manual_1d(x_agents, y_agents)
print(f"vmap_1d matches manual: {jnp.allclose(result_1d, result_manual_1d)}")
vmap_1d matches manual: True

When only a subset of parameters is in variables, the rest are broadcast. This is useful for functions that take both per-agent states and shared parameters:

def g(x, y, scale):
    return scale * x**0.4 * y**0.6


# Only map over x and y; scale is broadcast to all agents
g_aligned = vmap_1d(func=g, variables=("x", "y"))

result_g = g_aligned(x=x_agents, y=y_agents, scale=2.0)
print(f"With broadcast scale=2.0: {result_g}")
print(f"Equals 2 * vmap_1d result: {jnp.allclose(result_g, 2.0 * result_1d)}")
With broadcast scale=2.0: [ 7.962144 15.924289 23.886433 31.848577 39.81072 ]
Equals 2 * vmap_1d result: True

simulation_spacemap

During simulation, pylcm needs to find the optimal action for each agent. This requires evaluating the Q-function (action-value function) on:

  • The Cartesian product of action grids (to search over all possible actions)

  • The aligned state arrays (each agent has fixed states)

simulation_spacemap combines both: productmap over actions, then vmap_1d over states. The output shape is (n_action1, n_action2, ..., n_agents).

def q(action, state):
    """A simple Q-function: utility depends on action choice and current state."""
    return -((action - state) ** 2)


q_mapped = simulation_spacemap(
    func=q,
    action_names=("action",),
    state_names=("state",),
)

# 6 possible actions, 4 agents with different states
actions = jnp.array([0.0, 2.0, 4.0, 6.0, 8.0, 10.0])
states = jnp.array([1.0, 4.0, 7.0, 9.0])

Q_values = q_mapped(action=actions, state=states)

print(f"Actions: {len(actions)} grid points")
print(f"States:  {len(states)} agents")
print(f"Output shape: {Q_values.shape}  (n_actions, n_agents)")
print(
    f"\nQ[2, 1] = q(action[2], state[1])"
    f" = q({actions[2]}, {states[1]})"
    f" = {Q_values[2, 1]:.1f}"
)
Actions: 6 grid points
States:  4 agents
Output shape: (4, 6)  (n_actions, n_agents)

Q[2, 1] = q(action[2], state[1]) = q(4.0, 4.0) = -25.0
fig = go.Figure()
for i in range(len(states)):
    fig.add_trace(
        go.Scatter(
            x=actions,
            y=Q_values[:, i],
            mode="lines+markers",
            name=f"Agent {i} (state={states[i]:.0f})",
        )
    )
fig.update_layout(
    title="simulation_spacemap: Q-values across actions for each agent",
    xaxis_title="Action",
    yaxis_title="Q(action, state)",
    width=700,
    height=400,
)
fig.show()
Loading...

Each agent’s optimal action is the one closest to their state (the peak of each curve). pylcm finds this by taking the argmax over the action dimension.

best_action_idx = jnp.argmax(Q_values, axis=0)
for i in range(len(states)):
    print(
        f"Agent {i}: state={states[i]:.0f}, "
        f"best action={actions[best_action_idx[i]]:.0f}"
    )
Agent 0: state=1, best action=0
Agent 1: state=4, best action=0
Agent 2: state=7, best action=2
Agent 3: state=9, best action=4

Where dispatchers fit in the pipeline

PhaseDispatcherWhat it evaluatesOutput shape
SolutionproductmapUtility, constraints, helper functions on state-action grids(n_s1, ..., n_a1, ...)
SolutionproductmapNext-period value function over shock transitions(n_shock1, ...)
Simulationsimulation_spacemapQ-function: product over actions, aligned over agent states(n_a1, ..., n_agents)
Simulationvmap_1dState transitions, regime transitions, additional targets(n_agents,)

Signature preservation

Raw vmap strips the function signature — the result only accepts positional arguments. The dispatchers restore signatures using allow_args and allow_only_kwargs from lcm.functools, so dispatched functions work with keyword arguments just like the originals.

print("Original signature: ", inspect.signature(f))
print("productmap signature:", inspect.signature(f_product))
print("vmap_1d signature:   ", inspect.signature(f_aligned))
print("spacemap signature:  ", inspect.signature(q_mapped))
Original signature:  (x, y)
productmap signature: (*, x, y)
vmap_1d signature:    (*, x, y)
spacemap signature:   (*, action, state)
# Raw vmap loses the signature
f_raw_vmap = vmap(f, in_axes=(0, 0))
print("Raw vmap signature:", inspect.signature(f_raw_vmap))
Raw vmap signature: (x, y)

This matters because pylcm uses dags.concatenate_functions to compose functions by matching argument names. Without preserved signatures, this composition would break.