How to speed up vectorized operations for dynamic programming
2 Ansichten (letzte 30 Tage)
Ältere Kommentare anzeigen
Alessandro
am 14 Sep. 2024
Bearbeitet: Alessandro
am 14 Sep. 2024
I would like to speed up the following code which solves a discrete dynamic programming problem using the method of successive approximations, as described e.g. in Bertsekas.
The algorithm is made of two steps. In step 1, I precompute the payoff array R(a',a,z) where a' is the action and (a,z) are the states. In step 2 I compute the value function using the method of successive approximations: I guess V0, then I compute an updated V1 and finally I check if ||V1-V0|| is less than a tolerance level. If it is, I stop, otherwise I set V0=V1 and go on.
I profiled the code (see below a MWE) and the two most time-consuming lines are the following ones:
(1) RHS = Ret+beta*permute(EV,[1,3,2]);
(2) [max_val,max_ind] = max(RHS,[],1);
Line (1) takes up 63% of the total running time, line (2) takes 32%. As you can see I have already vectorized all loops.
I would be very grateful for any suggestion. I post below a MWE. (Note that I set n_a, the num of grid points, to a low value on purpose, to allow interested users to run quickly the example. In my actual code, n_a=10000 or more).
%% Solve income fluctuation problem CPU
clear;clc;close all
%% Economic parameters
sigma = 2;
r = 0.03;
beta = 0.96;
PZ = [0.60 0.40;
0.05 0.95];
z_grid = [0.5 1.0]';
n_z = length(z_grid);
b = 0; %lower bound for asset holdings a
grid_max = 4;
n_a = 500; % IN PRACTICE THIS IS EQUAL TO 5000-10000
R = 1+r;
a_grid = linspace(-b,grid_max,n_a)';
if sigma==1
fun_u = @(c) log(c);
else
fun_u = @(c) c.^(1-sigma)/(1-sigma);
end
%% Computational parameters
verbose = 0;
tiny = 1e-8; %very small positive number
tol = 1e-6; %tolerance for VFI and TI
max_iter = 500; %maximum num. of iterations for both VFI and TI
%% Start timing
tic
%% STEP 1- Precompute current payoff array R(a',a,z)
a_tomorrow = a_grid; %(a',1,1)
a_today = a_grid'; %(1,a,1)
z_today = shiftdim(z_grid,-2); %(1,1,z)
cons = (1+r)*a_today+z_today-a_tomorrow;
Ret = fun_u(cons); %size: [n_a,n_a,n_z]
Ret(cons<=0) = -inf;
%% STEP 2 - Value function iteration
iter = 1;
err = tol+1;
V0 = zeros(n_a,n_z);
while err>tol && iter<=max_iter
EV = V0*PZ'; %(a',z)
RHS = Ret+beta*permute(EV,[1,3,2]);
[max_val,max_ind] = max(RHS,[],1);
V1 = squeeze(max_val);
pol_ind_ap = squeeze(max_ind);
err = max(abs(V0(:)-V1(:)));
if verbose==1
fprintf('iter = %d, err = %f \n',iter,err)
end
iter = iter+1;
V0 = V1;
end
if err>tol
error('VFI did not converge!')
else
fprintf('VFI converged after = %d iterations \n',iter)
end
pol_ap = a_grid(pol_ind_ap);
pol_c = (1+r)*a_grid+z_grid'-pol_ap;
%% End timing
toc
%% Figures
figure
plot(a_grid,pol_c(:,1),'linewidth',2)
hold on
plot(a_grid,pol_c(:,2),'linewidth',2)
legend('Low shock','High shock','Location','NorthWest')
xlabel('asset level')
ylabel('consumption')
title('Consumption Policy Function')
figure
plot(a_grid,a_grid,'--','linewidth',2)
hold on
plot(a_grid,pol_ap(:,1),'linewidth',2)
hold on
plot(a_grid,pol_ap(:,2),'linewidth',2)
legend('45 line','Low shock','High shock','Location','NorthWest')
xlabel('Current period assets')
ylabel('Next-period assets')
title('Assets Policy Function')
2 Kommentare
Akzeptierte Antwort
Matt J
am 14 Sep. 2024
Bearbeitet: Matt J
am 14 Sep. 2024
This might be a little faster.
betaPZtransp=beta*PZ';
tic
while err>tol && iter<=max_iter
RHS = Ret + reshape(V0*betaPZtransp,n_a,1,n_z);
V1 = max(RHS,[],1);
err = norm( V0(:)-V1(:) ,inf);
if verbose
fprintf('iter = %d, err = %f \n',iter,err)
end
iter = iter+1;
V0 = reshape(V1,n_a,n_z);
end
toc
[V1,pol_ind_ap]=max(RHS,[],1);
pol_ind_ap = reshape(pol_ind_ap, n_a,n_z);
2 Kommentare
Weitere Antworten (0)
Siehe auch
Kategorien
Mehr zu Loops and Conditional Statements finden Sie in Help Center und File Exchange
Community Treasure Hunt
Find the treasures in MATLAB Central and discover how the community can help you!
Start Hunting!