classdef interval < handle & matlab.mixin.Copyable
    %Interval arithmetic operations
        %functions
        
        %Overloading operartors for the interval class
        %a+b
        %a-b
        %-a
        %+a
        %a.*b
        %a*b
        %a./b
        %a.\b
        %a/b
        %a\b
        %a.^b
        %a^b
        %a < b
        %a > b
        %a <=b
        %a >= b
        %a ~= b
        %a == b
            %%%Not included
            %%% a&b
            %%% a|b
            %%% ~a
            %%% a:d:b
            %%% a:b
        %a'
        %a.'
            %%% [a,b]
            %%% [a;b]
            %%% a(s1,s2,...,sn)
            %%% a(s1,s2,...,sn) = b
            %%% b(a)
        %sqrt
        %nthroot
        %abs
        %sign
        %norm %L_2 norm of interval vectors
        %norm_p %L_p norm
        %disp
        %int2str
        %extractbounds: used when working with matrices of intervals
        %sum
        %min
        %max
        %cross(a,b): computes the cross product of two interval vectors of size 3
        %det: computes the determinant of a 2x2 interval matrix
        %a.cap(b):      computes the intersection of intervals a and b
        %a.cup(b):      computes the union of intervals a and b
        
        %a \subset b:   a.subset(b)
        %a \superset b: a.superset(b)
        %x \in a:       a.element(x)
        
        %scale(factor): if [a,b] is bounded then this doubles the length of the interval, while keeping the midpoint
        %               if [a,b] is unbounded, then this halves the length of the uncovered number line

        %linspace(n): creates a mesh with each interval getting uniformly split into n values
    properties
        bounds %[a,b]
            %if a<=b: I = [a,b]
            %if b<a:  I = [-inf,a] \cup [b,inf]
            %The empty interval is represented by [inf,-inf]
            %Imagine the ends of the number line being connected at infiniity
    end
    
    methods
        function this = interval(lowerbounds,upperbounds)
            %Accepts two matrices of same size as bounds to create a matrix of intervals
            if nargin > 0
                if nargin == 1
                    upperbounds = lowerbounds;
                end
                v = size(lowerbounds);
                if all(size(lowerbounds) == size(upperbounds))
                    numv = prod(v);
                    if numv > 1
                        for ind = numv:-1:1
                            this(ind) = interval(lowerbounds(ind),upperbounds(ind));
                        end
                        if length(v) == 1
                            v = [v,1];
                        end
                        this = reshape(this,v);
                    else
                        this.bounds = [lowerbounds,upperbounds];
                    end
                else
                    warning('wrong interval definition');
                end
            end
        end
        %Overloading operators with MATLAB:
        %https://ch.mathworks.com/help/matlab/matlab_oop/implementing-operators-for-your-class.html
        function result = plus(a,b)
            %a+b
            fct = @interval.plus_element;
            result = interval.elementwiseoperator(fct,a,b);
        end
        function result = minus(a,b)
            %a-b
            result = a+(-b);
        end
        function result = uminus(a)
            %-a
            fct = @interval.uminus_element;
            result = interval.elementwiseoperator(fct,a);
        end
        function result = uplus(a)
            %+a
            result = a;
        end
        function result = times(a,b)
            %a.*b
            fct = @interval.times_element;
            result = interval.elementwiseoperator(fct,a,b);
        end
        function result = mtimes(a,b)
            %a*b
            [i,j] = size(a);
            [j2,k] = size(b);
            result = interval(zeros([i,k]));
            if j == j2
                if ~isa(a,'interval')
                    a = interval(a);
                end
                if ~isa(b,'interval')
                    b = interval(b);
                end
                for iind = 1:i
                    for kind = 1:k
                        currentsum = interval(0,0);
                        for jind = 1:j
                            currentsum = currentsum+interval.times_element(a(iind,jind).bounds,b(jind,kind).bounds);
                        end
                        result(iind,kind) = currentsum;
                    end
                end
            else
                error('The matrix dimensions are not matching');
            end
        end
        function result = rdivide(a,b)
            %a./b
            if isa(b, 'interval')
                result = a.*inverse(b);
            else
                result = a.*(1/b);
            end
        end
        function result = inverse(a)
            %1./a
            fct = @interval.inverse_element;
            result = interval.elementwiseoperator(fct,a);
        end
            %a.\b
            %a/b
            %a\b
        function result = power(a,b)
            %a.^b
            fct = @interval.power_element;
            result = interval.elementwiseoperator(fct,a,b);
        end
            %a^b
        function result = lt(a,b)
            %a < b
            fct = @interval.lt_element;
            result = interval.elementwiseoperator(fct,a,b);
        end
        function result = gt(a,b)
            %a > b
            result = b<a;
        end
        function result = le(a,b)
            %a <=b
            fct = @interval.le_element;
            result = interval.elementwiseoperator(fct,a,b);
        end
        function result = ge(a,b)
            %a >= b
            result = b <= a;
        end
        function result = ne(a,b)
            %a ~= b
            result = ~(a==b);
        end
        function result = eq(a,b)
            %a == b
            fct = @interval.eq_element;
            result = interval.elementwiseoperator(fct,a,b);
        end
        
        function result = inverse_elementwise(this)
            fct = @interval.inverse_element;
            result = interval.elementwiseoperator(fct,this);
        end
        
        function result = sqrt(a)
            %sqrt(a)
            fct = @interval.power_element;
            result = interval.elementwiseoperator(fct,a,1/2);
        end
        function result = nthroot(a,n)
            %nthroot(a,n)
            fct = @interval.power_element;
            result = interval.elementwiseoperator(fct,a,n^(-1));
        end
        
        function result = abs(a)
            %abs(a)
            fct = @interval.abs_element;
            result = interval.elementwiseoperator(fct,a);
        end

        function result = norm(a)
            %norm(a)
            result = norm_p(a,2);
        end
        function result = norm_p(a,p)
            %nthroot(a,b)
            if length(size(a)) == 2 && length(size(p)) == 2 && all(size(p) == [1,1])
                sizes = size(a);
                if sizes(1) == 1 %row vector
                    result = sum(a.^p,2)%^(1/p);
                elseif sizes(2) == 1 %column vector
                    result = sum(a.^p,1)%^(1/p);
                else %matrix
                    error('matrix norm not yet implemented')
                end
            elseif length(size(a)) ~= 2
                error('invalid input a for norm_p(a,p)')
            else
                error('norm_p(a,b) function requires a single scalar p')
            end
        end
        
        function disp(this)
            disp('Interval');
            dispmat = interval.int2mat(this);
            disp(dispmat);
        end

        function str = int2str(this)
            str = mat2str(interval.int2mat(this));
        end
        
        function result = summinmax(this,directions,to_be_evaluated)
            if nargin<2
                directions = 1;
            end
            s = size(this);
            d = length(s);
            snew = s;
            snew(directions) = ones(1,length(directions));
            result = interval(zeros(snew));
            specifieddirections = false(1,d);
            specifieddirections(directions) = true(1,length(directions));

            vLim = s;
            vLim1 = s;
            vLim1(directions) = [];

            v1    = ones(1, length(vLim1));
            ready = false;
            while ~ready
                Index1 = arrayindexing.sub2indV(vLim1, v1);
                
                vLim2 = s(directions);
                v2    = ones(1, length(vLim2));
                ready = false;
                while ~ready
                    v = zeros(1,d);
                    v(~specifieddirections) = v1;
                    v(specifieddirections) = v2;
                    Index = arrayindexing.sub2indV(vLim, v);
                    eval(to_be_evaluated);
                    % Update the index vector:
                    [v2,ready] = arrayindexing.updateindexvec(v2,vLim2);
                end
                % Update the index vector:
                [v1,ready] = arrayindexing.updateindexvec(v1,vLim1);
            end
        end
        function result = sum(this,directions)
            to_be_evaluated = 'result(Index1) = result(Index1)+this(Index);';
            result = summinmax(this,directions,to_be_evaluated);
        end
        function result = min(this,varargin)
            to_be_evaluated = 'result(Index1) = result(Index1)+this(Index);';
            result = summinmax(this,directions,to_be_evaluated);
        end
        %max
