pylcm solves and simulates models by evaluating scalar functions on structured
spaces — Cartesian products of grids during solution, aligned arrays of agent
states during simulation. JAX’s vmap is the primitive for this, but raw vmap
loses function signatures and doesn’t distinguish between product-mapping and
aligned-mapping.
The three dispatchers in lcm.dispatchers solve this:
productmap— evaluate on the Cartesian product of variables (solution phase)vmap_1d— evaluate on aligned 1D arrays (simulation phase)simulation_spacemap— hybrid: product over actions, aligned over states (simulation phase)
import inspect
import jax.numpy as jnp
import plotly.graph_objects as go
from jax import vmap
from lcm.dispatchers import productmap, simulation_spacemap, vmap_1d
blue, orange, green = "#4C78A8", "#F58518", "#54A24B"A scalar function¶
All dispatchers start from a scalar function — one that takes scalar inputs and returns a scalar output. We’ll use a simple production function throughout:
def f(x, y):
"""Cobb-Douglas-style function."""
return x**0.4 * y**0.6
# Scalar inputs, scalar output
f(2.0, 3.0)2.5508490012515814productmap¶
During the solution phase, pylcm needs to evaluate utility and constraint functions on every combination of state and action grid points — the Cartesian product of the grids.
productmap does this by iteratively applying vmap. Given variables ("x", "y"),
it vmaps over y first (innermost), then over x (outermost), producing an output
array with shape (n_x, n_y) — dimensions matching the variable order.
f_product = productmap(func=f, variables=("x", "y"))
x = jnp.array([1.0, 2.0, 3.0, 4.0])
y = jnp.array([10.0, 20.0, 30.0])
result = f_product(x=x, y=y)
print(f"x has {len(x)} points, y has {len(y)} points")
print(f"Output shape: {result.shape} (n_x, n_y)")
print(f"\nresult[2, 1] = f(x[2], y[1]) = f({x[2]}, {y[1]}) = {result[2, 1]:.4f}")
print(f"f(3.0, 20.0) = {f(3.0, 20.0):.4f}")x has 4 points, y has 3 points
Output shape: (4, 3) (n_x, n_y)
result[2, 1] = f(x[2], y[1]) = f(3.0, 20.0) = 9.3641
f(3.0, 20.0) = 9.3641
fig = go.Figure(
data=go.Heatmap(
z=result.T,
x=x,
y=y,
colorscale="Blues",
colorbar={"title": "f(x, y)"},
)
)
fig.update_layout(
title="productmap output: f evaluated on the Cartesian product of x and y",
xaxis_title="x",
yaxis_title="y",
width=600,
height=400,
)
fig.show()How it works: iterative vmap¶
Under the hood, productmap applies vmap one variable at a time, in reverse
order. For variables=("x", "y"):
vmapovery(position 1) → function now takes scalarx, arrayy→ returns 1Dvmapoverx(position 0) → function now takes arrayx, arrayy→ returns 2D
The reverse iteration ensures output dimensions match variable order.
Let’s verify this manually:
# Step 1: vmap over y (axis 1) — x is not mapped (None), y is mapped (0)
f_over_y = vmap(f, in_axes=(None, 0))
# Step 2: vmap over x (axis 0) — x is mapped (0), y is not mapped (None)
f_over_xy = vmap(f_over_y, in_axes=(0, None))
result_manual = f_over_xy(x, y)
print(f"productmap result matches manual vmap: {jnp.allclose(result, result_manual)}")productmap result matches manual vmap: True
vmap_1d¶
During simulation, states are not grids to form products over — they are arrays of realized values, one per simulated agent. All state arrays share the same leading axis (the agent axis).
vmap_1d maps the function over all specified variables simultaneously along
their leading axis, producing a 1D output (one value per agent).
f_aligned = vmap_1d(func=f, variables=("x", "y"))
# 5 agents, each with their own (x, y) pair
x_agents = jnp.array([1.0, 2.0, 3.0, 4.0, 5.0])
y_agents = jnp.array([10.0, 20.0, 30.0, 40.0, 50.0])
result_1d = f_aligned(x=x_agents, y=y_agents)
print(f"5 agents, aligned arrays → output shape: {result_1d.shape}")
print(
f"\nresult[2] = f(x[2], y[2])"
f" = f({x_agents[2]}, {y_agents[2]})"
f" = {result_1d[2]:.4f}"
)
print(f"f(3.0, 30.0) = {f(3.0, 30.0):.4f}")5 agents, aligned arrays → output shape: (5,)
result[2] = f(x[2], y[2]) = f(3.0, 30.0) = 11.9432
f(3.0, 30.0) = 11.9432
fig = go.Figure()
fig.add_trace(
go.Bar(
x=[f"Agent {i}" for i in range(len(result_1d))],
y=result_1d,
marker_color=blue,
text=[
f"f({x_agents[i]:.0f}, {y_agents[i]:.0f})" for i in range(len(result_1d))
],
textposition="outside",
)
)
fig.update_layout(
title="vmap_1d output: f evaluated on aligned agent arrays",
yaxis_title="f(x, y)",
width=600,
height=400,
)
fig.show()productmap vs vmap_1d¶
The key difference is how the inputs are combined:
productmap | vmap_1d | |
|---|---|---|
| Combination | Cartesian product | Elementwise (aligned) |
| Output shape | (n_x, n_y) | (n,) where n = len(x) = len(y) |
| Number of evaluations | n_x * n_y | n |
| Used in | Solution (grid evaluation) | Simulation (per-agent evaluation) |
| vmap strategy | One vmap per variable (iterative) | Single vmap over all variables |
How vmap_1d works: in_axes¶
vmap_1d uses a single vmap call. For each function parameter, the in_axes
entry is 0 (map over the leading axis) if the parameter is in variables, or
None (broadcast) otherwise.
# Equivalent manual construction for f(x, y) with variables=("x", "y")
f_manual_1d = vmap(f, in_axes=(0, 0)) # both x and y mapped
result_manual_1d = f_manual_1d(x_agents, y_agents)
print(f"vmap_1d matches manual: {jnp.allclose(result_1d, result_manual_1d)}")vmap_1d matches manual: True
When only a subset of parameters is in variables, the rest are broadcast. This is
useful for functions that take both per-agent states and shared parameters:
def g(x, y, scale):
return scale * x**0.4 * y**0.6
# Only map over x and y; scale is broadcast to all agents
g_aligned = vmap_1d(func=g, variables=("x", "y"))
result_g = g_aligned(x=x_agents, y=y_agents, scale=2.0)
print(f"With broadcast scale=2.0: {result_g}")
print(f"Equals 2 * vmap_1d result: {jnp.allclose(result_g, 2.0 * result_1d)}")With broadcast scale=2.0: [ 7.962144 15.924289 23.886433 31.848577 39.81072 ]
Equals 2 * vmap_1d result: True
simulation_spacemap¶
During simulation, pylcm needs to find the optimal action for each agent. This requires evaluating the Q-function (action-value function) on:
The Cartesian product of action grids (to search over all possible actions)
The aligned state arrays (each agent has fixed states)
simulation_spacemap combines both: productmap over actions, then vmap_1d over
states. The output shape is (n_action1, n_action2, ..., n_agents).
def q(action, state):
"""A simple Q-function: utility depends on action choice and current state."""
return -((action - state) ** 2)
q_mapped = simulation_spacemap(
func=q,
action_names=("action",),
state_names=("state",),
)
# 6 possible actions, 4 agents with different states
actions = jnp.array([0.0, 2.0, 4.0, 6.0, 8.0, 10.0])
states = jnp.array([1.0, 4.0, 7.0, 9.0])
Q_values = q_mapped(action=actions, state=states)
print(f"Actions: {len(actions)} grid points")
print(f"States: {len(states)} agents")
print(f"Output shape: {Q_values.shape} (n_actions, n_agents)")
print(
f"\nQ[2, 1] = q(action[2], state[1])"
f" = q({actions[2]}, {states[1]})"
f" = {Q_values[2, 1]:.1f}"
)Actions: 6 grid points
States: 4 agents
Output shape: (4, 6) (n_actions, n_agents)
Q[2, 1] = q(action[2], state[1]) = q(4.0, 4.0) = -25.0
fig = go.Figure()
for i in range(len(states)):
fig.add_trace(
go.Scatter(
x=actions,
y=Q_values[:, i],
mode="lines+markers",
name=f"Agent {i} (state={states[i]:.0f})",
)
)
fig.update_layout(
title="simulation_spacemap: Q-values across actions for each agent",
xaxis_title="Action",
yaxis_title="Q(action, state)",
width=700,
height=400,
)
fig.show()Each agent’s optimal action is the one closest to their state (the peak of each curve). pylcm finds this by taking the argmax over the action dimension.
best_action_idx = jnp.argmax(Q_values, axis=0)
for i in range(len(states)):
print(
f"Agent {i}: state={states[i]:.0f}, "
f"best action={actions[best_action_idx[i]]:.0f}"
)Agent 0: state=1, best action=0
Agent 1: state=4, best action=0
Agent 2: state=7, best action=2
Agent 3: state=9, best action=4
Where dispatchers fit in the pipeline¶
| Phase | Dispatcher | What it evaluates | Output shape |
|---|---|---|---|
| Solution | productmap | Utility, constraints, helper functions on state-action grids | (n_s1, ..., n_a1, ...) |
| Solution | productmap | Next-period value function over shock transitions | (n_shock1, ...) |
| Simulation | simulation_spacemap | Q-function: product over actions, aligned over agent states | (n_a1, ..., n_agents) |
| Simulation | vmap_1d | State transitions, regime transitions, additional targets | (n_agents,) |
Signature preservation¶
Raw vmap strips the function signature — the result only accepts positional
arguments. The dispatchers restore signatures using allow_args and
allow_only_kwargs from lcm.functools, so dispatched functions work with keyword
arguments just like the originals.
print("Original signature: ", inspect.signature(f))
print("productmap signature:", inspect.signature(f_product))
print("vmap_1d signature: ", inspect.signature(f_aligned))
print("spacemap signature: ", inspect.signature(q_mapped))Original signature: (x, y)
productmap signature: (*, x, y)
vmap_1d signature: (*, x, y)
spacemap signature: (*, action, state)
# Raw vmap loses the signature
f_raw_vmap = vmap(f, in_axes=(0, 0))
print("Raw vmap signature:", inspect.signature(f_raw_vmap))Raw vmap signature: (x, y)
This matters because pylcm uses dags.concatenate_functions to compose functions
by matching argument names. Without preserved signatures, this composition would
break.