Skip to content

Commit 57fb8f1

Browse files
authored
Initial upload
1 parent 825bb1b commit 57fb8f1

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

44 files changed

+7523
-0
lines changed

+grid/ApproxFunction.m

+50
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
classdef ApproxFunction
2+
% abstract superclass for approximating function
3+
4+
% common properties
5+
properties (Abstract, SetAccess=protected)
6+
SSGrid % state space grid
7+
Nof % # functions
8+
Vals % matrix with interpolated/approximated data points
9+
end
10+
11+
% common methods to be implemented
12+
methods (Abstract)
13+
evaluateAt(obj,points) % evaluation at point list
14+
fitTo(obj,points) % fit to new data points
15+
end
16+
17+
% common methods
18+
methods
19+
function plot3D(cf,dispfct,dispdim,val_other,ln)
20+
np=20;
21+
stbtemp=cf.SSGrid.StateBounds(:,dispdim);
22+
totdim=cf.SSGrid.Ndim;
23+
dim_other=setdiff(1:totdim,dispdim);
24+
25+
gridx=linspace(stbtemp(1,1),stbtemp(2,1),np)';
26+
gridy=linspace(stbtemp(1,2),stbtemp(2,2),np)';
27+
28+
indmat=grid.StateSpaceGrid.makeCombinations([np,np]);
29+
30+
plist=zeros(np*np,totdim);
31+
plist(:,dispdim(1))=gridx(indmat(:,2));
32+
plist(:,dispdim(2))=gridy(indmat(:,1));
33+
for i=1:length(val_other)
34+
plist(:,dim_other(i))=ones(np*np,1)*val_other(i);
35+
end
36+
37+
vlist=evaluateAt(cf,plist);
38+
vlist=vlist(dispfct,:);
39+
40+
if ln==1
41+
vlist=exp(vlist);
42+
elseif ln==2
43+
vlist=1./(1+exp(vlist));
44+
end
45+
figure;
46+
mesh(gridx,gridy,reshape(vlist,np,np)');
47+
end
48+
end
49+
50+
end

+grid/LinearInterpFunction.m

+146
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,146 @@
1+
classdef LinearInterpFunction < grid.ApproxFunction
2+
3+
properties (SetAccess=protected)
4+
% inherited properties (abstract in superclass)
5+
SSGrid
6+
Nof
7+
Vals
8+
% linear interp specific properties
9+
Type % scatter or equidistant
10+
InterpStruct
11+
end
12+
13+
properties (Constant)
14+
SCATTER=10;
15+
EQUI=20;
16+
end
17+
18+
methods
19+
% constructor
20+
function sf=LinearInterpFunction(ssgrid,vals,usemex)
21+
sf.SSGrid=ssgrid;
22+
sf.usemex=usemex;
23+
sf=fitTo(sf,vals);
24+
end
25+
26+
function sf=set.SSGrid(sf,ssg)
27+
% check that grid object at initialization is TensorGrid
28+
if isa(ssg,'grid.TensorGrid')
29+
sf.Type=grid.LinearInterpFunction.EQUI;
30+
elseif isa(ssg,'grid.ScatterGrid')
31+
sf.Type=grid.LinearInterpFunction.SCATTER;
32+
else
33+
error('StateSpaceGrid must be a TensorGrid or a ScatterGrid');
34+
end
35+
sf.SSGrid=ssg;
36+
end
37+
38+
39+
40+
% fit to new values
41+
function sf=fitTo(sf,vals)
42+
[npt,nof]=size(vals);
43+
if npt~=sf.SSGrid.Npt
44+
error('Value matrix must have dimensions (Npt x Nof)');
45+
end
46+
sf.Nof=nof;
47+
sf.Vals=vals;
48+
% distinguish cases
49+
if sf.Type==grid.LinearInterpFunction.EQUI
50+
% SSGrid is TensorGrid object in this case
51+
% reshape values for call to spapi
52+
dimvec=sf.SSGrid.Dimvec';
53+
ndim=sf.SSGrid.Ndim;
54+
vals=reshape(vals,[dimvec,nof]);
55+
vals=permute(vals,[ndim+1,1:ndim]);
56+
baseindex=repmat({':'},1,ndim+1);
57+
sf.InterpStruct=cell(nof,1);
58+
if sf.usemex
59+
for f=1:nof
60+
index=baseindex;
61+
index{1}=f;
62+
fvals=squeeze(vals(index{:}));
63+
fvals=permute(fvals,ndims(fvals):-1:1);
64+
sf.InterpStruct{f}=fvals(:)';
65+
end
66+
else
67+
for f=1:nof
68+
index=baseindex;
69+
index{1}=f;
70+
sf.InterpStruct{f}=griddedInterpolant(sf.SSGrid.Unigrids,squeeze(vals(index{:})),'linear','none');
71+
end
72+
end
73+
else
74+
% SSGrid is ScatterGrid
75+
sf.InterpStruct=[];
76+
end
77+
78+
end
79+
80+
% evaluation
81+
function vals=evaluateAt(sf,points)
82+
[np,ndim]=size(points);
83+
if ndim~=sf.SSGrid.Ndim
84+
error('Point matrix must have dimensions (#points x Ndim)');
85+
end
86+
% extrapolation doesn't work well, so force back into state
87+
% bounds
88+
points_corr=points;
89+
SBlow=ones(np,1)*sf.SSGrid.StateBounds(1,:);
90+
SBhi=ones(np,1)*sf.SSGrid.StateBounds(2,:);
91+
upvio=(points>SBhi);
92+
points_corr(upvio)=SBhi(upvio);
93+
downvio=(points<SBlow);
94+
points_corr(downvio)=SBlow(downvio);
95+
% check case
96+
if sf.Type==grid.LinearInterpFunction.EQUI
97+
vals=zeros(sf.Nof,np);
98+
if sf.usemex
99+
gridlist=sf.SSGrid.Unigrids';
100+
for f=1:sf.Nof
101+
vals(f,:)=linterp_eval2(gridlist, sf.InterpStruct{f}, points_corr);
102+
end
103+
else
104+
for f=1:sf.Nof
105+
vals(f,:)=sf.InterpStruct{f}(points_corr);
106+
end
107+
end
108+
else
109+
% simplex search using saved triangulation
110+
tri=sf.SSGrid.Tessel;
111+
[Tin,bcc] = tsearchn(sf.SSGrid.Pointmat,tri,points_corr);
112+
K = ~isnan(Tin);
113+
vals = zeros(np,sf.Nof);
114+
for d = 1:ndim+1
115+
% delaunay interp in each dimension
116+
vals(K,:) = vals(K,:) + bsxfun(@times,bcc(K,d),sf.Vals(tri(Tin(K),d),:));
117+
end
118+
% for points outside of the convex hull of the standard
119+
% triangulation, use furthest-site triangulation instead
120+
% (option 'Qu' in Qhull)
121+
if any(~K)
122+
% FStri=sf.SSGrid.FSTessel;
123+
% [FSTin,FSbcc]=tsearchn(sf.SSGrid.Pointmat,FStri,points_corr);
124+
% for d = 1:ndim+1
125+
% % delaunay interp in each dimension
126+
% vals(~K,:) = vals(~K,:) + bsxfun(@times,FSbcc(~K,d),sf.Vals(FStri(FSTin(~K),d),:));
127+
% end
128+
[cpi1,d1]=dsearchn(sf.SSGrid.Pointmat,tri,points_corr(~K,:));
129+
redind=setdiff(1:sf.SSGrid.Npt,cpi1);
130+
ptmat2=sf.SSGrid.Pointmat(redind,:);
131+
fctvals2=sf.Vals(redind,:);
132+
[cpi2,d2]=dsearchn(ptmat2,points_corr(~K,:));
133+
b=(ptmat2(cpi2,:)-sf.SSGrid.Pointmat(cpi1,:)).^2;
134+
b=sum(b,2);
135+
x=(d1.^2+b-d2.^2)./(2*b);
136+
vals(~K,:)=sf.Vals(cpi1,:) + repmat(max(0,x),1,sf.Nof).*(fctvals2(cpi2,:)-sf.Vals(cpi1,:));
137+
end
138+
vals=vals';
139+
end
140+
end
141+
142+
143+
end
144+
145+
146+
end

+grid/StateSpaceGrid.m

+80
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,80 @@
1+
classdef StateSpaceGrid
2+
% abstract superclass
3+
4+
properties (SetAccess=protected)
5+
StateBounds
6+
end
7+
8+
properties (Abstract, SetAccess=protected)
9+
Ndim
10+
Npt
11+
Pointmat
12+
Type
13+
end
14+
15+
%--------------------------------------------------------------------------
16+
% Regular methods (some checks for setter methods)
17+
%--------------------------------------------------------------------------
18+
methods
19+
function ssg=set.StateBounds(ssg,stBounds)
20+
nr=size(stBounds,1);
21+
if nr~=2
22+
error('StateBounds must be a 2xNdim matrix');
23+
end
24+
ssg.StateBounds=stBounds;
25+
end
26+
27+
end
28+
%--------------------------------------------------------------------------
29+
% End regular methods
30+
%--------------------------------------------------------------------------
31+
32+
33+
%--------------------------------------------------------------------------
34+
% Static methods (miscellaenous functions for grids)
35+
%--------------------------------------------------------------------------
36+
methods (Static)
37+
function nod=makeCombinations(x)
38+
39+
dim=length(x);
40+
np=prod(x);
41+
nod=zeros(np,dim);
42+
43+
% first dimension
44+
nod(:,1)=sort(repmat((1:x(1))',np/x(1),1));
45+
% remaining dims
46+
for d=2:dim
47+
% number of points divided by number of combinations from previous
48+
% dimensions
49+
rep_dim=np/prod(x(1:d));
50+
vec_elem=sort(repmat((1:x(d))',rep_dim,1));
51+
nod(:,d)=repmat(vec_elem,np/(rep_dim*x(d)),1);
52+
end
53+
end
54+
55+
56+
function nod=makeCombinations_rev(x)
57+
58+
dim=length(x);
59+
np=prod(x);
60+
nod=zeros(np,dim);
61+
62+
% first dimension
63+
nod(:,dim)=sort(repmat((1:x(dim))',np/x(dim),1));
64+
% remaining dims
65+
for d=dim-1:-1:1
66+
% number of points divided by number of combinations from previous
67+
% dimensions
68+
rep_dim=np/prod(x(d:dim));
69+
vec_elem=sort(repmat((1:x(d))',rep_dim,1));
70+
nod(:,d)=repmat(vec_elem,np/(rep_dim*x(d)),1);
71+
end
72+
end
73+
74+
end
75+
%--------------------------------------------------------------------------
76+
% End static methods
77+
%--------------------------------------------------------------------------
78+
79+
80+
end

+grid/TensorGrid.m

+84
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,84 @@
1+
classdef TensorGrid < grid.StateSpaceGrid
2+
3+
properties (SetAccess=protected)
4+
% inherited properties (abstract in superclass)
5+
Ndim
6+
Npt
7+
Pointmat
8+
Type
9+
% tensor specific properties
10+
Unigrids
11+
Dimvec
12+
end
13+
14+
methods
15+
% constructor
16+
function ssg = TensorGrid(arg1,arg2)
17+
if nargin < 1
18+
error('Not enough input arguments.');
19+
end
20+
if nargin==1
21+
% if initialized by directly specifying univariate grids
22+
ssg=makeTensorGrid(ssg,arg1);
23+
ssg.StateBounds=zeros(2,ssg.Ndim);
24+
for d=1:ssg.Ndim
25+
gridd=ssg.Unigrids{d};
26+
ssg.StateBounds(:,d)=[gridd(1);gridd(end)];
27+
end
28+
ssg.Type='tensor (direct)';
29+
else
30+
% if initialized by specifying bounds and dimvec
31+
% construct equally space grid in that case
32+
ssg.StateBounds=arg1;
33+
dimvec=arg2;
34+
ndim=length(dimvec);
35+
unigrids=cell(ndim,1);
36+
for d=1:ndim
37+
unigrids{d}=linspace(ssg.StateBounds(1,d),ssg.StateBounds(2,d),dimvec(d))';
38+
end
39+
ssg=makeTensorGrid(ssg,unigrids);
40+
ssg.Type='tensor (equ. sp.)';
41+
end
42+
end
43+
44+
function obj = set.Unigrids(obj,gridarray)
45+
if (~iscell(gridarray) || length(size(gridarray))>2)
46+
error('grid array must be a (Ndim x 1) cell array');
47+
end
48+
if size(gridarray,2)>1
49+
gridarray=permute(gridarray,[2,1]);
50+
end
51+
if size(gridarray,2)>1
52+
error('grid array must be a (Ndim x 1) cell array');
53+
end
54+
obj.Unigrids=gridarray;
55+
end
56+
57+
function grid=getUnigrid(ssg,dim)
58+
grid=ssg.Unigrids{dim};
59+
end
60+
61+
% make tensor grid from univariate grids
62+
function ssg = makeTensorGrid(ssg,unigrids)
63+
ssg.Unigrids = unigrids;
64+
ssg.Ndim=length(ssg.Unigrids);
65+
ssg.Dimvec=zeros(ssg.Ndim,1);
66+
for i=1:ssg.Ndim
67+
ssg.Dimvec(i)=length(ssg.Unigrids{i});
68+
end
69+
% generate point matrix
70+
ssg.Npt=prod(ssg.Dimvec);
71+
indmat=grid.StateSpaceGrid.makeCombinations_rev(ssg.Dimvec);
72+
ssg.Pointmat=zeros(ssg.Npt,ssg.Ndim);
73+
for d=1:ssg.Ndim
74+
gridd=ssg.Unigrids{d};
75+
ssg.Pointmat(:,d)=gridd(indmat(:,d));
76+
end
77+
end
78+
79+
80+
end
81+
82+
83+
84+
end

0 commit comments

Comments
 (0)