%Main code for parallelized subdivision

%% Initialization
tic;
syms f_syms g_syms d_syms x y z
assume(x,'real');
assume(y,'real');
assume(z,'real');

% %Example 1
% f_syms(x,y,z) = x.^2+y.^2-z;
% g_syms(x,y,z) = x.^2+y.^2+z.^2-1;
% B0 = freecurvebox([-ones(3,1),ones(3,1)]);
% MAXEPS = inf;

% %Example 2 - Same as Example 1, but with a limit imposed on box sizes by MAXEPS
% f_syms(x,y,z) = x.^2+y.^2-z;
% g_syms(x,y,z) = x.^2+y.^2+z.^2-1;
% B0 = freecurvebox([-ones(3,1),ones(3,1)]);
% MAXEPS = 0.05;

% %Example 3
% f_syms(x,y,z) = x.^4+2*x.^2*y.^2+y.^4 -2*(x.^2+y.^2) +1-z;
% g_syms(x,y,z) = 0.5-z;
% % B0 = freecurvebox([zeros(3,1),ones(3,1)]);
% B0 = freecurvebox([-ones(3,1),ones(3,1)]).scale(1.2);
% % B0 = freecurvebox([-1,-0.8;-1,-0.8;0.4,0.6]);
% % B0 = freecurvebox([-ones(2,1),ones(2,1);[0.1,2.1]]);
% MAXEPS = inf;

% % %Example 4
% f_syms(x,y,z) = x.^2+y.^2-z.^2 - 2;
% g_syms(x,y,z) = x.^2-y.^2+z.^2 - 1;
% % B0 = freecurvebox([zeros(3,1),ones(3,1)]);
% B0 = freecurvebox([-ones(3,1),ones(3,1)]).scale(3);
% % B0 = freecurvebox([-1,-0.8;-1,-0.8;0.4,0.6]);
% % B0 = freecurvebox([-ones(2,1),ones(2,1);[0.1,2.1]]);
% MAXEPS = inf;

%Example 5: trisector
% https://mathworld.wolfram.com/Point-LineDistance3-Dimensional.html
%implement the sqrt function for intervals
p = [0, 0,1;...
     0, 0,1;...
     4,-4,0];
direction = [4, 4,4;...
             3,-3,1;...
             0, 0,4];
q = p+direction;
% d_syms(p,q,x,y,z) = norm(cross(q-p,p-[x;y;z]))/norm(q-p);
% f_syms(x,y,z) = norm(cross(q(:,1)-p(:,1),p(:,1)-[x;y;z]))/norm(q(:,1)-p(:,1))-norm(cross(q(:,2)-p(:,2),p(:,2)-[x;y;z]))/norm(q(:,2)-p(:,2));
% g_syms(x,y,z) = norm(cross(q(:,1)-p(:,1),p(:,1)-[x;y;z]))/norm(q(:,1)-p(:,1))-norm(cross(q(:,3)-p(:,3),p(:,3)-[x;y;z]))/norm(q(:,3)-p(:,3));

qp1 = q(:,1)-p(:,1);
p1 = p(:,1); 
qp2 = q(:,2)-p(:,2);
p2 = p(:,2);
qp3 = q(:,3)-p(:,3);
p3 = p(:,3);
f_syms(x,y,z) = norm(cross(qp1,p1-[x;y;z]))/norm(qp1)-norm(cross(qp2,p2-[x;y;z]))/norm(qp2);
g_syms(x,y,z) = norm(cross(qp1,p1-[x;y;z]))/norm(qp1)-norm(cross(qp3,p3-[x;y;z]))/norm(qp3);

B0 = freecurvebox( 10*[-ones(3,1),ones(3,1)]);
MAXEPS = inf;


%##################################################
%00000000000000000000000000000000000000000000000000
%##################################################


df_syms = gradient(f_syms,[x,y,z]);
dg_syms = gradient(g_syms,[x,y,z]);
f_mat = matlabFunction(f_syms);
g_mat = matlabFunction(g_syms);
df_mat = matlabFunction(df_syms);
dg_mat = matlabFunction(dg_syms);
[f,g,df,dg] = funmanipulation.boxfunction(f_mat,g_mat,df_mat,dg_mat);

