In dynamic programming, the value function is computed on a discrete grid but must be evaluated at arbitrary points when solving earlier periods. The function representation turns a pre-computed array into a callable function that:
Accepts named arguments (e.g.,
wealth=150.0)Returns exact values at grid points
Linearly interpolates between grid points
This notebook explains how it works, using a minimal terminal-regime model.
The three steps¶
Converting an array into a callable function requires three things:
Label translation — Map variable labels to array indices. For discrete variables, pylcm uses integer codes that directly serve as indices (identity mapping).
Coordinate finding — For continuous variables, convert physical values (e.g., wealth = 150) to generalized coordinates (fractional indices into the grid). See the interpolation notebook for details.
Interpolation — Use the generalized coordinates with
map_coordinatesto linearly interpolate between grid points.
Worked example¶
We set up a minimal model with a single terminal regime: a retiree choosing consumption given wealth, with CRRA utility. The wealth grid is intentionally coarse (10 points) to clearly show the interpolation behavior.
import jax.numpy as jnp
import plotly.graph_objects as go
from lcm import AgeGrid, LinSpacedGrid, Model, Regime, categorical
from lcm.typing import ContinuousAction, ContinuousState, FloatND
blue, orange, green = "#4C78A8", "#F58518", "#54A24B"
def utility(consumption: ContinuousAction, risk_aversion: float) -> FloatND:
return consumption ** (1 - risk_aversion) / (1 - risk_aversion)
def next_wealth(
wealth: ContinuousState,
consumption: ContinuousAction,
interest_rate: float,
) -> ContinuousState:
return (1 + interest_rate) * (wealth - consumption)
def borrowing_constraint(
consumption: ContinuousAction, wealth: ContinuousState
) -> FloatND:
return consumption <= wealth
@categorical(ordered=False)
class RegimeId:
working_life: int
retirement: int
retirement_regime = Regime(
transition=None,
functions={"utility": utility},
constraints={"borrowing_constraint": borrowing_constraint},
actions={"consumption": LinSpacedGrid(start=1, stop=400, n_points=50)},
states={"wealth": LinSpacedGrid(start=1, stop=400, n_points=10)},
)
working_life_regime = Regime(
transition=lambda: RegimeId.retirement,
functions={"utility": utility},
constraints={"borrowing_constraint": borrowing_constraint},
actions={"consumption": LinSpacedGrid(start=1, stop=400, n_points=50)},
states={
"wealth": LinSpacedGrid(start=1, stop=400, n_points=10),
},
state_transitions={
"wealth": next_wealth,
},
)
model = Model(
description="Minimal consumption-savings model",
ages=AgeGrid(start=25, stop=65, step="20Y"),
regimes={"working_life": working_life_regime, "retirement": retirement_regime},
regime_id_class=RegimeId,
)
params = {
"discount_factor": 0.95,
"risk_aversion": 1.5,
"interest_rate": 0.04,
}Computing the last-period value function array¶
In the terminal period, the value function is the maximum of utility over feasible actions. We use the internal regime representation to access the compiled functions and grids.
from lcm.Q_and_F import _get_U_and_F
internal_regime = model.internal_regimes["retirement"]
u_and_f = _get_U_and_F(internal_regime.internal_functions)
u_and_f.__signature__<Signature (consumption: 'ContinuousAction', utility__risk_aversion: 'float', wealth: 'ContinuousState') -> ('FloatND', 'FloatND')>The function returns (utility, feasibility) for scalar inputs:
_u, _f = u_and_f(consumption=100.0, wealth=50.0, utility__risk_aversion=1.5)
print(f"Utility: {_u}, feasible: {_f}")Utility: -0.2, feasible: False
To evaluate on the full state-action grid, we use productmap:
from lcm.dispatchers import productmap
u_and_f_mapped = productmap(func=u_and_f, variables=("wealth", "consumption"))
u, f = u_and_f_mapped(**internal_regime.grids, utility__risk_aversion=1.5)
V_arr = jnp.max(u, axis=1, where=f, initial=-jnp.inf)
wealth_grid = internal_regime.grids["wealth"]
print(f"V_arr shape: {V_arr.shape} ({len(wealth_grid)} wealth grid points)")V_arr shape: (10,) (10 wealth grid points)
fig = go.Figure()
fig.add_trace(
go.Scatter(
x=wealth_grid,
y=V_arr,
mode="markers",
marker={"color": blue, "size": 8},
name="Pre-calculated values",
)
)
fig.update_layout(
xaxis_title="Wealth (x)",
yaxis_title="V(x)",
width=600,
height=400,
)
fig.show()Creating the function representation¶
The function representation turns V_arr into a callable that can be evaluated at
any wealth value. The name_of_values_on_grid argument sets the name of the array
parameter in the resulting function.
from lcm.function_representation import get_value_function_representation
scalar_value_function = get_value_function_representation(
state_space_info=internal_regime.state_space_info,
name_of_values_on_grid="V_arr",
)
scalar_value_function.__signature__<Signature (V_arr: 'Array', next_wealth: 'Array') -> 'Array'>This scalar function is then wrapped with productmap so it can evaluate on arrays:
value_function = productmap(func=scalar_value_function, variables=("next_wealth",))Visualizing interpolation¶
We evaluate the function representation on the original grid points (which should match exactly) and on additional points between grid points (which are interpolated).
wealth_points_new = jnp.array([10.0, 25.0, 75.0, 210.0, 300.0])
wealth_all = jnp.concatenate([wealth_grid, wealth_points_new])
V_via_func = value_function(next_wealth=wealth_all, V_arr=V_arr)fig = go.Figure()
fig.add_trace(
go.Scatter(
x=wealth_grid,
y=V_arr,
mode="lines+markers",
marker={"color": blue, "size": 8},
line={"color": blue},
name="Pre-calculated values (linear interpolation)",
)
)
fig.add_trace(
go.Scatter(
x=wealth_all,
y=V_via_func,
mode="markers",
marker={"color": orange, "size": 6},
name="Function representation output",
)
)
fig.update_layout(
xaxis_title="Wealth (x)",
yaxis_title="V(x)",
width=700,
height=400,
)
fig.show()The orange points from the function representation lie exactly on the blue line connecting the grid points. The function representation behaves like an analytical function corresponding to this piecewise linear interpolation.
Technical details¶
The function representation is assembled from four building blocks, each
implemented as a small function with a carefully chosen signature. These functions
are composed using dags.concatenate_functions.
Label translator¶
Maps discrete variable labels to array indices. pylcm uses integer codes internally, so the translator is the identity function.
from lcm.function_representation import _get_label_translator
translator = _get_label_translator(in_name="health")
print(f"Signature: {translator.__signature__}")
print(f"translator(health=3) = {translator(health=3)}")Signature: (health: 'Array') -> 'Array'
translator(health=3) = 3
Lookup function¶
Indexes into the value function array using named axes. This is important because
dags.concatenate_functions matches functions by argument names.
from lcm.function_representation import _get_lookup_function
lookup = _get_lookup_function(array_name="V_arr", axis_names=["wealth_index"])
print(f"Signature: {lookup.__signature__}")
# Look up values at indices 0, 2, 5
lookup(wealth_index=jnp.array([0, 2, 5]), V_arr=V_arr)Signature: (wealth_index: 'Array', V_arr: 'Array') -> 'Array'
Array([-2. , -0.22028813, -0.13457806], dtype=float32)Coordinate finder¶
Converts physical values to generalized coordinates — fractional indices into the
grid. For a linearly spaced grid [1, 45.3, 89.7, ...], the value 23.2 might
correspond to coordinate 0.5 (halfway between indices 0 and 1).
from lcm.function_representation import _get_coordinate_finder
wealth_gridspec = LinSpacedGrid(start=1, stop=400, n_points=10)
wealth_coordinate_finder = _get_coordinate_finder(
in_name="wealth",
grid=wealth_gridspec,
)
print(f"Signature: {wealth_coordinate_finder.__signature__}")
wealth_values = jnp.array([1.0, (1 + 45.333336) / 2, 390.0])
coords = wealth_coordinate_finder(wealth=wealth_values)
for w, c in zip(wealth_values, coords, strict=True):
print(f" wealth = {w:8.2f} → coordinate = {float(c):.4f}")Signature: (wealth: 'Array') -> 'Array'
wealth = 1.00 → coordinate = 0.0000
wealth = 23.17 → coordinate = 0.5000
wealth = 390.00 → coordinate = 8.7744
Interpolator¶
Uses the generalized coordinates to linearly interpolate on the value function
array via map_coordinates.
from lcm.function_representation import _get_interpolator
value_function_interpolator = _get_interpolator(
name_of_values_on_grid="V_arr",
axis_names=["wealth_index"],
)
print(f"Signature: {value_function_interpolator.__signature__}")
wealth_indices = wealth_coordinate_finder(wealth=wealth_values)
V_interpolations = value_function_interpolator(wealth_index=wealth_indices, V_arr=V_arr)Signature: (V_arr: 'Array', wealth_index: 'Array') -> 'Array'
fig = go.Figure()
fig.add_trace(
go.Scatter(
x=wealth_gridspec.to_jax(),
y=V_arr,
mode="markers",
marker={"color": blue, "size": 8},
name="Pre-calculated values",
)
)
fig.add_trace(
go.Scatter(
x=wealth_values,
y=V_interpolations,
mode="markers",
marker={"color": orange, "size": 6},
name="Interpolated values",
)
)
fig.update_layout(
xaxis_title="Wealth (x)",
yaxis_title="V(x)",
width=600,
height=400,
)
fig.show()Re-implementation from scratch¶
To understand how the pieces fit together, let’s re-implement the function
representation manually using dags.concatenate_functions.
The general idea: create functions for array lookup, coordinate finding, and
interpolation, each with signatures that declare their dependencies. Then let
dags wire them together.
Steps¶
Label translators for discrete states — identity functions (our model has no discrete states, so we skip this)
Discrete lookup — index into the array using discrete labels. With no discrete states, this is the identity (returns the array unchanged)
Coordinate finder for each continuous state — maps values to fractional indices
Interpolator — uses coordinates to interpolate on the array
Implementation¶
space_info = internal_regime.state_space_info
funcs = {}
# Step 1: No discrete state variables
print(f"Discrete states: {space_info.discrete_states}")Discrete states: {}
# Step 2: Discrete lookup — identity (no discrete states to index by)
def discrete_lookup(V_arr):
return V_arr
funcs["__interpolation_data__"] = discrete_lookup# Step 3: Coordinate finder for wealth
from lcm.grid_helpers import get_linspace_coordinate
def wealth_coordinate_finder(wealth):
return get_linspace_coordinate(value=wealth, start=1, stop=400, n_points=10)
funcs["__wealth_coord__"] = wealth_coordinate_finder# Step 4: Interpolator using map_coordinates
from lcm.ndimage import map_coordinates
def interpolator(__interpolation_data__, __wealth_coord__):
coordinates = jnp.array([__wealth_coord__])
return map_coordinates(input=__interpolation_data__, coordinates=coordinates)
funcs["__fval__"] = interpolator# Compose with dags
from dags import concatenate_functions
value_function = concatenate_functions(functions=funcs, targets="__fval__")
print(f"Composed signature: {value_function.__signature__}")
V_evaluated = value_function(wealth=wealth_gridspec.to_jax(), V_arr=V_arr)Composed signature: (V_arr, wealth)
fig = go.Figure()
fig.add_trace(
go.Scatter(
x=wealth_gridspec.to_jax(),
y=V_arr,
mode="markers",
marker={"color": blue, "size": 8},
name="Pre-calculated values",
)
)
fig.add_trace(
go.Scatter(
x=wealth_gridspec.to_jax(),
y=V_evaluated,
mode="markers",
marker={"color": orange, "size": 6},
name="Re-implemented function representation",
)
)
fig.update_layout(
xaxis_title="Wealth (x)",
yaxis_title="V(x)",
width=600,
height=400,
)
fig.show()The orange points coincide perfectly with the blue grid points — our manual re-implementation matches pylcm’s built-in function representation.