From 0a57421622270568bc4a842e3d3a5702641d7690 Mon Sep 17 00:00:00 2001 From: Jishnu Bhattacharya Date: Mon, 21 Nov 2022 16:46:19 +0400 Subject: [PATCH] Support empty LowRankFun --- src/Multivariate/LowRankFun.jl | 23 +++++++++++++---------- src/Multivariate/ProductFun.jl | 3 ++- src/Multivariate/VectorFun.jl | 2 +- 3 files changed, 16 insertions(+), 12 deletions(-) diff --git a/src/Multivariate/LowRankFun.jl b/src/Multivariate/LowRankFun.jl index b82dc8a9..bb0a179e 100644 --- a/src/Multivariate/LowRankFun.jl +++ b/src/Multivariate/LowRankFun.jl @@ -34,7 +34,6 @@ struct LowRankFun{S<:Space,M<:Space,SS<:AbstractProductSpace,T<:Number} <: Bivar B::Vector{VFun{M,T}}, space::SS) where {S,M,SS,T} @assert length(A) == length(B) - @assert length(A) > 0 new{S,M,SS,T}(A,B,space) end end @@ -57,29 +56,33 @@ size(f::LowRankFun) = size(f,1),size(f,2) function LowRankFun(X::Array{T},dx::S,dy::M) where {S<:Space,M<:Space,T<:Number} U,Σ,V=svd(X) - m=max(1,count(s->s>10eps(T),Σ)) + m=count(s->s>10eps(T),Σ) - A=VFun{S,T}[Fun(dx,U[:,k].*sqrt(Σ[k])) for k=1:m] - B=VFun{M,T}[Fun(dy,conj(V[:,k]).*sqrt(Σ[k])) for k=1:m] + A=VFun{S,T}[Fun(dx, (@view U[:,k]).*sqrt(Σ[k])) for k=1:m] + B=VFun{M,T}[Fun(dy, conj.(@view V[:,k]).*sqrt(Σ[k])) for k=1:m] - LowRankFun(A,B) + LowRankFun(A, B, dx ⊗ dy) end ## Construction in a TensorSpace via a Vector of Funs function LowRankFun(X::Vector{VFun{S,T}},d::TensorSpace{SV,DD}) where {S,T,DD<:EuclideanDomain{2},SV} - @assert d[1] == space(X[1]) - LowRankFun(X,d[2]) + if !isempty(X) + @assert factor(d, 1) == space(X[1]) + else + @assert factor(d, 1) isa S + end + LowRankFun(X, factor(d, 2), factor(d, 1)) end -function LowRankFun(X::Vector{VFun{S,T}},dy::Space) where {S,T} - m=mapreduce(ncoefficients,max,X) +function LowRankFun(X::Vector{VFun{S,T}},dy::Space,dx::Space = nothing) where {S,T} + m=mapreduce(ncoefficients,max,X, init=0) M=zeros(T,m,length(X)) for k=1:length(X) M[1:ncoefficients(X[k]),k]=X[k].coefficients end - LowRankFun(M,space(X[1]),dy) + LowRankFun(M, dx === nothing ? space(X[1]) : dx, dy) end diff --git a/src/Multivariate/ProductFun.jl b/src/Multivariate/ProductFun.jl index 896927db..62bfa777 100644 --- a/src/Multivariate/ProductFun.jl +++ b/src/Multivariate/ProductFun.jl @@ -340,7 +340,8 @@ end *(f::ProductFun,B::Fun) = transpose(B*transpose(f)) -LowRankFun(f::ProductFun{S,V,SS}) where {S,V,SS<:TensorSpace} = LowRankFun(f.coefficients,factor(space(f),2)) +LowRankFun(f::ProductFun{S,V,SS}) where {S,V,SS<:TensorSpace} = + LowRankFun(f.coefficients, space(f)) LowRankFun(f::Fun) = LowRankFun(ProductFun(f)) function differentiate(f::ProductFun{S,V,SS},j::Integer) where {S,V,SS<:TensorSpace} diff --git a/src/Multivariate/VectorFun.jl b/src/Multivariate/VectorFun.jl index 30288318..b8c63222 100644 --- a/src/Multivariate/VectorFun.jl +++ b/src/Multivariate/VectorFun.jl @@ -198,4 +198,4 @@ end #TODO: fix for complex evaluate(A::AbstractArray{T},x::Number) where {T<:Fun} = - typeof(first(A)(x))[Akj(x) for Akj in A] + [Akj(x) for Akj in A]