[fig,ax] = createfigure(B0);
plotspace = reshape(B0.scale(4).boxdimensions',1,[]);
alphavalue = 0.2;
hf = fimplicit3(ax,f,plotspace,'EdgeColor','none','FaceAlpha',alphavalue,'FaceColor','b');
hg = fimplicit3(ax,g,plotspace,'EdgeColor','none','FaceAlpha',alphavalue,'FaceColor','y');
B0.plotbox(ax);
drawnow;

%% for trisector example
M = cat(3,p+100*(q-p),q+100*(p-q));
hline1 = plot3(ax,squeeze(M(1,:,:))',squeeze(M(2,:,:))',squeeze(M(3,:,:))','m');
%% end trisector part

filename = which('parallel_local_tracecurve');
[filepath,name,ext] = fileparts(filename);

%% Settings
%Rules for inherited test results
B0.testresults = cell(1,8); %1:C0(f),2:C0(g),3:C1(f),4:C1(g),5:Jaccobian,6:MK,7:MK_face,8:C0_face
B0.inherittestindices = 1:4;

%Depth limit for phase 1
depthlimit = 8; 
%Depth limit for phase 2
numiterMKlimit = 6;

tStart = tic; 

%% Subdivision Phase 1
Q = B0; %Input of the first phase of subdivision
QJac = []; %Output of first phase of subdivision

% Depth for phase 1
depth = 0;

% Create a subdivision of boxes, which all satisfy the predicates untill C1 tests
% and the Jaccobian tests hold (where boxes satisfying C0 get excluded at each level)
while ~isempty(Q) && depth <= depthlimit
    Q_next = cell(length(Q),1);
    QJac_add = cell(length(Q),1);
    
    disp(['Phase 1: depth = ', num2str(depth), ' | length(Q) = ', num2str(length(Q))]);  

    %Parallel Subdivision
    parfor i = 1:length(Q)
        B_par = Q(i);
        if ~local_predicate.C0(B_par,f,1) && ~local_predicate.C0(B_par,g,2)
            if B_par.radius<MAXEPS && local_predicate.C1(B_par,df,3) && local_predicate.C1(B_par,dg,4) && ... %&& local_predicate.C1cross(B_par,df,dg,5)
               local_predicate.Jaccobian(B_par,df,dg,5)
                QJac_add{i} = B_par; 
            else
                children = B_par.split;
                Q_next{i} = children;
            end
        end
    end
    
    accepted = [QJac_add{:}];

    disp(['# accepted boxes = ', num2str(length(accepted))]);

    %Collection of results
    Q = [Q_next{:}];
    QJac = [QJac, QJac_add{:}];

    depth = depth+1;
end

if ~isempty(Q) && depth == depthlimit + 1
    disp("Phase 1 stopped due to depth limit");
end

disp(['Time for Phase 1: ',num2str(toc(tStart)),'s']);

disp("Phase 1 finalized.");
disp("Proceeding to phase 2...");

tPhase2 = tic;

%% Subdivision Phase 2
QMK = QJac; %Input of the second phase of subdivision
Qcurve = []; %Output of second phase of subdivision
%%
%Depth for phase 2
numiterMK = 1;

while ~isempty(QMK) && numiterMK <= numiterMKlimit
    %Iteration cell arrays
    QMK_next = cell(length(QMK),1);
    Qcurve_add = cell(length(QMK),1);
    
    disp(['Phase 2: numiterMK = ', num2str(numiterMK), ' | length(QMK) = ', num2str(length(QMK))]); 

    %Parallel Subdivision
    parfor i = 1:length(QMK)
        B_par = QMK(i);
        if ~local_predicate.C0(B_par,f,1) && ~local_predicate.C0(B_par,g,2)
            if local_predicate.Jaccobian(B_par,df,dg,5) && local_predicate.MK_face(B_par,f,df,g,dg,7) %local_predicate.MK_face(B_par,f,df,g,dg,7) 
                if any(B_par.testresults{7})
                    Qcurve_add{i} = B_par;
                end %Else: MK test has succeeded not in finding a root, but in excluding all in internal call to C0_faces
            else
                children = B_par.split;
                QMK_next{i} = children;
            end
        end
    end
    
    accepted = [Qcurve_add{:}];

    disp(['# accepted boxes = ', num2str(length(accepted))]);

    %Collection of results
    QMK = [QMK_next{:}];
    Qcurve = [Qcurve, Qcurve_add{:}];

    numiterMK = numiterMK+1;
end

if ~isempty(QMK) && numiterMK == numiterMKlimit + 1
    disp("Phase 2 stopped due to depth limit");
end

disp(['Time for Phase 2: ',num2str(toc(tPhase2)),'s']);

disp("Phase 2 finalized.");

%% Outputs
disp(['Total time for subvision: ',num2str(toc(tStart)),'s']);

disp(['Number of boxes in Q: ',num2str(length(Q))]);

%boxes which were not classified and fail the Jaccobian
for i = 1:length(Q)
    Q(i).plotbox(ax,'r');
end

disp(['Number of boxes in QMK: ',num2str(length(QMK))]);

%boxes which were not classified but pass the Jaccobian
for i = 1:length(QMK)
    QMK(i).plotbox(ax,'y');
end

disp(['Number of accepted boxes: ',num2str(length(Qcurve))]);

%boxes which pass the MK test and satisfy all the C1/Jaccobian requirements
for i = 1:length(Qcurve)
    Qcurve(i).plotbox(ax,'g');
end

%% ####################





%% Curve Construction
tic;

leavessub = leaves(B0);
n = length(leavessub);
nodesnearcurvelogical = false(n,1);
centers = zeros(3,n);

parfor i = 1:n
    B = leavessub(i);
    leavessub(i).boxid = i;
    centers(:,i) = leavessub(i).center;
    if any(B.testresults{7}) %all boxes in the final subdivision which pass the MK test also pass Jaccobian
        nodesnearcurvelogical(i) = true;
    end
end
nodes = 1:n;
nodesnearcurve = nodes(nodesnearcurvelogical); %list of nodes according to initial IDs

Afull = logical(sparse(n,n));

for i = 1:n %Cannor directly use parfor here
    neighbors = leavessub(i).neighbors;
    for j = 1:length(neighbors) %probably not worth doing with parfor
        Afull(i,neighbors(j).boxid) = true;
    end
end

Afull = Afull | Afull';
Gfull = graph(Afull);

Gnearcurve = subgraph(Gfull,nodesnearcurve);

bins = conncomp(Gnearcurve);
numcomps = numel(unique(bins));
curvepieces = cell(1,numcomps);

disp(["Numcomps: ", num2str(numcomps)])

for comp = 1:numcomps
    nodescomponentlogical = bins == comp;
    nodescomponent = nodesnearcurve(nodescomponentlogical);
    Gcomponent = subgraph(Gnearcurve,nodescomponentlogical);

    ndirected = length(nodescomponent);
    
    disp(["Comp ", num2str(comp), ":"])
    disp(["Component size:", ndirected])

    if ndirected == 1
        disp("Skipped comp of size 1.")
        disp("-----ooooo-----")
        continue
    end

    Adirected = logical(sparse(ndirected,ndirected));
    edges = table2array(Gcomponent.Edges);
    for i = 1:size(edges,1)
        edge = edges(i,:);
        B1 = leavessub(nodescomponent(edge(1)));
        B2 = leavessub(nodescomponent(edge(2)));
        center1 = B1.center;
        center2 = B2.center;
        direction = (B1.radius+B2.radius)/sqrt(3) - (center2-center1) < 10*eps;
        direction = direction*(-1)^(center1(direction)<center2(direction));%B1 has smaller coordinates than B2 in direction
        if all((direction')*B1.testresults{5} >= 0) && all((direction')*B2.testresults{5} >=0)
            Adirected(edge(1),edge(2)) = true;
        elseif all((direction')*B1.testresults{5} <= 0) && all((direction')*B2.testresults{5} <=0)
            Adirected(edge(2),edge(1)) = true;
        end
    end
    Gdirected = digraph(Adirected);
    hGdirected = plot(ax,Gdirected,'XData',centers(1,nodescomponent),'YData',centers(2,nodescomponent),'ZData',centers(3,nodescomponent));
    hGdirected.ArrowSize = 5; %7.5;
    hGdirected.LineWidth = 0.5; %0.75;

    %%Find cycle or path
    %Use BFS to create a tree with shortest paths
    %Then run DFS to check for paths that have all boxes nearby (neighbors of neighbors)
    %Repeat for all neighbors of the starting vertex and select the shortest path
    [path,foundcycle] = findcycle(Gdirected);
    if ~foundcycle
        disp("Cycle not found!")
        path = findpath(Gdirected);
    else
        disp("Cycle found.")
    end
    disp(["Path length:", length(path)])
    disp("-----ooooo-----")
    curvepieces{comp} = path;
    
    %Plot curve
    Coordinates = zeros(3,length(path));
    for i = 1:length(path)
        Coordinates(:,i) = leavessub(nodescomponent(path(i))).center;
    end
    plot3(ax,Coordinates(1,:),Coordinates(2,:),Coordinates(3,:),'-k','Linewidth',3);
end
disp(['time for graph algorithm: ',num2str(toc),'s']);