%         function result = max(this,varargin)
%         end
        
        function result = cross(a,b)
            %cross(a,b)
            if numel(a) == 3 && numel(b) == 3 %length(size(a)) == 2 && length(size(b)) == 2 && all(size(a) == [1,3]) && all(size(b) == [1,3])
                result = [(a(2)*b(3)) - (a(3)*b(2)), (a(3)*b(1)) - (a(1)*b(3)), (a(1)*b(2)) - (a(2)*b(1))];
            else
                error('cross function is only implemented for two interval vectors of size 3')
            end
        end

        %det
        function result = det(this)
            if length(size(this)) == 2 && all(size(this) == [2,2])
                result = this(1)*this(4)-this(2)*this(3);
            else
                error('det function has to be implemented for non-2x2 matrices')
            end
        end
        %cap
        %cup
        function result = cup(this,b)
            fct = @interval.cup_element;
            result = interval.elementwiseoperator(fct,this,b);
        end

        function result = sign(this)
            fct = @interval.sign_element;
            result = interval.elementwiseoperator(fct,this);
        end

        function result = extractbounds(this)
            s = size(this);
            n = prod(s);
            result = zeros(n,2);
            for i = 1:n
                result(i,:) = this(i).bounds;
            end
            if s(end) == 1
                s(end) = [];
            end
            result = reshape(result,[s,2]);
        end
        
        function result = subset(a,b)
            fct = @interval.subset_element;
            result = all(interval.elementwiseoperator(fct,a,b));
        end
        function result = superset(this,bi)
            %this superset of b
            result = bi.subset(this);
        end

        function result = scale(this,factor)
            result = interval(zeros(size(this)));
            s = 1+(factor-1)/2;
            for i=1:numel(this)
                result(i) = interval(s*this(i).bounds(1)+(1-s)*this(i).bounds(2),(1-s)*this(i).bounds(1)+s*this(i).bounds(2));
            end
        end
        function result = linspace(this,n)
            result = zeros([numel(this),n]);
            for i=1:numel(this)
                result(i,:) = linspace(this(i).bounds(1),this(i).bounds(2),n);
            end
            s = size(this);
            if length(s) == 2 && s(2) == 1
                s(2) = [];
            end
            result = reshape(result,[s,n]);
        end

        %%%%%%%%%%%%% functions which need revision
