A three-period consumption-savings model with two regimes:
Working life (ages 25 and 45): The agent chooses whether to work and how much to consume. A simple tax-and-transfer system guarantees a consumption floor. Savings earn interest.
Retirement (age 65): Terminal regime. The agent consumes out of remaining wealth.
Model¶
An agent lives for three periods (ages 25, 45, and 65). In the first two periods (working life), the agent chooses whether to work and how much to consume . In the final period (retirement), the agent consumes out of remaining wealth.
Working life (ages 25 and 45):
subject to
where is wealth, earnings, the wage, a consumption floor guaranteed by transfers, the tax rate, and end-of-period wealth. The transfer only kicks in when the agent’s resources () fall below the consumption floor.
Retirement (age 65, terminal):
from pprint import pprint
import jax.numpy as jnp
import numpy as np
import pandas as pd
import plotly.express as px
from lcm import (
AgeGrid,
DiscreteGrid,
LinSpacedGrid,
LogSpacedGrid,
Model,
Regime,
categorical,
initial_conditions_from_dataframe,
)
from lcm.typing import (
BoolND,
ContinuousAction,
ContinuousState,
DiscreteAction,
FloatND,
ScalarInt,
)Categorical Variables¶
@categorical(ordered=False)
class Work:
no: int
yes: int
@categorical(ordered=False)
class RegimeId:
working_life: int
retirement: intModel Functions¶
# Utility
def utility(
consumption: ContinuousAction,
work: DiscreteAction,
disutility_of_work: float,
risk_aversion: float,
) -> FloatND:
return consumption ** (1 - risk_aversion) / (
1 - risk_aversion
) - disutility_of_work * (work == Work.yes)
def utility_retirement(wealth: ContinuousState, risk_aversion: float) -> FloatND:
return wealth ** (1 - risk_aversion) / (1 - risk_aversion)
# Auxiliary functions
def earnings(work: DiscreteAction, wage: float) -> FloatND:
return jnp.where(work == Work.yes, wage, 0.0)
def taxes_transfers(
earnings: FloatND,
wealth: ContinuousState,
consumption_floor: float,
tax_rate: float,
) -> FloatND:
return jnp.where(
earnings >= consumption_floor,
tax_rate * (earnings - consumption_floor),
jnp.minimum(0.0, wealth + earnings - consumption_floor),
)
def end_of_period_wealth(
wealth: ContinuousState,
earnings: FloatND,
taxes_transfers: FloatND,
consumption: ContinuousAction,
) -> FloatND:
return wealth + earnings - taxes_transfers - consumption
# State transition
def next_wealth(end_of_period_wealth: FloatND, interest_rate: float) -> ContinuousState:
return (1 + interest_rate) * end_of_period_wealth
# Constraints
def borrowing_constraint_working(end_of_period_wealth: FloatND) -> BoolND:
return end_of_period_wealth >= 0
# Regime transition
def next_regime(age: float, last_working_age: float) -> ScalarInt:
return jnp.where(
age >= last_working_age, RegimeId.retirement, RegimeId.working_life
)Regimes and Model¶
age_grid = AgeGrid(start=25, stop=65, step="20Y")
retirement_age = age_grid.exact_values[-1]
working_life = Regime(
transition=next_regime,
active=lambda age: age < retirement_age,
states={
"wealth": LinSpacedGrid(start=0, stop=50, n_points=25),
},
state_transitions={
"wealth": next_wealth,
},
actions={
"work": DiscreteGrid(Work),
"consumption": LogSpacedGrid(start=4, stop=50, n_points=100),
},
functions={
"utility": utility,
"earnings": earnings,
"taxes_transfers": taxes_transfers,
"end_of_period_wealth": end_of_period_wealth,
},
constraints={
"borrowing_constraint_working": borrowing_constraint_working,
},
)
retirement = Regime(
transition=None,
active=lambda age: age >= retirement_age,
states={
"wealth": LinSpacedGrid(start=0, stop=50, n_points=25),
},
functions={"utility": utility_retirement},
)
model = Model(
regimes={
"working_life": working_life,
"retirement": retirement,
},
ages=age_grid,
regime_id_class=RegimeId,
description="A tiny three-period consumption-savings model.",
)Parameters¶
Use model.get_params_template() to see what parameters the model expects, organized
by regime and function.
pprint(dict(model.get_params_template())){'retirement': {'utility': {'risk_aversion': 'float'}},
'working_life': {'H': {'discount_factor': 'float'},
'borrowing_constraint_working': {},
'earnings': {'wage': 'float'},
'end_of_period_wealth': {},
'next_regime': {'last_working_age': 'float'},
'next_wealth': {'interest_rate': 'float'},
'taxes_transfers': {'consumption_floor': 'float',
'tax_rate': 'float'},
'utility': {'disutility_of_work': 'float',
'risk_aversion': 'float'}}}
Parameters shared across regimes (risk_aversion, discount_factor,
interest_rate) can be specified at the model level. Parameters unique to one
regime go under the regime name.
params = {
"discount_factor": 0.95,
"risk_aversion": 1.5,
"interest_rate": 0.03,
"working_life": {
"utility": {"disutility_of_work": 1.0},
"earnings": {"wage": 20.0},
"taxes_transfers": {"consumption_floor": 2.0, "tax_rate": 0.2},
"next_regime": {"last_working_age": age_grid.exact_values[-2]},
},
}Solve and Simulate¶
n_agents = 100
initial_df = pd.DataFrame(
{
"regime": "working_life",
"age": float(age_grid.exact_values[0]),
"wealth": np.linspace(1, 20, n_agents),
}
)
initial_conditions = initial_conditions_from_dataframe(initial_df, model=model)
result = model.solve_and_simulate(
params=params,
initial_conditions=initial_conditions,
)INFO:lcm:Starting solution
WARNING:lcm:NaN/Inf in V_arr for regime 'retirement' at age 65.0
INFO:lcm:Age: 65.0 regimes=1 (0.1s)
INFO:lcm:Age: 45.0 regimes=1 (0.2s)
INFO:lcm:Age: 25.0 regimes=1 (0.1s)
INFO:lcm:Solution complete (0.4s)
INFO:lcm:Starting simulation
INFO:lcm:Age: 25.0 regimes=1 (1.1s)
INFO:lcm:Age: 45.0 regimes=1 (0.1s)
INFO:lcm:Age: 65.0 regimes=1 (0.1s)
INFO:lcm:Simulation complete (1.4s)
df = result.to_dataframe(additional_targets="all")
df["age"] = df["age"].astype(int)
df.loc[df["age"] == retirement_age, "consumption"] = df.loc[
df["age"] == retirement_age, "wealth"
]
columns = [
"regime",
"work",
"consumption",
"wealth",
"earnings",
"taxes_transfers",
"end_of_period_wealth",
"value",
]
df.set_index(["subject_id", "age"])[columns].head(20).style.format(
precision=1,
na_rep="",
)Source
# Classify agents by work pattern across the two working-life periods
first_working_age = age_grid.exact_values[0]
last_working_age = age_grid.exact_values[-2]
df_working = df[df["regime"] == "working_life"]
work_by_age = df_working.pivot_table(
index="subject_id",
columns="age",
values="work",
aggfunc="first",
)
work_pattern = (
work_by_age[first_working_age].astype(str)
+ ", "
+ work_by_age[last_working_age].astype(str)
)
assert "yes, yes" not in work_pattern.to_numpy(), (
"Plotting assumes that no agent works in both periods of working life."
)
label_map = {
"yes, no": "low", # work early, not later
"no, yes": "medium", # coast early, work later
"no, no": "high", # never work
}
groups = work_pattern.map(label_map).rename("initial_wealth")
# Combined descriptives and work decisions table
initial_wealth = df[df["age"] == first_working_age].set_index("subject_id")["wealth"]
group_desc = initial_wealth.groupby(groups).agg(["min", "max"]).round(1)
df_groups = df.copy()
df_groups["initial_wealth"] = df_groups["subject_id"].map(groups)
df_mean = df_groups.groupby(["initial_wealth", "age"], as_index=False).mean(
numeric_only=True,
)
work_table = df_mean[df_mean["age"] < retirement_age].pivot_table(
index="initial_wealth",
columns="age",
values="earnings",
)
work_table = (work_table > 0).astype(int)
work_table.columns = [f"works {c}" for c in work_table.columns]
summary = pd.concat([group_desc, work_table], axis=1)
summary.index.name = "initial_wealth"
summary.loc[["low", "medium", "high"]].style.format(precision=1, na_rep="")Source
fig = px.line(
df_mean,
x="age",
y="consumption",
color="initial_wealth",
title="Consumption by Age",
template="plotly_dark",
)
fig.show()Source
fig = px.line(
df_mean,
x="age",
y="wealth",
color="initial_wealth",
title="Wealth by Age",
template="plotly_dark",
)
fig.show()