{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Occasionally binding constraint model - Solved with ANN (Flux.jl)\n",
"\n",
"This notebook solves the Bianchi (2011) model with occasionally binding collateral constraints using an Artificial Neural Network approach.\n",
"\n",
"**Method**: Fischer-Burmeister complementarity formulation with neural network approximation"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Setup: Load Required Packages"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"using Flux\n",
"using Plots\n",
"using ProgressMeter\n",
"using LaTeXStrings\n",
"using Random\n",
"using Statistics\n",
"using Distributions\n",
"using Flux.NNlib\n",
"using CUDA\n",
"using cuDNN"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## GPU Device Selection\n",
"\n",
"Check if CUDA is available for GPU acceleration. The code will automatically fall back to CPU if not."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"if CUDA.has_cuda()\n",
" @info \"CUDA is on\"\n",
" CUDA.allowscalar(false)\n",
" const device = gpu\n",
"else\n",
" @info \"CUDA is off\"\n",
" const device = cpu\n",
"end"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Model Calibration\n",
"\n",
"Standard parameters from Bianchi (2011):\n",
"- σ: Risk aversion\n",
"- κ: Borrowing constraint parameter\n",
"- β: Discount factor\n",
"- ω: Weight on traded goods\n",
"- η: Elasticity of substitution"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Model calibrated with σ=2.0, κ=0.2808, β=0.91\n"
]
}
],
"source": [
"# Precision type (Float32 for faster GPU computation)\n",
"const FTYPE = Float32\n",
"\n",
"# Structural parameters\n",
"const σ = FTYPE(2.0) # Inverse of intertemporal elasticity\n",
"const κ = FTYPE(0.2808) # Borrowing constraint parameter\n",
"const β = FTYPE(0.91) # Discount factor\n",
"const ω = FTYPE(0.31) # Weight on traded goods\n",
"const η = FTYPE(0.5) # Elasticity of substitution\n",
"const δ = FTYPE(0.14) # Bond duration\n",
"const τ = FTYPE(0.2) # Tax rate\n",
"\n",
"# Endowments\n",
"const yT = FTYPE(1.0) # Traded goods\n",
"const yN = FTYPE(1.0) # Non-traded goods\n",
"\n",
"# Debt grid\n",
"const l_min, l_max = FTYPE(0.1), FTYPE(0.9)\n",
"const n_l = 300\n",
"const l_grid = collect(range(l_min, l_max, length=n_l))\n",
"\n",
"# Interest rate\n",
"const i_l = FTYPE(0.05)\n",
"\n",
"println(\"Model calibrated with σ=$σ, κ=$κ, β=$β\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Stochastic Shocks\n",
"\n",
"The model includes two shocks:\n",
"1. **ν**: Government default cost\n",
"2. **φ**: Private sector default rate\n",
"\n",
"These are discretized into 25 states with transition matrix Π."
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Shocks discretized: 25 states\n"
]
}
],
"source": [
"# Pre-computed transition matrix (25x25)\n",
"const Π = FTYPE.([\n",
" 2.1212e-01 5.3030e-01 0.0000e+00 0.0000e+00 0.0000e+00 1.2121e-01 1.3636e-01 0.0000e+00 0.0000e+00 0.0000e+00 0.0000e+00 0.0000e+00 0.0000e+00 0.0000e+00 0.0000e+00 0.0000e+00 0.0000e+00 0.0000e+00 0.0000e+00 0.0000e+00 0.0000e+00 0.0000e+00 0.0000e+00 0.0000e+00 0.0000e+00;\n",
" 1.7268e-02 4.6834e-01 2.5432e-01 5.2329e-04 0.0000e+00 7.8493e-03 1.6431e-01 8.6866e-02 5.2329e-04 0.0000e+00 0.0000e+00 0.0000e+00 0.0000e+00 0.0000e+00 0.0000e+00 0.0000e+00 0.0000e+00 0.0000e+00 0.0000e+00 0.0000e+00 0.0000e+00 0.0000e+00 0.0000e+00 0.0000e+00 0.0000e+00;\n",
" 0.0000e+00 8.7662e-02 5.5571e-01 8.9786e-02 0.0000e+00 0.0000e+00 3.1087e-02 2.0429e-01 3.1280e-02 1.9309e-04 0.0000e+00 0.0000e+00 0.0000e+00 0.0000e+00 0.0000e+00 0.0000e+00 0.0000e+00 0.0000e+00 0.0000e+00 0.0000e+00 0.0000e+00 0.0000e+00 0.0000e+00 0.0000e+00 0.0000e+00;\n",
" 0.0000e+00 0.0000e+00 2.4401e-01 4.7631e-01 2.0886e-02 0.0000e+00 0.0000e+00 8.1508e-02 1.6913e-01 8.1508e-03 0.0000e+00 0.0000e+00 0.0000e+00 0.0000e+00 0.0000e+00 0.0000e+00 0.0000e+00 0.0000e+00 0.0000e+00 0.0000e+00 0.0000e+00 0.0000e+00 0.0000e+00 0.0000e+00 0.0000e+00;\n",
" 0.0000e+00 0.0000e+00 0.0000e+00 3.9423e-01 3.0769e-01 0.0000e+00 0.0000e+00 9.6154e-03 1.2500e-01 1.6346e-01 0.0000e+00 0.0000e+00 0.0000e+00 0.0000e+00 0.0000e+00 0.0000e+00 0.0000e+00 0.0000e+00 0.0000e+00 0.0000e+00 0.0000e+00 0.0000e+00 0.0000e+00 0.0000e+00 0.0000e+00;\n",
" 3.3784e-03 5.0676e-03 0.0000e+00 0.0000e+00 0.0000e+00 3.3615e-01 5.1464e-01 1.0135e-02 0.0000e+00 0.0000e+00 5.0113e-02 7.9392e-02 1.1261e-03 0.0000e+00 0.0000e+00 0.0000e+00 0.0000e+00 0.0000e+00 0.0000e+00 0.0000e+00 0.0000e+00 0.0000e+00 0.0000e+00 0.0000e+00 0.0000e+00;\n",
" 3.0430e-04 8.0991e-03 4.0963e-03 2.3408e-05 0.0000e+00 2.0458e-02 5.4219e-01 2.8665e-01 1.1938e-03 0.0000e+00 3.3473e-03 8.6234e-02 4.7213e-02 1.8726e-04 0.0000e+00 0.0000e+00 0.0000e+00 0.0000e+00 0.0000e+00 0.0000e+00 0.0000e+00 0.0000e+00 0.0000e+00 0.0000e+00 0.0000e+00;\n",
" 0.0000e+00 1.4610e-03 8.4155e-03 1.3157e-03 1.7087e-05 1.1107e-04 1.0535e-01 6.4296e-01 1.0445e-01 1.2815e-04 5.1262e-05 1.5806e-02 1.0315e-01 1.6771e-02 1.7087e-05 0.0000e+00 0.0000e+00 0.0000e+00 0.0000e+00 0.0000e+00 0.0000e+00 0.0000e+00 0.0000e+00 0.0000e+00 0.0000e+00;\n",
" 0.0000e+00 2.3579e-05 4.1500e-03 8.2999e-03 4.9517e-04 0.0000e+00 1.1082e-03 2.8830e-01 5.3869e-01 2.2330e-02 0.0000e+00 3.3011e-04 4.7701e-02 8.5640e-02 2.9238e-03 0.0000e+00 0.0000e+00 0.0000e+00 0.0000e+00 0.0000e+00 0.0000e+00 0.0000e+00 0.0000e+00 0.0000e+00 0.0000e+00;\n",
" 0.0000e+00 0.0000e+00 0.0000e+00 7.2765e-03 4.1580e-03 0.0000e+00 0.0000e+00 6.2370e-03 5.0260e-01 3.4719e-01 0.0000e+00 0.0000e+00 5.1975e-04 7.6403e-02 5.5613e-02 0.0000e+00 0.0000e+00 0.0000e+00 0.0000e+00 0.0000e+00 0.0000e+00 0.0000e+00 0.0000e+00 0.0000e+00 0.0000e+00;\n",
" 0.0000e+00 0.0000e+00 0.0000e+00 0.0000e+00 0.0000e+00 2.1251e-02 2.8537e-02 4.0478e-04 0.0000e+00 0.0000e+00 3.6733e-01 5.3269e-01 9.1075e-03 0.0000e+00 0.0000e+00 1.4572e-02 2.6108e-02 0.0000e+00 0.0000e+00 0.0000e+00 0.0000e+00 0.0000e+00 0.0000e+00 0.0000e+00 0.0000e+00;\n",
" 0.0000e+00 0.0000e+00 0.0000e+00 0.0000e+00 0.0000e+00 1.3778e-03 3.1569e-02 1.7214e-02 1.0333e-04 0.0000e+00 2.2389e-02 5.7034e-01 3.0483e-01 1.4036e-03 0.0000e+00 1.4208e-03 3.1638e-02 1.7610e-02 1.0333e-04 0.0000e+00 0.0000e+00 0.0000e+00 0.0000e+00 0.0000e+00 0.0000e+00;\n",
" 0.0000e+00 0.0000e+00 0.0000e+00 0.0000e+00 0.0000e+00 1.2406e-05 6.0946e-03 3.7033e-02 6.1442e-03 1.5508e-05 1.0545e-04 1.1019e-01 6.7863e-01 1.0985e-01 1.3337e-04 3.1016e-06 6.3862e-03 3.9043e-02 6.3613e-03 9.3047e-06 0.0000e+00 0.0000e+00 0.0000e+00 0.0000e+00 0.0000e+00;\n",
" 0.0000e+00 0.0000e+00 0.0000e+00 0.0000e+00 0.0000e+00 0.0000e+00 9.4763e-05 1.6859e-02 3.1746e-02 1.3611e-03 0.0000e+00 1.5076e-03 3.0492e-01 5.6999e-01 2.3053e-02 0.0000e+00 5.1689e-05 1.7006e-02 3.1780e-02 1.6282e-03 0.0000e+00 0.0000e+00 0.0000e+00 0.0000e+00 0.0000e+00;\n",
" 0.0000e+00 0.0000e+00 0.0000e+00 0.0000e+00 0.0000e+00 0.0000e+00 0.0000e+00 8.0727e-04 2.7245e-02 1.9576e-02 0.0000e+00 0.0000e+00 9.0817e-03 5.3966e-01 3.5015e-01 0.0000e+00 0.0000e+00 0.0000e+00 3.1887e-02 2.1594e-02 0.0000e+00 0.0000e+00 0.0000e+00 0.0000e+00 0.0000e+00;\n",
" 0.0000e+00 0.0000e+00 0.0000e+00 0.0000e+00 0.0000e+00 0.0000e+00 0.0000e+00 0.0000e+00 0.0000e+00 0.0000e+00 5.8564e-02 8.0663e-02 2.2099e-03 0.0000e+00 0.0000e+00 3.4088e-01 4.9558e-01 9.3923e-03 0.0000e+00 0.0000e+00 4.4199e-03 8.2873e-03 0.0000e+00 0.0000e+00 0.0000e+00;\n",
" 0.0000e+00 0.0000e+00 0.0000e+00 0.0000e+00 0.0000e+00 0.0000e+00 0.0000e+00 0.0000e+00 0.0000e+00 0.0000e+00 3.3449e-03 8.5006e-02 4.7474e-02 2.5375e-04 0.0000e+00 2.0992e-02 5.3746e-01 2.9186e-01 1.4994e-03 0.0000e+00 2.5375e-04 7.6125e-03 4.1753e-03 6.9204e-05 0.0000e+00;\n",
" 0.0000e+00 0.0000e+00 0.0000e+00 0.0000e+00 0.0000e+00 0.0000e+00 0.0000e+00 0.0000e+00 0.0000e+00 0.0000e+00 2.5060e-05 1.6982e-02 1.0494e-01 1.6581e-02 4.1767e-05 1.4201e-04 1.0634e-01 6.3816e-01 1.0536e-01 1.2530e-04 0.0000e+00 1.3532e-03 8.5288e-03 1.4117e-03 8.3534e-06;\n",
" 0.0000e+00 0.0000e+00 0.0000e+00 0.0000e+00 0.0000e+00 0.0000e+00 0.0000e+00 0.0000e+00 0.0000e+00 0.0000e+00 0.0000e+00 2.0451e-04 4.5219e-02 8.4735e-02 3.7948e-03 0.0000e+00 1.7270e-03 2.8756e-01 5.4286e-01 2.3041e-02 0.0000e+00 4.5446e-05 3.4539e-03 7.0896e-03 2.7268e-04;\n",
" 0.0000e+00 0.0000e+00 0.0000e+00 0.0000e+00 0.0000e+00 0.0000e+00 0.0000e+00 0.0000e+00 0.0000e+00 0.0000e+00 0.0000e+00 0.0000e+00 0.0000e+00 9.3781e-02 4.8144e-02 0.0000e+00 0.0000e+00 1.2036e-02 5.0752e-01 3.2297e-01 0.0000e+00 0.0000e+00 0.0000e+00 9.0271e-03 6.5196e-03;\n",
" 0.0000e+00 0.0000e+00 0.0000e+00 0.0000e+00 0.0000e+00 0.0000e+00 0.0000e+00 0.0000e+00 0.0000e+00 0.0000e+00 0.0000e+00 0.0000e+00 0.0000e+00 0.0000e+00 0.0000e+00 2.0000e-01 1.1667e-01 0.0000e+00 0.0000e+00 0.0000e+00 2.1667e-01 4.5000e-01 1.6667e-02 0.0000e+00 0.0000e+00;\n",
" 0.0000e+00 0.0000e+00 0.0000e+00 0.0000e+00 0.0000e+00 0.0000e+00 0.0000e+00 0.0000e+00 0.0000e+00 0.0000e+00 0.0000e+00 0.0000e+00 0.0000e+00 0.0000e+00 0.0000e+00 8.3420e-03 1.6788e-01 9.2284e-02 1.0428e-03 0.0000e+00 1.4599e-02 4.5203e-01 2.6173e-01 2.0855e-03 0.0000e+00;\n",
" 0.0000e+00 0.0000e+00 0.0000e+00 0.0000e+00 0.0000e+00 0.0000e+00 0.0000e+00 0.0000e+00 0.0000e+00 0.0000e+00 0.0000e+00 0.0000e+00 0.0000e+00 0.0000e+00 0.0000e+00 0.0000e+00 2.7276e-02 1.8443e-01 2.8902e-02 1.8064e-04 0.0000e+00 9.2666e-02 5.7225e-01 9.4111e-02 1.8064e-04;\n",
" 0.0000e+00 0.0000e+00 0.0000e+00 0.0000e+00 0.0000e+00 0.0000e+00 0.0000e+00 0.0000e+00 0.0000e+00 0.0000e+00 0.0000e+00 0.0000e+00 0.0000e+00 0.0000e+00 0.0000e+00 0.0000e+00 0.0000e+00 8.6089e-02 1.8058e-01 6.8241e-03 0.0000e+00 1.0499e-03 2.6772e-01 4.3832e-01 1.9423e-02;\n",
" 0.0000e+00 0.0000e+00 0.0000e+00 0.0000e+00 0.0000e+00 0.0000e+00 0.0000e+00 0.0000e+00 0.0000e+00 0.0000e+00 0.0000e+00 0.0000e+00 0.0000e+00 0.0000e+00 0.0000e+00 0.0000e+00 0.0000e+00 0.0000e+00 1.3793e-01 9.1954e-02 0.0000e+00 0.0000e+00 1.1494e-02 4.9425e-01 2.6437e-01;\n",
"])\n",
"\n",
"# Shock grid (25 x 2): [ν, φ]\n",
"const S = FTYPE.([\n",
" -1.1021e-03 3.3916e-03;\n",
" -1.1021e-03 8.4394e-03;\n",
" -1.1021e-03 2.1000e-02;\n",
" -1.1021e-03 5.2255e-02;\n",
" -1.1021e-03 1.3003e-01;\n",
" 9.0945e-01 3.3916e-03;\n",
" 9.0945e-01 8.4394e-03;\n",
" 9.0945e-01 2.1000e-02;\n",
" 9.0945e-01 5.2255e-02;\n",
" 9.0945e-01 1.3003e-01;\n",
" 1.8200e+00 3.3916e-03;\n",
" 1.8200e+00 8.4394e-03;\n",
" 1.8200e+00 2.1000e-02;\n",
" 1.8200e+00 5.2255e-02;\n",
" 1.8200e+00 1.3003e-01;\n",
" 2.7306e+00 3.3916e-03;\n",
" 2.7306e+00 8.4394e-03;\n",
" 2.7306e+00 2.1000e-02;\n",
" 2.7306e+00 5.2255e-02;\n",
" 2.7306e+00 1.3003e-01;\n",
" 3.6411e+00 3.3916e-03;\n",
" 3.6411e+00 8.4394e-03;\n",
" 3.6411e+00 2.1000e-02;\n",
" 3.6411e+00 5.2255e-02;\n",
" 3.6411e+00 1.3003e-01;\n",
"])\n",
"\n",
"const ν_grid = S[:,1] # Government default cost\n",
"const ϕ_grid = S[:,2] # Private default rate\n",
"const n_s = size(S,1) # Number of states (25)\n",
"\n",
"println(\"Shocks discretized: $n_s states\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Economic Functions\n",
"\n",
"Key equilibrium conditions:\n",
"- Marginal utility of traded consumption\n",
"- Non-traded goods price (from intratemporal FOC)"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"price_nt (generic function with 1 method)"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"# Marginal utility of traded consumption\n",
"∂u∂cT(ct, yn; ω_p=ω, η_p=η, σ_p=σ) = ct > 0 ? \n",
" ω_p * ct^(-1/η_p) * (ω_p*ct^((η_p-1)/η_p) + (1-ω_p)*yn^((η_p-1)/η_p))^((1-σ_p*η_p)/(η_p-1)) : \n",
" FTYPE(999_999.0)\n",
"\n",
"# Price of non-traded goods (from intratemporal optimality)\n",
"price_nt(ct, yn; ω_p=ω, η_p=η) = (1-ω_p)/ω_p * (max(ct, FTYPE(1e-9)) / yn)^(1/η_p)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Neural Network Architecture\n",
"\n",
"**Input**: 3D state (l, ν, φ) \n",
"**Output**: Policy l' ∈ [l_min, l_max] \n",
"**Architecture**: 3 → 256 → 128 → 1 with ReLU activations"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"const state_dim = 3\n",
"const policy_dim = 1\n",
"const Q1 = 256\n",
"const Q2 = 128\n",
"activation_f = relu\n",
"const T_sigmoid = FTYPE(1.0)\n",
"\n",
"# Build network: maps normalized state to policy\n",
"model = Chain(\n",
" Dense(state_dim, Q1, activation_f),\n",
" Dense(Q1, Q2, activation_f),\n",
" Dense(Q2, policy_dim),\n",
" x -> l_min .+ (l_max - l_min) .* Flux.sigmoid(x ./ T_sigmoid)\n",
")\n",
"\n",
"println(\"Network created with $(sum(length, Flux.params(model))) parameters\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Training Setup\n",
"\n",
"**Optimizer**: AdaBelief with two-stage learning rate \n",
"**Epochs**: 50,000 (warmup at 30,000) \n",
"**Batch size**: 32"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(layers = ((weight = \u001b[32mLeaf(AdaBelief(eta=0.001, beta=(0.9, 0.999), epsilon=1.0e-16), \u001b[39m(Float32[0.0 0.0 0.0; 0.0 0.0 0.0; … ; 0.0 0.0 0.0; 0.0 0.0 0.0], Float32[0.0 0.0 0.0; 0.0 0.0 0.0; … ; 0.0 0.0 0.0; 0.0 0.0 0.0], (0.9, 0.999))\u001b[32m)\u001b[39m, bias = \u001b[32mLeaf(AdaBelief(eta=0.001, beta=(0.9, 0.999), epsilon=1.0e-16), \u001b[39m(Float32[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0 … 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], Float32[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0 … 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], (0.9, 0.999))\u001b[32m)\u001b[39m, σ = ()), (weight = \u001b[32mLeaf(AdaBelief(eta=0.001, beta=(0.9, 0.999), epsilon=1.0e-16), \u001b[39m(Float32[0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0; … ; 0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0], Float32[0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0; … ; 0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0], (0.9, 0.999))\u001b[32m)\u001b[39m, bias = \u001b[32mLeaf(AdaBelief(eta=0.001, beta=(0.9, 0.999), epsilon=1.0e-16), \u001b[39m(Float32[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0 … 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], Float32[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0 … 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], (0.9, 0.999))\u001b[32m)\u001b[39m, σ = ()), (weight = \u001b[32mLeaf(AdaBelief(eta=0.001, beta=(0.9, 0.999), epsilon=1.0e-16), \u001b[39m(Float32[0.0 0.0 … 0.0 0.0], Float32[0.0 0.0 … 0.0 0.0], (0.9, 0.999))\u001b[32m)\u001b[39m, bias = \u001b[32mLeaf(AdaBelief(eta=0.001, beta=(0.9, 0.999), epsilon=1.0e-16), \u001b[39m(Float32[0.0], Float32[0.0], (0.9, 0.999))\u001b[32m)\u001b[39m, σ = ()), ()),)"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"const initial_learning_rate = FTYPE(1e-3)\n",
"const final_learning_rate = FTYPE(1e-6)\n",
"const epochs_warmup = 30000\n",
"const epochs = 50000\n",
"const batch_size = 32\n",
"\n",
"opt = AdaBelief(initial_learning_rate)\n",
"st = Flux.setup(opt, model)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## State Normalization\n",
"\n",
"Normalize shocks but keep debt in original scale for better training."
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"sample_states (generic function with 1 method)"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"const shock_mean = mean(S, dims=1)'\n",
"const shock_std = std(S, dims=1)'\n",
"\n",
"function normalize_state(batch)\n",
" l = batch[1:1, :]\n",
" shocks = batch[2:3, :]\n",
" normalized_shocks = (shocks .- shock_mean) ./ (2 .* shock_std)\n",
" return vcat(l, normalized_shocks)\n",
"end\n",
"\n",
"function sample_states(batch_size::Int)\n",
" # Sample debt uniformly\n",
" loans = rand(FTYPE, 1, batch_size) .* (l_max - l_min) .+ l_min\n",
" # Sample shocks\n",
" shock_indices = rand(1:n_s, batch_size)\n",
" shocks_ν = S[shock_indices, 1]'\n",
" shocks_ϕ = S[shock_indices, 2]'\n",
" return vcat(loans, shocks_ν, shocks_ϕ), shock_indices\n",
"end"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Loss Function: Fischer-Burmeister\n",
"\n",
"The FB function handles complementarity: \n",
"FB(a, b) = a + b - √(a² + b²) \n",
"\n",
"Where:\n",
"- a = Euler equation residual\n",
"- b = Constraint slack\n",
"\n",
"At optimum, FB = 0 whether constraint binds or not."
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"ann_loss (generic function with 1 method)"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"const huber_delta = FTYPE(0.1)\n",
"\n",
"function huber_loss(x, δ)\n",
" abs_x = abs.(x)\n",
" return Flux.mean(ifelse.(abs_x .< δ, FTYPE(0.5) .* x.^2, δ .* (abs_x .- FTYPE(0.5) * δ)))\n",
"end\n",
"\n",
"function ann_loss(m, batch::AbstractMatrix, shock_indices::Vector{Int})\n",
" current_batch_size = size(batch, 2)\n",
" normalized_batch = normalize_state(batch)\n",
" l, ν, ϕ = batch[1:1, :], batch[2:2, :], batch[3:3, :]\n",
"\n",
" # Current policy\n",
" lp = m(normalized_batch)\n",
" ct = yT*(1-τ) .- l .* (FTYPE(1.0) .- ϕ) .+ lp ./ (FTYPE(1.0) + i_l)\n",
" pn = price_nt.(ct, yN)\n",
"\n",
" # Vectorized expectation: evaluate ALL future states at once\n",
" lp_expanded = repeat(lp, inner=(1, n_s))\n",
" S_tiled = repeat(S', 1, current_batch_size)\n",
" state_prime_unnormalized = vcat(lp_expanded, S_tiled)\n",
" lpp = m(normalize_state(state_prime_unnormalized))\n",
" ϕ_p_tiled = state_prime_unnormalized[3:3, :]\n",
" ct_prime = yT*(1-τ) .- lp_expanded .* (FTYPE(1.0) .- ϕ_p_tiled) .+ lpp ./ (FTYPE(1.0) + i_l)\n",
" future_marg_utils = ∂u∂cT.(ct_prime, yN)\n",
" \n",
" # Compute expectations\n",
" future_marg_utils_reshaped = reshape(future_marg_utils, n_s, current_batch_size)\n",
" E_λp = Π' * future_marg_utils_reshaped\n",
" cartesian_indices = CartesianIndex.(shock_indices, 1:current_batch_size)\n",
" E_λp_final = E_λp[cartesian_indices]'\n",
"\n",
" # Fischer-Burmeister residual\n",
" λ = ∂u∂cT.(ct, yN)\n",
" borr_const = κ .* (pn .* yN .+ yT*(1-τ))\n",
" epsilon = FTYPE(1e-9)\n",
" a_norm = FTYPE(1.0) .- (β .* E_λp_final .* (FTYPE(1.0) + i_l)) ./ (λ .+ epsilon)\n",
" b_norm = FTYPE(1.0) .- lp ./ (borr_const .+ epsilon)\n",
" fb_residual = a_norm .+ b_norm .- sqrt.(a_norm.^2 .+ b_norm.^2)\n",
" \n",
" return huber_loss(fb_residual, huber_delta)\n",
"end"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Multiplier Recovery Function\n",
"\n",
"After training, we can recover the Kuhn-Tucker multiplier μ to see where the constraint binds."
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"recover_multiplier (generic function with 1 method)"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"function recover_multiplier(m, batch::AbstractMatrix, shock_indices::Vector{Int})\n",
" current_batch_size = size(batch, 2)\n",
" normalized_batch = normalize_state(batch)\n",
" l, ν, ϕ = batch[1:1, :], batch[2:2, :], batch[3:3, :]\n",
" lp = m(normalized_batch)\n",
" ct = yT*(1-τ) .- l .* (FTYPE(1.0) .- ϕ) .+ lp ./ (FTYPE(1.0) + i_l)\n",
"\n",
" # Same vectorized expectation as in loss\n",
" lp_expanded = repeat(lp, inner=(1, n_s))\n",
" S_tiled = repeat(S', 1, current_batch_size)\n",
" state_prime_unnormalized = vcat(lp_expanded, S_tiled)\n",
" lpp = m(normalize_state(state_prime_unnormalized))\n",
" ϕ_p_tiled = state_prime_unnormalized[3:3, :]\n",
" ct_prime = yT*(1-τ) .- lp_expanded .* (FTYPE(1.0) .- ϕ_p_tiled) .+ lpp ./ (FTYPE(1.0) + i_l)\n",
" future_marg_utils = ∂u∂cT.(ct_prime, yN)\n",
" future_marg_utils_reshaped = reshape(future_marg_utils, n_s, current_batch_size)\n",
" E_λp = Π' * future_marg_utils_reshaped\n",
" cartesian_indices = CartesianIndex.(shock_indices, 1:current_batch_size)\n",
" E_λp_final = E_λp[cartesian_indices]'\n",
"\n",
" # Multiplier from Euler equation\n",
" λ = ∂u∂cT.(ct, yN)\n",
" μ = λ ./ (FTYPE(1.0) + i_l) .- β .* E_λp_final\n",
" return μ\n",
"end"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Training Loop\n",
"\n",
"Train for 50,000 epochs with learning rate decay at epoch 30,000."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"println(\"--- Starting ANN Training ---\")\n",
"losses = FTYPE[]\n",
"p_bar = Progress(epochs; desc=\"Training Policy ANN:\")\n",
"\n",
"for epoch in 1:epochs\n",
" # Switch to lower learning rate at warmup point\n",
" if epoch == epochs_warmup\n",
" println(\"\\n--- Switched to final learning rate: $(final_learning_rate) ---\")\n",
" global opt = ADAM(final_learning_rate)\n",
" global st = Flux.setup(opt, model)\n",
" end\n",
" \n",
" # Sample batch and compute loss\n",
" batch, shock_indices = sample_states(batch_size)\n",
" loss, grads = Flux.withgradient(m -> ann_loss(m, batch, shock_indices), model)\n",
" Flux.update!(st, model, grads[1])\n",
" push!(losses, loss)\n",
"\n",
" ProgressMeter.next!(p_bar; showvalues=[(:epoch, epoch), (:loss, loss)])\n",
"end\n",
"\n",
"println(\"--- Training Complete ---\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Visualize Results\n",
"\n",
"Plot:\n",
"1. Training loss convergence\n",
"2. Learned policy function\n",
"3. Kuhn-Tucker multiplier"
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {},
"outputs": [
{
"data": {
"image/png": "",
"image/svg+xml": [
"\n",
"\n"
],
"text/html": [
"\n",
"\n"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"# Loss plot\n",
"p1 = plot(losses,\n",
" label=\"Training Loss\",\n",
" xlabel=\"Epoch\",\n",
" ylabel=\"Loss (log scale)\",\n",
" title=\"ANN Training Progress\",\n",
" yaxis=:log10,\n",
" legend=:topright)\n",
"\n",
"# Prepare plotting grid (at lowest shocks)\n",
"println(\"Plotting learned policy function...\")\n",
"plot_l_grid = l_grid\n",
"plot_ν_grid = fill(ν_grid[1], n_l)\n",
"plot_ϕ_grid = fill(ϕ_grid[1], n_l)\n",
"plot_grid_unnormalized = vcat(plot_l_grid', plot_ν_grid', plot_ϕ_grid')\n",
"plot_shock_indices = fill(1, n_l)\n",
"\n",
"learned_policy = model(normalize_state(plot_grid_unnormalized))\n",
"\n",
"# Policy plot\n",
"p2 = plot(l_grid, learned_policy',\n",
" label=\"Learned Policy l'(l)\",\n",
" xlabel=\"Current Loans (l)\",\n",
" ylabel=\"Next Period's Loans (l')\",\n",
" title=\"Policy Function (at lowest shocks)\",\n",
" linewidth=2,\n",
" legend=:topleft)\n",
"plot!(l_grid, l_grid,\n",
" linestyle=:dash,\n",
" color=:black,\n",
" label=\"45-degree line\")\n",
"\n",
"# Multiplier plot\n",
"println(\"Recovering and plotting KT multiplier...\")\n",
"recovered_mu = recover_multiplier(model, plot_grid_unnormalized, plot_shock_indices)\n",
"\n",
"p3 = plot(l_grid, recovered_mu',\n",
" label=\"Multiplier μ(l)\",\n",
" xlabel=\"Current Loans (l)\",\n",
" ylabel=\"KT Multiplier (μ)\",\n",
" title=\"Recovered KT Multiplier (at lowest shocks)\",\n",
" linewidth=2,\n",
" legend=:topleft)\n",
"plot!(l_grid, zeros(FTYPE, n_l),\n",
" linestyle=:dash,\n",
" color=:black,\n",
" label=\"\")\n",
"\n",
"# Combine plots\n",
"plt = plot(p1, p2, p3, layout=(1,3), size=(1100, 600))\n",
"display(plt)\n",
"\n",
"# Save figure\n",
"# savefig(plt, \"Collateral_Constraint_ANN_Results.png\")\n",
"# println(\"Figure saved as Collateral_Constraint_ANN_Results.png\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Interpretation\n",
"\n",
"**Policy Function**: Shows characteristic kink where borrowing constraint starts to bind\n",
"\n",
"**Multiplier μ**: \n",
"- μ = 0: Constraint not binding (interior solution)\n",
"- μ > 0: Constraint binds (household wants to borrow more)\n",
"\n",
"The kink location indicates the threshold debt level where sudden stops occur."
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Julia 1.11.2",
"language": "julia",
"name": "julia-1.11"
},
"language_info": {
"file_extension": ".jl",
"mimetype": "application/julia",
"name": "julia",
"version": "1.11.2"
}
},
"nbformat": 4,
"nbformat_minor": 4
}