%        
%         function result = mrdivide(a,b)
%             %a/b
%             %%%%%%%%%%% Be careful if 0 is contained in numerator and denominator
%             result = a*b.inverseinterval;
%         end
% %         function result = mpower(a,b)
% %             %a^b
% % %             res1 = 
% % %             result = a.interval+b.interval;
% %         end
    end
    
    
    
    methods(Static)
        function result = elementwiseoperator(fct,a,b)
            %Computes elementwise: fct(a,b) or fct(a) if b is not defined
            %Possible outputtupes are: 'interval' or 'logical'
            %allows for arrays a and b to have only 1 element while the other doesn't, similar to 1+[2,3] = [3,4]
            
            s = size(a);
            if isa(a,'double')
                a = interval(a);
            end
            if nargin >= 3
                if numel(a) == 1
                    s = size(b);
                end
                if isa(b,'double')
                    b = interval(b);
                end
            end
            snum = prod(s); %#ok<NASGU>
            for ind = prod(s):-1:1
                if nargin <= 2
                    input = {a(ind).bounds};
                else
                    if numel(a) == 1
                        input = {a.bounds,b(ind).bounds};
                    elseif numel(b) == 1
                        input = {a(ind).bounds,b.bounds};
                    else
                        input = {a(ind).bounds,b(ind).bounds};
                    end
                end
                result(ind) = fct(input{:});
            end
            result = reshape(result,s);
        end
        
        function result = plus_element(a,b)
            %a+b
            if (a(2)<a(1) && b(2)<b(1)) || (a(1)+b(1) <= a(2)+b(2) && (a(2)<a(1) || b(2) < b(1)))
                bounds = [-inf,inf];
            else
                bounds = a+b;
            end
            result = interval(bounds(1),bounds(2));
        end
        function result = uminus_element(a)
            %-a
            result = interval(-a(2),-a(1));
        end
        function result = times_element(a,b)
            %a*b
            values = a'*b;
            if a(1) <= a(2) && b(1) <= b(2)
                bounds = [min(values(:)),max(values(:))];
            else
                %make a contain infinity
                if a(1) <= a(2)
                    c = a;
                    a = b;
                    b = c;
                end
                if interval.subset_element([0,0],b) || (b(2) < b(1) && interval.subset_element([0,0],a)) %0 in a or b
                    bounds = [-inf,inf];
                elseif b(2) < b(1) && ~interval.subset_element([0,0],a) && ~interval.subset_element([0,0],b) %0 notin a or b
                    bounds = [min(values(values>0)),max(values(values<0))];
                else
                    if interval.subset_element([0,0],b)
                        c = a;
                        a = b;
                        b = c;
                    end
                    if ~interval.subset_element([0,0],a)
                        if 0 < b(1) %b positive
                            bounds = [a(1)*b(1),a(2)*b(1)];
                        else %b negative
                            bounds = [a(2)*b(2),a(1)*b(2)];
                        end
                    else
                        if 0<a(2) %0 < complement(a)
                            if 0 < b(1) %b positive
                                bounds = [a(1)*b(1),a(2)*b(2)];
                                if bounds(2)>bounds(1)
                                    bounds = [-inf,inf];
                                end
                            else %b negative
                                bounds = [a(2)*b(1),a(1)*b(2)];
                                if bounds(2)>bounds(1)
                                    bounds = [-inf,inf];
                                end
                            end
                        else %complement(a) < 0
                            if 0 < b(1) %b positive
                                bounds = [a(1)*b(2),a(2)*b(1)];
                                if bounds(2)>bounds(1)
                                    bounds = [-inf,inf];
                                end
                            else %b negative
                                bounds = [a(1)*b(1),a(2)*b(2)];
                                if bounds(2)>bounds(1)
                                    bounds = [-inf,inf];
                                end
                            end
                        end
                    end
                end
            end
            result = interval(bounds(1),bounds(2));
        end
        
        
        
        
        
        function result = inverse_element(a)
            %Computes the inverse interval a^-1
            result = interval(1/a(2),1/a(1));
        end

        function result = abs_element(a)
            %|a|
            if a(1) <= a(2) %[a(1), a(2)]
                if a(1) >= 0
                    result = interval(a(1), a(2));
                elseif a(2) <= 0
                    result = interval(-a(2), -a(1));
                else %a(1) < 0 < a(2)
                    result = interval(0, max(-a(1),a(2)));
                end
            else %[-inf, a(2)] \cup [a(1), inf]
                if a(2) >= 0 || a(1) <= 0
                    result = interval(0, inf);
                else
                    result = interval(min(abs(a(1)), abs(a(2))), inf);
                end
            end
        end
        
        function result = power_element(a,b)
            %Computes a^b assuming b is a real number 
            %and a is non-negative when b is not an integer
            if a(2) >= a(1)
                if b(1) == b(2) %b can be treated as a single real number
                    b = b(1);
                    if rem(b(1),1) == 0 %b is an integer
                        if b >= 0
                            if rem(b,2) == 0 && a(1) <= 0 && 0 <= a(2)
                                result = interval(0,max(a(1)^b,a(2)^b));
                            else
                                res1 = a(1)^b;
                                res2 = a(2)^b;
                                result = interval(min(res1,res2),max(res1,res2));
                            end
                        else % b < 0
                            if rem(b,2) == 0 && a(1) <= 0 && 0 <= a(2)
                                result = interval(min(a(1)^b,a(2)^b), inf);
                            elseif a(1) <= 0 && 0 <= a(2) %rem(b,2) == 1
                                result = interval(-inf, inf);
                            else %a does not include 0
                                res1 = a(1)^b;
                                res2 = a(2)^b;
                                result = interval(min(res1,res2),max(res1,res2));
                            end
                        end
                    elseif a >= 0 %b is not an integer, but a is non-negative
                        res1 = a(1)^b;
                        res2 = a(2)^b;
                        result = interval(min(res1,res2),max(res1,res2));
                    else
                        error("interval method a.^b not implemented for non-integer b and negative a")
                    end
                else
                    error("interval method a.^b not implemented for an interval b")
                end
            else %a(1) > a(2) (a = [-\inf,a(2)] \cup [a(1),\inf])
                result = cup(interval(-inf,a(2)).^2, interval(a(1),inf).^2);
            end
        end
        
        function result = lt_element(a,b)
            %a < b
            result = (a(1)<=a(2)) && (b(1)<=b(2)) && a(2)<b(1);
        end
        function result = le_element(a,b)
            %a <=b
            result = (a(1)<=a(2)) && (b(1)<=b(2)) && a(2)<=b(1);
        end
        function result = eq_element(a,b)
            %a == b
            result = all(abs(a - b) == 3*eps);
        end
        
        function result = cup_element(a,b)
            if a(2) >= a(1) && b(2) >= b(1)
                % [a(1), a(2)] \cup [b(1), b(2)]
                result = interval(min(a(1),b(1)), max(a(2),b(2)));
            elseif a(2) >= a(1) && b(2) < b(1)
                % [a(1), a(2)] \cup ([-inf, b(2)] \cup [b(1), inf])
                if (a(1) <= b(2) && a(2) >= b(1)) || (a(1) >= b(2) && a(2) <= b(1))
                    result = interval(-inf, inf);
                elseif a(1) <= b(2) &&  a(2) < b(1)
                    result = interval(b(1), a(2));
                else % a(1) > b(2) &&  a(2) >= b(1)
                    result = interval(a(1), b(2));
                end
            elseif a(2) < a(1) && b(2) >= b(1)
                % ([-inf, a(2)] \cup [a(1), inf]) \cup [b(1), b(2)]
                %symmetric to the previous case
                c = a;
                a = b;
                b = c;
                if (a(1) <= b(2) && a(2) >= b(1)) || (a(1) >= b(2) && a(2) <= b(1))
                    result = interval(-inf, inf);
                elseif a(1) <= b(2) &&  a(2) < b(1)
                    result = interval(b(1), a(2));
                else % a(1) > b(2) &&  a(2) >= b(1)
                    result = interval(a(1), b(2));
                end
            else % a(2) < a(1) && b(2) < b(1)
                % ([-inf, a(2)] \cup [a(1), inf]) \cup ([-inf, b(2)] \cup [b(1), inf])
                if max(a(2), b(2)) >= min(a(1), b(1))
                    result = interval(-inf, inf);
                else
                    result = interval(min(a(1),b(1)), max(a(2), b(2)));
                end
            end
            %to be checked
        end
        
        function result = sign_element(a)
            result = interval(min(sign(a(1)), sign(a(2))), max(sign(a(1)), sign(a(2))));
        end
        
        
        
        
        
        
        
        
        
        function result = subset_element(a,b)
            %a subset of b
            if a(1)<=a(2) && b(1) <= b(2)
                result = b(1) <= a(1) && a(2) <= b(2);
            elseif a(1)<=a(2) && ~(b(1) <= b(2))
                result = a(2) <= b(1) || b(2) <= a(1);
            elseif ~(a(1)<=a(2)) && b(1) <= b(2)
                result = false;
            elseif ~(a(1)<=a(2)) && ~(b(1) <= b(2))
                result = a(1)<=b(1) && b(2)<=a(2);
            end
        end
        
        
        
        
        
        function Index = sub2indV(Vlim,X)
            k     = [1, cumprod(Vlim)];
            Index = sum(k(1:length(X)) .* (X - 1)) + 1;
        end
        function v = ind2subV(Vlim, ind)
            ind = ind-1;
            v = zeros(1,0);
            for i = 1:length(Vlim)
                v(i) = 1+mod(ind,Vlim(i));
                ind = (ind-v(i)+1)/Vlim(i);
            end
        end
        function mat = int2mat(I)
            %Gather the interval bounds in a matrix
            sizeI = size(I);
            if sizeI(end) == 1
                sizeI(end) = [];
            end
            mat = zeros([sizeI,2]);
            num = numel(I);
            for i = 1:num
                if ~isempty(I(i).bounds)
                    mat([i,i+num]) = I(i).bounds;
                else
                    mat = ['empty ',num2str(sizeI),' array of intervals'];
                end
            end
        end
        
        function int = zeros(varargin)
            sizes = cell2mat(varargin);
            int = interval(zeros(sizes));
        end
        function int = ones(varargin)
            sizes = cell2mat(varargin);
            int = interval(ones(sizes));
        end
        function int = unit(varargin)
            sizes = cell2mat(varargin);
            if length(sizes) == 1
                A = zeros([sizes,1]);
                B = ones([sizes,1]);
            else
                A = zeros(sizes);
                B = ones(sizes);
            end
            int = interval(A,B);
        end
        
        
        function test()
            interval.test1;
        end
        function test1()
            a = interval(0,1);
            b = interval(3,4);
            disp(a+b);
            disp(a-b);
            disp(a.*b);
            disp(a./b);
        end
    end
end

