classdef spherical_design < handle
    %SPHERICAL_DESIGN Summary of this class goes here
    %   Detailed explanation goes here

    properties
        name    % Design name: name of HRTF-set, or the design type
        N0      % number of sampling points
        x0      % sampling positions, Descartes
        Phi0    % sampling positions, Spherical
        n_out   % outward normal
        polar_gap
        dummy_ixs
        dummy_weights

        delauney_trias
        delauney_circles
        delauney_vertices
        delauney_matrices
        voronoi_cells

        SHT_mx
        ISHT_mx
    end

    methods
        function obj = spherical_design(varargin)
            % Definition mode:
            %   Direct:  sample grid is given directly
            %   Design:  design mode is given with N0 being the no. of sampling points
            def_mode = varargin{1};
            switch def_mode
                case 'Direct'
                    if std(abs(varargin{2}(:,3)))<1e-1
                        obj.Phi0 = varargin{2};
                        obj.Phi0(:,3) = mean(obj.Phi0(:,3));
                        % 2.5D to 3D design
                        if all(abs(obj.Phi0(:,2))<1e-6)
                            obj.Phi0 = [obj.Phi0; [0,-pi/2*180/pi, mean(obj.Phi0(:,3))]; [0,pi/2*180/pi, mean(obj.Phi0(:,3))]];
                        end
                        if max(abs(obj.Phi0(:,1)))>2*pi
                            obj.Phi0(:,1:2) = obj.Phi0(:,1:2)*pi/180;
                        end
                        [x0,y0,z0] = sph2cart(obj.Phi0(:,1),obj.Phi0(:,2),obj.Phi0(:,3));
                        obj.x0 = [x0 y0 z0]*(mean(obj.Phi0(:,3))).^2./(sum( [x0 y0 z0].^2,2));
                    else
                        obj.x0 = varargin{2};
                        [azim,elev,R] = cart2sph(obj.x0(:,1),obj.x0(:,2),obj.x0(:,3));
                        obj.Phi0 = [azim, elev,R];
                    end
                    obj.name = varargin{3};
                    obj.N0 = length(obj.x0);
                case 'Design'
                    obj.N0 = varargin{2};
                    R = varargin{4};
                    obj.name = [varargin{3},'_',mat2str(varargin{2}),'_R',mat2str(R)];
                    switch varargin{3}
                        case 'T'
                            files = dir('Data/Spherical_designs/t_design_points');
                            for n = 3 : length(files)
                                temp = split(files(n).name,'.');
                                N_available(n-2) = str2num(temp{2});
                            end
                            [~,file_ix] = min(abs(N_available-obj.N0));
                            fid = fopen(fullfile(files(file_ix+2).folder,files(file_ix+2).name), 'r');
                            x0 = fscanf(fid,'%f');
                            obj.x0 = R*reshape(x0,[3,size(x0,1)/3])';
                            [azim0,elev0,r0] = cart2sph(obj.x0(:,1),obj.x0(:,2),obj.x0(:,3));
                            obj.Phi0 = [azim0,elev0,r0];
                            fclose(fid);

                        case 'Gauss'
                            azim = linspace(0,2*pi-2*pi/obj.N0,obj.N0);
                            Ne = round(obj.N0/4);
                            elev = (-Ne:Ne)/Ne*(pi/2-1/Ne);
                            [Azim,Elev] = meshgrid(azim,elev);
                            Phi0 = [Azim(:),Elev(:)];
                            %   Phi0 = [0,-pi/2;Phi0;0,pi/2];
                            obj.Phi0 = [Phi0,ones(size(Phi0,1),2)*R];
                            [x0,y0,z0] = sph2cart(obj.Phi0(:,1),obj.Phi0(:,2),obj.Phi0(:,3));
                            obj.x0 = [x0 y0 z0]*(mean(obj.Phi0(:,3))).^2./(sum( [x0 y0 z0].^2,2));

                    end
                case 'File'
                    fid = fopen(fullfile('Data/Spherical_designs/user/',varargin{2}), 'r');
                    x0 = fscanf(fid,'%f');
                    R = varargin{3};
                    obj.x0 = R*reshape(x0,[3,size(x0,1)/3])';
                    mainDir = obj.x0(13,:);
                    R = get_rotation_mx(mainDir,[1 0 0]);
                    obj.x0 = obj.x0 * R';
                    [azim0,elev0,r0] = cart2sph(obj.x0(:,1),obj.x0(:,2),obj.x0(:,3));
                    obj.Phi0 = [azim0,elev0,r0];
                    fclose(fid);


                    obj;

            end
            %%
            obj.n_out = obj.x0./sqrt(sum(obj.x0.^2,2));
            [ ~, tria_faces ] = sphere_delaunay( size(obj.x0,1) , obj.x0'/mean(obj.Phi0(:,3)) );
            obj.delauney_trias = tria_faces;
            obj.get_delauney_circles;
            obj.get_delauney_vertices;

            [v_vor, K, n_in, dS] = get_voronoi_polygons(obj.delauney_trias, obj.x0);
            obj.voronoi_cells = struct('vertices',[],'connectivity',[],'dS',[], 'n_in',[]);
            obj.voronoi_cells(1).vertices = v_vor;
            obj.voronoi_cells(1).connectivity = K';
            obj.voronoi_cells(1).dS = dS';
            obj.voronoi_cells(1).n_in = n_in';

            obj.check_polar_gap;


        end
        function obj = check_polar_gap(obj)
            R0 = mean(sqrt(sum((obj.x0).^2,2)));
            polar_dist = min(sqrt(sum((obj.x0-[0,0,-1].*R0).^2,2)));
            if polar_dist>1e-2
                X0 = obj.x0./sqrt(sum((obj.x0).^2,2));
                x = X0(:,1)./(1+X0(:,3));
                y = X0(:,2)./(1+X0(:,3));
                K{1} = convhull([x y])';
                dS = get_spherical_polygon_area(K,obj.x0');
                rat = dS/(4*pi*R0^2);
            else
                rat = 0;
            end
            if rat > 0.1
                total = (sum(obj.voronoi_cells.dS(K{1}))-dS);
                obj.voronoi_cells.dS(K{1}) = obj.voronoi_cells.dS(K{1})/sum(obj.voronoi_cells.dS(K{1}))*total;
            end
        end

        function get_delauney_circles(obj)
            xc = zeros(size(obj.delauney_trias,2),3);
            r = zeros(size(obj.delauney_trias,2),1);
            fi0 = zeros(size(obj.delauney_trias,2),1);
            for n = 1 : size(obj.delauney_trias,2)
                tria = obj.delauney_trias(:,n);
                xv = obj.x0(tria,:);
                [xc(n,:),r(n),v1n,v2nb] = circlefit3d(xv(1,:),xv(2,:),xv(3,:));
                fi0(n) = atan(r(n)/norm(xc(n,:)));
            end
            xc = xc./sqrt(sum(xc.^2,2));
            obj.delauney_circles = struct('center',xc,'fi',fi0);
        end

        function [ix_tria,x_intersection] = find_triangle2point(obj,xp)
            % Calculates the Delauney triangle that is intersected by a
            % direction, given by position vector xp, starting from the
            % origin
            for n = 1 : size(obj.delauney_trias ,2)
                tria = obj.delauney_trias(:,n);
                p0 = obj.x0(tria(1),:);
                v1 = obj.x0(tria(2),:) - p0;
                v2 = obj.x0(tria(3),:) - p0;
                normal = cross(v1',v2');
                normal = normal ./ sqrt( sum( normal.^2 ) );
                d(n) = dot(p0,normal)/dot(xp/norm(xp),normal);
                xp_ = d(n)*xp/norm(xp);
                w(:,n) = [v1',v2']\(xp_-p0');
            end
            w = (round(1e8*w))/1e8;
            ix_tria = find(w(1,:)>=0 & w(2,:)>=0 & sum(w,1)<=1&d>0);
            x_intersection = xp/norm(xp)*d(ix_tria(1));
        end

        function [f_out, x_interp] = interpolate(obj, x_target, f_in,mode)
            % interpolates input function to poistions over the sphere into
            % the direction x_target
            switch mode
                case 'linear'
                    [ixout,x_interp] = obj.find_triangle2point(x_target);
                    ix = obj.delauney_trias(:,ixout(1));
                    vertices = obj.x0(ix,:);
                    w = vertices'\x_interp;
                    f_out = squeeze(sum(f_in(ix,:,:).*w,1)).';
                case 'nearest'
                    [~,ix] = min(sum(( obj.x0 - x_target'./norm(x_target).*mean(obj.Phi0(:,3)) ).^2,2));
                    f_out = squeeze(f_in(ix,:,:))';
            end
        end

        function [data_out] = interpolate_dummy_nodes(obj,data_in)
            data_out = zeros(size(data_in,1),size(obj.x0,1));
            ix0 = setdiff((1:size(obj.x0,1)),obj.dummy_ixs);
            data_out(:,ix0) = data_in;

            F = fft(data_in,[],1);
            for n = 1 : length(obj.dummy_ixs)
                in = F(:,obj.dummy_weights{n}(:,1));
                amp_out = abs(in)*obj.dummy_weights{n}(:,2);
                phase_out = unwrap(angle(in),1)*obj.dummy_weights{n}(:,2);
                data_out(:,obj.dummy_ixs(n)) = ifft( amp_out.*exp(1i*phase_out),'symmetric');
            end
        end

        function obj = get_delauney_vertices(obj)
            for n = 1 : size(obj.x0,1)
                [rows,cols] = find(obj.delauney_trias'==n);
                p0 = obj.x0(n,:)';
                nout = p0/norm(p0);
                trias = obj.delauney_trias(:,rows);
                Vmx = [];
                for ti = 1 : size(trias,2)
                    ixs = setdiff((1:3),cols(ti));
                    p1 = obj.x0( trias(ixs(1),ti), :)';
                    p2 = obj.x0( trias(ixs(2),ti), :)';
                    v1 = (p1-p0);
                    v2 = (p2-p0);
                    V{ti} = [v1,v2];
                    Vmx = [Vmx;pinv( [v1,v2]-([v1,v2]'*nout)'.*nout )];
                end
                obj.delauney_vertices{n} = V;
                obj.delauney_matrices{n} = Vmx;
            end
        end

        function [dx] = get_delauney_dx(obj,kin)
            dx = zeros(size(kin,1),1);
            for n = 1 : size(kin,1)
                w = obj.delauney_matrices{n}*kin(n,:)';
                w = reshape(w,[2,length(w)/2]);
                ix_sol = find(w(1,:)>=-1e-10&w(2,:)>=-1e-10);
                ix_sol = ix_sol(1);
                dx(n) = norm((obj.delauney_vertices{n}{ix_sol})*w(:,ix_sol)/norm(w(:,ix_sol),1));
                if isnan(dx(n))
                    dx(n) = 0;
                end

            end
        end


        function [x0,dS,n_in] = append_point(obj,xp)
            dS = obj.voronoi_cells.dS;
            [ixout,x_extrap] = obj.find_triangle2point(xp);
            ix_pos = obj.delauney_trias(:,ixout(1));
            vertices = obj.x0(ix_pos,:);
            weights = round(1000*(vertices'\x_extrap))/1e3;
            dS0 = weights'*dS(ix_pos);
            dS(ix_pos) = (1-weights).*dS(ix_pos);
            dS(end+1,1) = dS0;
            x0 = [obj.x0;xp'/norm(xp)*mean(sqrt(sum((obj.x0).^2,2)))];
            n_in = [-obj.n_out; -xp'/norm(xp)];
        end

        function [SHT_out] = SHT(obj,mode,in,Nmax)
            switch mode
                case 'quadrature'
                    input = fft(in,[],3);
                    [quadrature_points, output] = interpolate2quadrature(input, obj.delauney_trias,obj.x0);
                    [azim, elev]= cart2sph(quadrature_points.quad_pos(:,1),quadrature_points.quad_pos(:,2),quadrature_points.quad_pos(:,3));
                    zenith = pi/2 - elev;
                    S = 4*pi*mean(obj.Phi0(:,3))^2;
                    Y = getSpherHarmMx( zenith, azim, Nmax, 'real' );
                    SHT_out = zeros(size(Y,2),size(in,2),size(in,3));
                    for m = 1 : size(in,3)
                        SHT_out(:,:,m) = (4*pi*quadrature_points.quad_w.*Y./S)'*squeeze(output(:,:,m));
                    end
                case 'pinv'
                    if isempty(obj.SHT_mx)
                        obj.SHT_mx = pinv( getSpherHarmMx( pi/2-obj.Phi0(:,2), obj.Phi0(:,1), 3, 'real') );
                    end
                    SHT_out = in*obj.SHT_mx';
            end
        end

        function yout = ISHT(obj,in)
             if isempty(obj.ISHT_mx)
                  obj.ISHT_mx = getSpherHarmMx( pi/2-obj.Phi0(:,2), obj.Phi0(:,1), 3, 'real');
             end
             yout = in*obj.ISHT_mx';
        end
        function Ymx = get_SH_mx(obj)
             if isempty(obj.ISHT_mx)
                  obj.ISHT_mx = getSpherHarmMx( pi/2-obj.Phi0(:,2), obj.Phi0(:,1), 3, 'real');
             end
             Ymx = obj.ISHT_mx;
        end

        function obj = set_dummy_ixs(obj,ixs)
            obj.dummy_ixs = ixs;
            ix0 = setdiff((1:size(obj.x0,1)),ixs);
            [ ~, tria_faces ] = sphere_delaunay( size(obj.x0(ix0,:),1) , obj.x0(ix0,:)'/mean(obj.Phi0(:,3)) );
            for m = 1 : length(ixs)
                xp = obj.x0(ixs(m),:)';
                d = [];
                w = [];
                for n = 1 : size(tria_faces ,2)
                    tria = tria_faces(:,n);
                    p0 = obj.x0(tria(1),:);
                    v1 = obj.x0(tria(2),:) - p0;
                    v2 = obj.x0(tria(3),:) - p0;
                    normal = cross(v1',v2');
                    normal = normal ./ sqrt( sum( normal.^2 ) );
                    d(n) = dot(p0,normal)/dot(xp/norm(xp),normal);
                    xp_ = d(n)*xp/norm(xp);
                    w(:,n) = [v1',v2']\(xp_-p0');
                end
                w = (round(1e8*w))/1e8;
                ix_tria = find(w(1,:)>=0 & w(2,:)>=0 & sum(w,1)<=1&d>0);
                ix_tria = ix_tria(1);
                x_intersection = xp/norm(xp)*d(ix_tria);
                ix_vert = tria_faces(:,ix_tria );
                weights = obj.x0(ix_vert,:)'\x_intersection;
                obj.dummy_weights{m} = [ix_vert, weights];
            end
        end


        % Plotting functions
        function scatter_sphere(obj)
            scatter3(obj.x0(:,1),obj.x0(:,2),obj.x0(:,3),10,'k','filled');
            axis equal tight
            view(3)
        end

        function hp = plot_voronoi(varargin)
            if length(varargin) == 1
                obj = varargin{1};
                alpha = 1;
            elseif length(varargin) == 2
                obj = varargin{1};
                alpha = varargin{2};
            end
            K = obj.voronoi_cells.connectivity;
            vertices = obj.voronoi_cells.vertices;
            for n = 1 : size(K,1)
                hp(n) = patch(vertices(1,K{n}),vertices(2,K{n}),vertices(3,K{n}),[0.8,0.8,0.8],'FaceAlpha',alpha);
                hold on
            end
            axis equal tight
            view(3)
            hold on
%            scatter3(obj.x0(:,1),obj.x0(:,2),obj.x0(:,3),10,'filled','black');
            % a = [1:size(obj.x0,1)]'; b = num2str(a); c = cellstr(b);
            % dx = 0.005;
            % text_pos = obj.x0 + dx*obj.n_out;
            % text(text_pos(:,1), text_pos(:,2),text_pos(:,3),c,'FontSize',6);
            xlabel('x', 'FontSize',8)
            ylabel('y', 'FontSize',8)
            zlabel('z', 'FontSize',8)
            ax = gca;
            ax.FontSize = 8;

        end
        function s = plot_data(obj,in)
            s = trisurf(obj.delauney_trias.',obj.x0(:,1),obj.x0(:,2),obj.x0(:,3),in);
            axis equal tight
            grid off
        end
    end
end

