worldStates
is a Matlab MxNxL 3D array (tensor) containing L states of a MxN grid of binary values.
ps
is a length L list of probabilities associated with the different states.
The function [worldStates, ps] = StateMerge(worldStates, ps)
should remove duplicate world states and sum the probabilities of the merged states to the single state that remains. Duplicate states are states with the exact same configuration of binary values.
Here is the current implementation of this function:
function [worldStates, ps] = StateMerge(worldStates, ps)
M = containers.Map;
for i = 1:length(ps)
s = worldStates(:,:,i);
s = mat2str(s);
if isKey(M, s)
M(s) = M(s) + ps(i);
else
M(s) = ps(i);
end
end
stringStates = keys(M);
n = length(stringStates);
sz = size(worldStates);
worldStates = zeros([sz(1:2), n]);
ps = zeros(1, 1, n);
for i = 1:n
worldStates(:,:,i) = eval(stringStates{i});
ps(i) = M(stringStates{i});
end
end
It uses a Map to be able to remove duplicates in O(L) time, using the states as keys and the probabilities as values. Since Matlab maps does not allow for general data structures as keys the states are converted into string representations to be used as keys and later converted back to arrays using the eval function.
It turns out this code is way to slow for my needs as i will want to process many states (magnitude ~10^6) many times (10^3). The problem lies in converting the matrix to a string which takes a substantial amount of time and scales poorly with state size. An example for small 25x25 states is given below:
How could i create keys in a more efficient manner? Is there another solution aside from using a map that would yield better results?
EDIT: Runnable code as requested. This example makes merges very unlikely:
worldStates = double(rand(25,25, 1000) > 0.5);
weights = rand(1,1, 1000);
ps = weights./sum(weights);
[worldStates, ps] = StateMerge(worldStates, ps);
In this example there will be lot's of merges:
worldStates = double(rand(25,25) > 0.5) .* ones(1,1,1000);
worldStates(1:2,1:2,:) = rand(2,2,1000) > 0.5;
weights = rand(1,1, 1000);
ps = weights./sum(weights);
[worldStates, ps] = StateMerge(worldStates, ps);