{ "cells": [ { "cell_type": "markdown", "id": "30779248", "metadata": {}, "source": [ "$$\n", "\\newcommand{\\argmax}{arg\\,max}\n", "\\newcommand{\\argmin}{arg\\,min}\n", "$$" ] }, { "cell_type": "markdown", "id": "573fc071", "metadata": {}, "source": [ "
\n", " \n", " \"QuantEcon\"\n", " \n", "
" ] }, { "cell_type": "markdown", "id": "e93b2d2e", "metadata": {}, "source": [ "# Job Search IV: Fitted Value Function Iteration" ] }, { "cell_type": "markdown", "id": "92f630f7", "metadata": {}, "source": [ "# GPU\n", "\n", "This lecture was built using a machine with access to a GPU — although it will also run without one.\n", "\n", "[Google Colab](https://colab.research.google.com/) has a free tier with GPUs\n", "that you can access as follows:\n", "\n", "1. Click on the “play” icon top right \n", "1. Select Colab \n", "1. Set the runtime environment to include a GPU " ] }, { "cell_type": "markdown", "id": "0f2371e8", "metadata": {}, "source": [ "## Contents\n", "\n", "- [Job Search IV: Fitted Value Function Iteration](#Job-Search-IV:-Fitted-Value-Function-Iteration) \n", " - [Overview](#Overview) \n", " - [Model](#Model) \n", " - [Solution method](#Solution-method) \n", " - [Implementation](#Implementation) \n", " - [Simulation](#Simulation) \n", " - [Exercises](#Exercises) " ] }, { "cell_type": "markdown", "id": "f48ff2cf", "metadata": {}, "source": [ "## Overview\n", "\n", "This lecture follows on from the job search model with separation presented in\n", "the [previous lecture](https://python.quantecon.org/mccall_model_with_separation.html).\n", "\n", "That lecture combined exogenous job separation events and a Markov wage offer\n", "process.\n", "\n", "In this lecture we continue with this set and, in addition, allow the wage offer process to be continuous rather than discrete.\n", "\n", "In particular,\n", "\n", "$$\n", "W_t = \\exp(X_t)\n", " \\quad \\text{where} \\quad\n", " X_{t+1} = \\rho X_t + \\nu Z_{t+1}\n", "$$\n", "\n", "and $ \\{Z_t\\} $ is IID and standard normal.\n", "\n", "While we already considered continuous wage distributions briefly in\n", "[Job Search I: The McCall Search Model](https://python.quantecon.org/mccall_model.html), the change was relatively trivial in that case.\n", "\n", "The reason is that we were able to reduce the problem to solving for a single\n", "scalar value (the continuation value).\n", "\n", "Here, in our Markov setting, the change is less trivial, since a continuous wage\n", "distribution leads to an uncountably infinite state space.\n", "\n", "The infinite state space leads to additional challenges, particularly when it\n", "comes to applying value function iteration (VFI).\n", "\n", "These challenges will lead us to modify VFI by adding an interpolation step.\n", "\n", "The combination of VFI and this interpolation step is called **fitted value\n", "function iteration** (fitted VFI).\n", "\n", "Fitted VFI is very common in practice, so we will take some time to work through\n", "the details.\n", "\n", "In addition to what’s in Anaconda, this lecture will need the following libraries" ] }, { "cell_type": "code", "execution_count": null, "id": "02faca04", "metadata": { "hide-output": false }, "outputs": [], "source": [ "!pip install quantecon jax" ] }, { "cell_type": "markdown", "id": "7c497c28", "metadata": {}, "source": [ "We will use the following imports:" ] }, { "cell_type": "code", "execution_count": null, "id": "75a98dea", "metadata": { "hide-output": false }, "outputs": [], "source": [ "import matplotlib.pyplot as plt\n", "import jax\n", "import jax.numpy as jnp\n", "from jax import lax\n", "from typing import NamedTuple\n", "from functools import partial\n", "import quantecon as qe" ] }, { "cell_type": "markdown", "id": "a2f36d62", "metadata": {}, "source": [ "## Model\n", "\n", "Assuming that readers are familiar with the content of [Job Search III: Search with Separation and Markov Wages](https://python.quantecon.org/mccall_model_with_sep_markov.html), the model can be summarized as follows.\n", "\n", "- Wage offers follow a continuous Markov process: $ W_t = \\exp(X_t) $ where $ X_{t+1} = \\rho X_t + \\nu Z_{t+1} $ \n", "- $ \\{Z_t\\} $ is IID and standard normal \n", "- Jobs terminate with probability $ \\alpha $ each period (separation rate) \n", "- Unemployed workers receive compensation $ c $ per period \n", "- Workers have CRRA utility $ u(x) = \\frac{x^{1-\\gamma} - 1}{1-\\gamma} $ \n", "- Future payoffs are discounted by factor $ \\beta \\in (0,1) $ " ] }, { "cell_type": "markdown", "id": "0cdfccf7", "metadata": {}, "source": [ "## Solution method\n", "\n", "Let’s discuss how we can solve this model.\n", "\n", "The only real change from [Job Search III: Search with Separation and Markov Wages](https://python.quantecon.org/mccall_model_with_sep_markov.html) is that we replace sums with integrals." ] }, { "cell_type": "markdown", "id": "0103c828", "metadata": {}, "source": [ "### Value function iteration\n", "\n", "In the [discrete case](https://python.quantecon.org/mccall_model_with_sep_markov.html), we ended up iterating on the Bellman operator\n", "\n", "\n", "\n", "$$\n", "(Tv_u)(w) =\n", " \\max\n", " \\left\\{\n", " \\frac{1}{1-\\beta(1-\\alpha)} \\cdot\n", " \\left(\n", " u(w) + \\alpha\\beta (Pv_u)(w)\n", " \\right),\n", " u(c) + \\beta(Pv_u)(w)\n", " \\right\\} \\tag{45.1}\n", "$$\n", "\n", "where\n", "\n", "$$\n", "(P v_u)(w) := \\sum_{w'} v_u(w') P(w, w')\n", "$$\n", "\n", "Here we iterate on the same law after changing the definition of the $ P $ operator to\n", "\n", "$$\n", "(P v_u)(w) := \\int v_u(w') p(w, w') d w'\n", "$$\n", "\n", "where $ p(w, \\cdot) $ is the conditional density of $ w' $ given $ w $.\n", "\n", "Here we are thinking of $ v_u $ as a function on all of $ \\mathbb{R}_+ $.\n", "\n", "After taking $ \\psi $ to be the standard normal density, we can write the expression above more explicitly as\n", "\n", "$$\n", "(P v_u)(w) := \\int v_u( w^\\rho \\exp(\\nu z) ) \\psi(z) dz,\n", "$$\n", "\n", "To understand this expression, recall that $ W_t = \\exp(X_t) $ where $ X_{t+1} = \\rho X_t + \\nu Z_{t+1} $.\n", "\n", "As a result $ W_{t+1} = \\exp(X_{t+1}) = \\exp(\\rho \\log(W_t) + \\nu Z_{t+1}) = W_t^\\rho \\exp(\\nu Z_{t+1}) $.\n", "\n", "The integral above regards the current wage $ W_t $ as fixed at $ w $ and takes the\n", "expectation of $ v_u(w^\\rho \\exp(\\nu Z_{t+1})) $." ] }, { "cell_type": "markdown", "id": "e1edb4be", "metadata": {}, "source": [ "### Fitting\n", "\n", "In theory, we should now proceed as follows:\n", "\n", "1. Begin with a guess $ v $ \n", "1. Applying $ T $ to obtain the update $ v' = Tv $ \n", "1. Unless some stopping condition is satisfied, set $ v = v' $ and go to step 2. \n", "\n", "\n", "However, there is a problem we must confront before we implement this procedure: The iterates of the value function can neither be calculated exactly nor stored on a computer.\n", "\n", "To see the issue, consider [(45.1)](#equation-bell2mcmc).\n", "\n", "Even if $ v $ is a known function, the only way to store its update $ v' $ is to record its value $ v'(w) $ for every $ w \\in \\mathbb R_+ $.\n", "\n", "Clearly, this is impossible." ] }, { "cell_type": "markdown", "id": "f0f9b063", "metadata": {}, "source": [ "### Fitted value function iteration\n", "\n", "What we will do instead is use **fitted value function iteration**.\n", "\n", "The procedure is as follows:\n", "\n", "Let a current guess $ v $ be given.\n", "\n", "Now we record the value of the function $ v' $ at only finitely many “grid” points $ w_1 < w_2 < \\cdots < w_I $ and then reconstruct $ v' $ from this information when required.\n", "\n", "More precisely, the algorithm will be\n", "\n", "\n", "\n", "1. Begin with an array $ \\mathbf v $ representing the values of an initial guess of the value function on some grid points $ \\{w_i\\} $. \n", "1. Build a function $ v $ on the state space $ \\mathbb R_+ $ by interpolation or approximation, based on $ \\mathbf v $ and $ \\{ w_i\\} $. \n", "1. Obtain and record the samples of the updated function $ v'(w_i) $ on each grid point $ w_i $. \n", "1. Unless some stopping condition is satisfied, take this as the new array and go to step 1. \n", "\n", "\n", "How should we go about step 2?\n", "\n", "This is a problem of function approximation, and there are many ways to approach it.\n", "\n", "What’s important here is that the function approximation scheme must not only produce a good approximation to each $ v $, but also that it combines well with the broader iteration algorithm described above.\n", "\n", "One good choice from both respects is continuous piecewise linear interpolation.\n", "\n", "This method\n", "\n", "1. combines well with value function iteration (see, e.g.,\n", " [[Gordon, 1995](https://python.quantecon.org/zreferences.html#id63)] or [[Stachurski, 2008](https://python.quantecon.org/zreferences.html#id62)]) and \n", "1. preserves useful shape properties such as monotonicity and concavity/convexity. \n", "\n", "\n", "Linear interpolation will be implemented using JAX’s interpolation function `jnp.interp`.\n", "\n", "The next figure illustrates piecewise linear interpolation of an arbitrary function on grid points $ 0, 0.2, 0.4, 0.6, 0.8, 1 $." ] }, { "cell_type": "code", "execution_count": null, "id": "ca640c74", "metadata": { "hide-output": false }, "outputs": [], "source": [ "def f(x):\n", " y1 = 2 * jnp.cos(6 * x) + jnp.sin(14 * x)\n", " return y1 + 2.5\n", "\n", "c_grid = jnp.linspace(0, 1, 6)\n", "f_grid = jnp.linspace(0, 1, 150)\n", "\n", "def Af(x):\n", " return jnp.interp(x, c_grid, f(c_grid))\n", "\n", "fig, ax = plt.subplots()\n", "\n", "ax.plot(f_grid, f(f_grid), 'b-', label='true function')\n", "ax.plot(f_grid, Af(f_grid), 'g-', label='linear approximation')\n", "ax.vlines(c_grid, c_grid * 0, f(c_grid), linestyle='dashed', alpha=0.5)\n", "\n", "ax.legend(loc=\"upper center\")\n", "\n", "ax.set(xlim=(0, 1), ylim=(0, 6))\n", "plt.show()" ] }, { "cell_type": "markdown", "id": "31f955ec", "metadata": {}, "source": [ "## Implementation\n", "\n", "Let’s code up and solve the model." ] }, { "cell_type": "markdown", "id": "77c81b19", "metadata": {}, "source": [ "### Setup\n", "\n", "The first step is to build a JAX-compatible structure for the McCall model with\n", "separation and a continuous wage offer distribution.\n", "\n", "The key computational challenge is evaluating the conditional expectation\n", "$ (Pv_u)(w) = \\int v_u(w') p(w, w') dw' $ at each wage grid point.\n", "\n", "Recall that we have:\n", "\n", "$$\n", "(Pv_u)(w) = \\int v_u(w^\\rho \\exp(\\nu z)) \\psi(z) dz\n", "$$\n", "\n", "where $ \\psi $ is the standard normal density.\n", "\n", "We will approximate this integral using Monte Carlo integration with draws $ \\{Z_i\\} $ from the standard normal distribution:\n", "\n", "$$\n", "(Pv_u)(w) \\approx \\frac{1}{N} \\sum_{i=1}^N v_u(w^\\rho \\exp(\\nu Z_i))\n", "$$\n", "\n", "For this reason, our data structure will include a fixed set of IID $ N(0,1) $ draws $ \\{Z_i\\} $." ] }, { "cell_type": "code", "execution_count": null, "id": "6231e6a9", "metadata": { "hide-output": false }, "outputs": [], "source": [ "class Model(NamedTuple):\n", " c: float # unemployment compensation\n", " α: float # job separation rate\n", " β: float # discount factor\n", " ρ: float # wage persistence\n", " ν: float # wage volatility\n", " γ: float # utility parameter\n", " w_grid: jnp.ndarray # grid of points for fitted VFI\n", " z_draws: jnp.ndarray # draws from the standard normal distribution\n", "\n", "def create_mccall_model(\n", " c: float = 1.0,\n", " α: float = 0.05,\n", " β: float = 0.96,\n", " ρ: float = 0.9,\n", " ν: float = 0.2,\n", " γ: float = 1.5,\n", " grid_size: int = 100,\n", " mc_size: int = 1000,\n", " seed: int = 1234\n", " ):\n", " \"\"\"Factory function to create a McCall model instance.\"\"\"\n", "\n", " key = jax.random.PRNGKey(seed)\n", " z_draws = jax.random.normal(key, (mc_size,))\n", "\n", " # Discretize just to get a suitable wage grid for interpolation\n", " mc = qe.markov.tauchen(grid_size, ρ, ν)\n", " w_grid = jnp.exp(jnp.array(mc.state_values))\n", "\n", " return Model(c, α, β, ρ, ν, γ, w_grid, z_draws)" ] }, { "cell_type": "markdown", "id": "c0942e94", "metadata": {}, "source": [ "We use the same CRRA utility function as in the discrete case:" ] }, { "cell_type": "code", "execution_count": null, "id": "0af9375a", "metadata": { "hide-output": false }, "outputs": [], "source": [ "def u(x, γ):\n", " return (x**(1 - γ) - 1) / (1 - γ)" ] }, { "cell_type": "markdown", "id": "99b45367", "metadata": {}, "source": [ "### Iteration\n", "\n", "Here is the Bellman operator, where we use Monte Carlo integration to evaluate the expectation." ] }, { "cell_type": "code", "execution_count": null, "id": "5f9735a9", "metadata": { "hide-output": false }, "outputs": [], "source": [ "def T(model, v):\n", " \"\"\"Update the value function.\"\"\"\n", "\n", " # Unpack model parameters\n", " c, α, β, ρ, ν, γ, w_grid, z_draws = model\n", "\n", " # Interpolate array represented value function\n", " vf = lambda x: jnp.interp(x, w_grid, v)\n", "\n", " def compute_expectation(w):\n", " # Use Monte Carlo to evaluate integral (P v)(w) = E[v(W' | w)] \n", " # where W' = w^ρ * exp(ν * Z)\n", " w_next = w**ρ * jnp.exp(ν * z_draws)\n", " return jnp.mean(vf(w_next))\n", "\n", " compute_exp_on_grid = jax.vmap(compute_expectation)\n", " Pv = compute_exp_on_grid(w_grid)\n", "\n", " d = 1 / (1 - β * (1 - α))\n", " v_e = d * (u(w_grid, γ) + α * β * Pv)\n", " continuation_values = u(c, γ) + β * Pv\n", " return jnp.maximum(v_e, continuation_values)" ] }, { "cell_type": "markdown", "id": "377156b5", "metadata": {}, "source": [ "Here’s the solver, which computes an approximate fixed point $ v_u $ of $ T $." ] }, { "cell_type": "code", "execution_count": null, "id": "0cbf542e", "metadata": { "hide-output": false }, "outputs": [], "source": [ "@jax.jit\n", "def vfi(\n", " model: Model,\n", " tolerance: float = 1e-6, # Error tolerance\n", " max_iter: int = 100_000, # Max iteration bound\n", " ):\n", " \"\"\"\n", " Compute the fixed point v_u of T.\n", "\n", " \"\"\"\n", "\n", " v_init = jnp.zeros(model.w_grid.shape)\n", "\n", " def cond(loop_state):\n", " v, error, i = loop_state\n", " return (error > tolerance) & (i <= max_iter)\n", "\n", " def update(loop_state):\n", " v, error, i = loop_state\n", " v_new = T(model, v)\n", " error = jnp.max(jnp.abs(v_new - v))\n", " new_loop_state = v_new, error, i + 1\n", " return new_loop_state\n", "\n", " initial_state = (v_init, tolerance + 1, 1)\n", " final_loop_state = lax.while_loop(cond, update, initial_state)\n", " v_final, error, i = final_loop_state\n", "\n", " return v_final" ] }, { "cell_type": "markdown", "id": "9dc6f0f2", "metadata": {}, "source": [ "Here’s a function that uses a solution $ v_u $ to compute the remaining functions of\n", "interest: $ v_e $, and the continuation value function $ h $.\n", "\n", "We use the same expressions as we did in the [discrete case](https://python.quantecon.org/mccall_model_with_sep_markov.html), after replacing sums with integrals." ] }, { "cell_type": "code", "execution_count": null, "id": "d8a8a6ac", "metadata": { "hide-output": false }, "outputs": [], "source": [ "def compute_solution_functions(model, v_u):\n", "\n", " # Unpack model parameters\n", " c, α, β, ρ, ν, γ, w_grid, z_draws = model\n", "\n", " # Interpolate v_u on the wage grid\n", " vf = lambda x: jnp.interp(x, w_grid, v_u)\n", "\n", " def compute_expectation(w):\n", " # Use Monte Carlo to evaluate integral (P v)(w)\n", " # Compute E[v(w' | w)] where w' = w^ρ * exp(ν * z)\n", " w_next = w**ρ * jnp.exp(ν * z_draws)\n", " return jnp.mean(vf(w_next))\n", "\n", " compute_exp_on_grid = jax.vmap(compute_expectation)\n", " Pv = compute_exp_on_grid(w_grid)\n", "\n", " d = 1 / (1 - β * (1 - α))\n", " v_e = d * (u(w_grid, γ) + α * β * Pv)\n", " h = u(c, γ) + β * Pv\n", "\n", " return v_e, h" ] }, { "cell_type": "markdown", "id": "09f7c4fb", "metadata": {}, "source": [ "Let’s try solving the model:" ] }, { "cell_type": "code", "execution_count": null, "id": "1bf1befe", "metadata": { "hide-output": false }, "outputs": [], "source": [ "model = create_mccall_model()\n", "c, α, β, ρ, ν, γ, w_grid, z_draws = model\n", "v_u = vfi(model)\n", "v_e, h = compute_solution_functions(model, v_u)" ] }, { "cell_type": "markdown", "id": "fabd2ef6", "metadata": {}, "source": [ "Let’s plot our results." ] }, { "cell_type": "code", "execution_count": null, "id": "544b4ded", "metadata": { "hide-output": false }, "outputs": [], "source": [ "fig, ax = plt.subplots(figsize=(9, 5.2))\n", "ax.plot(w_grid, h, 'g-', linewidth=2,\n", " label=\"continuation value function $h$\")\n", "ax.plot(w_grid, v_e, 'b-', linewidth=2,\n", " label=\"employment value function $v_e$\")\n", "ax.legend(frameon=False)\n", "ax.set_xlabel(r\"$w$\")\n", "plt.show()" ] }, { "cell_type": "markdown", "id": "0c806ecb", "metadata": {}, "source": [ "The reservation wage is at the intersection of the employment value function $ v_e $ and the continuation value function $ h $.\n", "\n", "Here’s a function to compute it explicitly." ] }, { "cell_type": "code", "execution_count": null, "id": "ed98137b", "metadata": { "hide-output": false }, "outputs": [], "source": [ "@jax.jit\n", "def get_reservation_wage(model: Model) -> float:\n", " \"\"\"\n", " Calculate the reservation wage for a given model.\n", "\n", " \"\"\"\n", " c, α, β, ρ, ν, γ, w_grid, z_draws = model\n", "\n", " v_u = vfi(model)\n", " v_e, h = compute_solution_functions(model, v_u)\n", "\n", " # Compute optimal policy (acceptance indices)\n", " σ = v_e >= h\n", "\n", " # Find first index where policy indicates acceptance\n", " first_accept_idx = jnp.argmax(σ) # returns first True value\n", "\n", " # If no acceptance (all False), return infinity\n", " # Otherwise return the wage at the first acceptance index\n", " return jnp.where(jnp.any(σ), w_grid[first_accept_idx], jnp.inf)" ] }, { "cell_type": "markdown", "id": "d96ac31a", "metadata": {}, "source": [ "Let’s repeat our plot, but now inserting the reservation wage." ] }, { "cell_type": "code", "execution_count": null, "id": "9a98169c", "metadata": { "hide-output": false }, "outputs": [], "source": [ "w_bar = get_reservation_wage(model)\n", "\n", "fig, ax = plt.subplots(figsize=(9, 5.2))\n", "ax.plot(w_grid, h, 'g-', linewidth=2,\n", " label=\"continuation value function $h$\")\n", "ax.plot(w_grid, v_e, 'b-', linewidth=2,\n", " label=\"employment value function $v_e$\")\n", "ax.axvline(x=w_bar, color='black', linestyle='--', alpha=0.8,\n", " label=f'reservation wage $\\\\bar{{w}}$')\n", "ax.legend(frameon=False)\n", "ax.set_xlabel(r\"$w$\")\n", "plt.show()" ] }, { "cell_type": "markdown", "id": "7b46bf57", "metadata": {}, "source": [ "## Simulation\n", "\n", "Now we run some simulations with a focus on unemployment rate." ] }, { "cell_type": "markdown", "id": "595c8eab", "metadata": {}, "source": [ "### Single agent dynamics\n", "\n", "Let’s simulate the employment path of a single agent under the optimal policy.\n", "\n", "We need a function to update the agent’s state by one period." ] }, { "cell_type": "code", "execution_count": null, "id": "a3e6378b", "metadata": { "hide-output": false }, "outputs": [], "source": [ "def update_agent(key, status, wage, model, w_bar):\n", " \"\"\"\n", " Updates an agent's employment status and current wage by one period.\n", "\n", " Parameters:\n", " - key: JAX random key\n", " - status: Current employment status (0 or 1)\n", " - wage: Current wage if employed, current offer if unemployed\n", " - model: Model instance\n", " - w_bar: Reservation wage\n", "\n", " \"\"\"\n", " c, α, β, ρ, ν, γ, w_grid, z_draws = model\n", "\n", " # Draw new wage offer based on current wage\n", " key1, key2 = jax.random.split(key)\n", " z = jax.random.normal(key1)\n", " new_wage = wage**ρ * jnp.exp(ν * z)\n", "\n", " # Check if separation occurs (for employed workers)\n", " separation_occurs = jax.random.uniform(key2) < α\n", "\n", " # Accept if current wage meets or exceeds reservation wage\n", " accepts = wage >= w_bar\n", "\n", " # If employed: status = 1 if no separation, 0 if separation\n", " # If unemployed: status = 1 if accepts, 0 if rejects\n", " next_status = jnp.where(\n", " status,\n", " 1 - separation_occurs.astype(jnp.int32), # employed path\n", " accepts.astype(jnp.int32) # unemployed path\n", " )\n", "\n", " # If employed: wage = current if no separation, new if separation\n", " # If unemployed: wage = current if accepts, new if rejects\n", " next_wage = jnp.where(\n", " status,\n", " jnp.where(separation_occurs, new_wage, wage), # employed path\n", " jnp.where(accepts, wage, new_wage) # unemployed path\n", " )\n", "\n", " return next_status, next_wage" ] }, { "cell_type": "markdown", "id": "8e4c06a5", "metadata": {}, "source": [ "Here’s a function to simulate the employment path of a single agent." ] }, { "cell_type": "code", "execution_count": null, "id": "5fb222fb", "metadata": { "hide-output": false }, "outputs": [], "source": [ "def simulate_employment_path(\n", " model: Model, # Model details\n", " w_bar: float, # Reservation wage\n", " T: int = 2_000, # Simulation length\n", " seed: int = 42 # Set seed for simulation\n", " ):\n", " \"\"\"\n", " Simulate employment path for T periods starting from unemployment.\n", "\n", " \"\"\"\n", " key = jax.random.PRNGKey(seed)\n", " c, α, β, ρ, ν, γ, w_grid, z_draws = model\n", "\n", " # Initial conditions: start unemployed with initial wage draw\n", " status = 0\n", " key, subkey = jax.random.split(key)\n", " wage = jnp.exp(jax.random.normal(subkey) * ν)\n", "\n", " wage_path = []\n", " status_path = []\n", "\n", " for t in range(T):\n", " wage_path.append(wage)\n", " status_path.append(status)\n", "\n", " key, subkey = jax.random.split(key)\n", " status, wage = update_agent(\n", " subkey, status, wage, model, w_bar\n", " )\n", "\n", " return jnp.array(wage_path), jnp.array(status_path)" ] }, { "cell_type": "markdown", "id": "44d49cd0", "metadata": {}, "source": [ "Let’s create a comprehensive plot of the employment simulation:" ] }, { "cell_type": "code", "execution_count": null, "id": "30df44de", "metadata": { "hide-output": false }, "outputs": [], "source": [ "model = create_mccall_model()\n", "w_bar = get_reservation_wage(model)\n", "\n", "wage_path, employment_status = simulate_employment_path(model, w_bar)\n", "\n", "fig, (ax1, ax2, ax3) = plt.subplots(3, 1, figsize=(8, 6))\n", "\n", "# Plot employment status\n", "ax1.plot(employment_status, 'b-', alpha=0.7, linewidth=1)\n", "ax1.fill_between(\n", " range(len(employment_status)), employment_status, alpha=0.3, color='blue'\n", ")\n", "ax1.set_ylabel('employment status')\n", "ax1.set_title('Employment path (0=unemployed, 1=employed)')\n", "ax1.set_yticks((0, 1))\n", "ax1.set_ylim(-0.1, 1.1)\n", "\n", "# Plot wage path with reservation wage\n", "ax2.plot(wage_path, 'b-', alpha=0.7, linewidth=1)\n", "ax2.axhline(y=w_bar, color='black', linestyle='--', alpha=0.8,\n", " label=f'Reservation wage: {w_bar:.2f}')\n", "ax2.set_xlabel('time')\n", "ax2.set_ylabel('wage')\n", "ax2.set_title('Wage path (actual and offers)')\n", "ax2.legend()\n", "\n", "# Plot cumulative fraction of time unemployed\n", "unemployed_indicator = (employment_status == 0).astype(int)\n", "cumulative_unemployment = (\n", " jnp.cumsum(unemployed_indicator) /\n", " jnp.arange(1, len(employment_status) + 1)\n", ")\n", "\n", "ax3.plot(cumulative_unemployment, 'r-', alpha=0.8, linewidth=2)\n", "ax3.axhline(y=jnp.mean(unemployed_indicator), color='black',\n", " linestyle='--', alpha=0.7,\n", " label=f'Final rate: {jnp.mean(unemployed_indicator):.3f}')\n", "ax3.set_xlabel('time')\n", "ax3.set_ylabel('cumulative unemployment rate')\n", "ax3.set_title('Cumulative fraction of time spent unemployed')\n", "ax3.legend()\n", "ax3.set_ylim(0, 1)\n", "\n", "plt.tight_layout()\n", "plt.show()" ] }, { "cell_type": "markdown", "id": "8a958b3c", "metadata": {}, "source": [ "The simulation shows the agent cycling between employment and unemployment.\n", "\n", "The agent starts unemployed and receives wage offers according to the Markov process.\n", "\n", "When unemployed, the agent accepts offers that exceed the reservation wage.\n", "\n", "When employed, the agent faces job separation with probability $ \\alpha $ each period." ] }, { "cell_type": "markdown", "id": "1d461a6d", "metadata": {}, "source": [ "### Cross-sectional analysis\n", "\n", "Now let’s simulate many agents simultaneously to examine the cross-sectional unemployment rate.\n", "\n", "To do this efficiently, we need a different approach than `simulate_employment_path` defined above.\n", "\n", "The key differences are:\n", "\n", "- `simulate_employment_path` records the entire history (all T periods) for a single agent, which is useful for visualization but memory-intensive \n", "- The new function `sim_agent` below only tracks and returns the final state, which is all we need for cross-sectional statistics \n", "- `sim_agent` uses `lax.fori_loop` instead of a Python loop, making it JIT-compilable and suitable for vectorization across many agents \n", "\n", "\n", "We first define a function that simulates a single agent forward T time steps:" ] }, { "cell_type": "code", "execution_count": null, "id": "b4cc15fe", "metadata": { "hide-output": false }, "outputs": [], "source": [ "@jax.jit\n", "def sim_agent(key, initial_status, initial_wage, model, w_bar, T):\n", " \"\"\"\n", " Simulate a single agent forward T time steps using lax.fori_loop.\n", "\n", " Uses fold_in to generate a new key at each time step.\n", "\n", " Parameters:\n", " - key: JAX random key for this agent\n", " - initial_status: Initial employment status (0 or 1)\n", " - initial_wage: Initial wage\n", " - model: Model instance\n", " - w_bar: Reservation wage\n", " - T: Number of time periods to simulate\n", "\n", " Returns:\n", " - final_status: Employment status after T periods\n", " - final_wage: Wage after T periods\n", " \"\"\"\n", " def update(t, loop_state):\n", " status, wage = loop_state\n", " step_key = jax.random.fold_in(key, t)\n", " status, wage = update_agent(step_key, status, wage, model, w_bar)\n", " return status, wage\n", "\n", " initial_loop_state = (initial_status, initial_wage)\n", " final_loop_state = lax.fori_loop(0, T, update, initial_loop_state)\n", " final_status, final_wage = final_loop_state\n", " return final_status, final_wage\n", "\n", "\n", "# Create vectorized version of sim_agent to process multiple agents in parallel\n", "sim_agents_vmap = jax.vmap(sim_agent, in_axes=(0, 0, 0, None, None, None))\n", "\n", "\n", "def simulate_cross_section(\n", " model: Model,\n", " n_agents: int = 100_000,\n", " T: int = 200,\n", " seed: int = 42\n", " ) -> float:\n", " \"\"\"\n", " Simulate cross-section of agents and return unemployment rate.\n", "\n", " This approach:\n", " 1. Generates n_agents random keys\n", " 2. Calls sim_agent for each agent (vectorized via vmap)\n", " 3. Collects the final states to produce the cross-section\n", "\n", " Returns the cross-sectional unemployment rate.\n", " \"\"\"\n", " c, α, β, ρ, ν, γ, w_grid, z_draws = model\n", "\n", " key = jax.random.PRNGKey(seed)\n", "\n", " # Solve for optimal reservation wage\n", " w_bar = get_reservation_wage(model)\n", "\n", " # Initialize arrays\n", " init_key, subkey = jax.random.split(key)\n", " initial_wages = jnp.exp(jax.random.normal(subkey, (n_agents,)) * ν)\n", " initial_status_vec = jnp.zeros(n_agents, dtype=jnp.int32)\n", "\n", " # Generate n_agents random keys\n", " agent_keys = jax.random.split(init_key, n_agents)\n", "\n", " # Simulate each agent forward T steps (vectorized)\n", " final_status, final_wages = sim_agents_vmap(\n", " agent_keys, initial_status_vec, initial_wages, model, w_bar, T\n", " )\n", "\n", " unemployment_rate = 1 - jnp.mean(final_status)\n", " return unemployment_rate" ] }, { "cell_type": "markdown", "id": "73db7ab8", "metadata": {}, "source": [ "Now let’s compare the time-average unemployment rate (from a single agent’s long\n", "simulation) with the cross-sectional unemployment rate (from many agents at a\n", "single point in time)." ] }, { "cell_type": "code", "execution_count": null, "id": "8197577c", "metadata": { "hide-output": false }, "outputs": [], "source": [ "model = create_mccall_model()\n", "cross_sectional_unemp = simulate_cross_section(\n", " model, n_agents=20_000, T=200\n", ")\n", "\n", "time_avg_unemp = jnp.mean(unemployed_indicator)\n", "print(f\"Time-average unemployment rate (single agent, T=2000): \"\n", " f\"{time_avg_unemp:.4f}\")\n", "print(f\"Cross-sectional unemployment rate (at t=200): \"\n", " f\"{cross_sectional_unemp:.4f}\")\n", "print(f\"Difference: {abs(time_avg_unemp - cross_sectional_unemp):.4f}\")" ] }, { "cell_type": "markdown", "id": "45d01622", "metadata": {}, "source": [ "The difference above can be further reduced by increasing the simulation length for the single agent." ] }, { "cell_type": "code", "execution_count": null, "id": "ddfdbb81", "metadata": { "hide-output": false }, "outputs": [], "source": [ "wage_path_long, employment_status_long = simulate_employment_path(model, w_bar, T=10_000)\n", "unemployed_indicator_long = (employment_status_long == 0).astype(int)\n", "time_avg_unemp_long = jnp.mean(unemployed_indicator_long)\n", "\n", "print(f\"Time-average unemployment rate (single agent, T=10000): \"\n", " f\"{time_avg_unemp_long:.4f}\")\n", "print(f\"Cross-sectional unemployment rate (at t=200): \"\n", " f\"{cross_sectional_unemp:.4f}\")\n", "print(f\"Difference: {abs(time_avg_unemp_long - cross_sectional_unemp):.4f}\")" ] }, { "cell_type": "markdown", "id": "668ef5fc", "metadata": {}, "source": [ "### Visualization\n", "\n", "This function generates a histogram showing the distribution of employment status across many agents:" ] }, { "cell_type": "code", "execution_count": null, "id": "bb4abdbc", "metadata": { "hide-output": false }, "outputs": [], "source": [ "def plot_cross_sectional_unemployment(\n", " model: Model, # Model instance with parameters\n", " t_snapshot: int = 200, # Time for cross-sectional snapshot\n", " n_agents: int = 20_000 # Number of agents to simulate\n", " ):\n", " \"\"\"\n", " Generate histogram of cross-sectional unemployment at a specific time.\n", "\n", " \"\"\"\n", " c, α, β, ρ, ν, γ, w_grid, z_draws = model\n", "\n", " # Get final employment state directly\n", " key = jax.random.PRNGKey(42)\n", " w_bar = get_reservation_wage(model)\n", "\n", " # Initialize arrays\n", " init_key, subkey = jax.random.split(key)\n", " initial_wages = jnp.exp(jax.random.normal(subkey, (n_agents,)) * ν)\n", " initial_status_vec = jnp.zeros(n_agents, dtype=jnp.int32)\n", "\n", " # Generate n_agents random keys\n", " agent_keys = jax.random.split(init_key, n_agents)\n", "\n", " # Simulate each agent forward T steps (vectorized)\n", " final_status, _ = sim_agents_vmap(\n", " agent_keys, initial_status_vec, initial_wages, model, w_bar, t_snapshot\n", " )\n", "\n", " # Calculate unemployment rate\n", " unemployment_rate = 1 - jnp.mean(final_status)\n", "\n", " fig, ax = plt.subplots(figsize=(8, 5))\n", "\n", " # Plot histogram as density (bars sum to 1)\n", " weights = jnp.ones_like(final_status) / len(final_status)\n", " ax.hist(final_status, bins=[-0.5, 0.5, 1.5],\n", " alpha=0.7, color='blue', edgecolor='black',\n", " density=True, weights=weights)\n", "\n", " ax.set_xlabel('employment status (0=unemployed, 1=employed)')\n", " ax.set_ylabel('density')\n", " ax.set_title(f'Cross-sectional distribution at t={t_snapshot}, ' +\n", " f'unemployment rate = {unemployment_rate:.3f}')\n", " ax.set_xticks([0, 1])\n", "\n", " plt.tight_layout()\n", " plt.show()" ] }, { "cell_type": "markdown", "id": "85a9da10", "metadata": {}, "source": [ "Let’s plot the cross-sectional distribution:" ] }, { "cell_type": "code", "execution_count": null, "id": "fe110b9b", "metadata": { "hide-output": false }, "outputs": [], "source": [ "plot_cross_sectional_unemployment(model)" ] }, { "cell_type": "markdown", "id": "c0d86624", "metadata": {}, "source": [ "## Exercises" ] }, { "cell_type": "markdown", "id": "ff7ac074", "metadata": {}, "source": [ "## Exercise 45.1\n", "\n", "Use the code above to explore what happens to the reservation wage when $ c $ changes." ] }, { "cell_type": "markdown", "id": "f2db5ea7", "metadata": {}, "source": [ "## Solution\n", "\n", "Here is one solution" ] }, { "cell_type": "code", "execution_count": null, "id": "8a99596b", "metadata": { "hide-output": false }, "outputs": [], "source": [ "def compute_res_wage_given_c(c):\n", " model = create_mccall_model(c=c)\n", " w_bar = get_reservation_wage(model)\n", " return w_bar\n", "\n", "c_vals = jnp.linspace(0.0, 2.0, 15)\n", "w_bar_vals = jax.vmap(compute_res_wage_given_c)(c_vals)\n", "\n", "fig, ax = plt.subplots()\n", "ax.set(xlabel='unemployment compensation', ylabel='reservation wage')\n", "ax.plot(c_vals, w_bar_vals, label=r'$\\bar w$ as a function of $c$')\n", "ax.legend()\n", "plt.show()" ] }, { "cell_type": "markdown", "id": "c192e64b", "metadata": {}, "source": [ "As unemployment compensation increases, the reservation wage also increases.\n", "\n", "This makes economic sense: when the value of being unemployed rises (through higher $ c $), workers become more selective about which job offers to accept." ] }, { "cell_type": "markdown", "id": "622446c2", "metadata": {}, "source": [ "## Exercise 45.2\n", "\n", "Create a plot that shows how the reservation wage changes with the risk aversion parameter $ \\gamma $.\n", "\n", "Use `γ_vals = jnp.linspace(1.2, 2.5, 15)` and keep all other parameters at their default values.\n", "\n", "How do you expect the reservation wage to vary with $ \\gamma $? Why?" ] }, { "cell_type": "markdown", "id": "235f3603", "metadata": {}, "source": [ "## Solution\n", "\n", "We compute the reservation wage for different values of the risk aversion parameter:" ] }, { "cell_type": "code", "execution_count": null, "id": "e3b65d64", "metadata": { "hide-output": false }, "outputs": [], "source": [ "γ_vals = jnp.linspace(1.2, 2.5, 15)\n", "w_bar_vec = jnp.empty_like(γ_vals)\n", "\n", "for i, γ in enumerate(γ_vals):\n", " model = create_mccall_model(γ=γ)\n", " w_bar = get_reservation_wage(model)\n", " w_bar_vec = w_bar_vec.at[i].set(w_bar)\n", "\n", "fig, ax = plt.subplots(figsize=(9, 5.2))\n", "ax.plot(γ_vals, w_bar_vec, linewidth=2, alpha=0.6,\n", " label='reservation wage')\n", "ax.legend(frameon=False)\n", "ax.set_xlabel(r'$\\gamma$')\n", "ax.set_ylabel(r'$\\bar{w}$')\n", "ax.set_title('Reservation wage as a function of risk aversion')\n", "plt.show()" ] }, { "cell_type": "markdown", "id": "f2f89d90", "metadata": {}, "source": [ "As risk aversion ($ \\gamma $) increases, the reservation wage decreases.\n", "\n", "This occurs because more risk-averse workers place higher value on the security\n", "of employment relative to the uncertainty of continued search.\n", "\n", "With higher $ \\gamma $, the utility cost of unemployment (foregone consumption)\n", "becomes more severe, making workers more willing to accept lower wages rather\n", "than continue searching." ] } ], "metadata": { "date": 1770028421.002884, "filename": "mccall_fitted_vfi.md", "kernelspec": { "display_name": "Python", "language": "python3", "name": "python3" }, "title": "Job Search IV: Fitted Value Function Iteration" }, "nbformat": 4, "nbformat_minor": 5 }