Skip to article frontmatterSkip to article content
Site not loading correctly?

This may be due to an incorrect BASE_URL configuration. See the MyST Documentation for reference.

Defining Models

A Model ties together regimes, an age grid, and a regime ID class into a solvable lifecycle model.

The Model Constructor

from lcm import Model

model = Model(
    regimes=regimes,             # dict mapping names to Regime instances
    ages=ages,                   # AgeGrid defining the lifecycle timeline
    regime_id_class=RegimeId,    # @categorical dataclass mapping names to int indices
    enable_jit=True,             # controls JAX compilation (default: True)
    fixed_params={},             # optional params baked in at init time
    description="",              # optional description string
)

All arguments are keyword-only. The three required arguments are regimes, ages, and regime_id_class.

Regime ID Classes

The regime_id_class maps regime names to integer indices. Use the @categorical decorator to create it:

from lcm import categorical

@categorical(ordered=False)
class RegimeId:
    retired: int
    working: int

Rules:

Age Grids

The ages argument defines the lifecycle timeline. There are two construction modes:

Range-based

from lcm import AgeGrid

ages = AgeGrid(start=25, stop=75, step="Y")  # annual steps, ages 25 to 75

Step formats:

The stop value is inclusive if (stop - start) is exactly divisible by the step size.

Exact values

ages = AgeGrid(exact_values=[25, 35, 45, 55, 65, 75])

Use this for irregular age spacing.

Key properties

Model Validation Rules

The Model constructor validates:

Inspecting a Model

After construction, the model exposes several useful attributes:

model.regimes             # immutable mapping of user Regime objects
model.internal_regimes    # processed internal representations
model.n_periods           # number of periods
model.regime_names_to_ids # name -> integer mapping
model.get_params_template()  # mutable copy of the parameter template

Use model.get_params_template() to get a mutable copy of the parameter template — see Parameters.

Complete Example

import jax.numpy as jnp
from lcm import AgeGrid, DiscreteGrid, LinSpacedGrid, Model, Regime, categorical


@categorical(ordered=False)
class RegimeId:
    retired: int
    working: int


@categorical(ordered=False)
class WorkChoice:
    no: int
    yes: int


def next_wealth(wealth, consumption, interest_rate):
    return (wealth - consumption) * (1 + interest_rate)


def next_regime(work):
    return jnp.where(work == WorkChoice.yes, RegimeId.working, RegimeId.retired)


def utility(consumption, work, disutility_of_work):
    return jnp.log(consumption) - disutility_of_work * work


def terminal_utility(wealth):
    return jnp.log(wealth)


working = Regime(
    transition=next_regime,
    states={
        "wealth": LinSpacedGrid(start=1, stop=100, n_points=50),
    },
    state_transitions={
        "wealth": next_wealth,
    },
    actions={
        "consumption": LinSpacedGrid(start=1, stop=50, n_points=30),
        "work": DiscreteGrid(WorkChoice),
    },
    functions={"utility": utility},
)

retired = Regime(
    transition=None,
    states={
        "wealth": LinSpacedGrid(start=1, stop=100, n_points=50),
    },
    functions={"utility": terminal_utility},
)

model = Model(
    regimes={"working": working, "retired": retired},
    ages=AgeGrid(start=25, stop=75, step="Y"),
    regime_id_class=RegimeId,
)

See Also