Skip to content

Commit 5e95962

Browse files
committed
Actual initial commit
1 parent 52aa89a commit 5e95962

File tree

517 files changed

+155244
-0
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

517 files changed

+155244
-0
lines changed

Fct.h

+72
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,72 @@
1+
#pragma once
2+
3+
#include <memory>
4+
#include <functional>
5+
#include <math.h>
6+
7+
namespace netn {
8+
9+
template <typename T> class Model;
10+
struct Component;
11+
typedef std::function<double(double)> func_t;
12+
13+
template <typename T>
14+
class Fct : public Model<T> {
15+
public:
16+
Fct(const Model<T> & var, const func_t & func, const func_t & deriv);
17+
Fct(const Fct & other);
18+
virtual ~Fct() = default;
19+
20+
T eval() const override;
21+
T derivPart(const Component & component) const override;
22+
std::shared_ptr<Model<T>> toModel() const override;
23+
private:
24+
func_t _func;
25+
func_t _deriv;
26+
std::shared_ptr<Model<T>> _var;
27+
};
28+
29+
template <typename T>
30+
Fct<T> sin(const Model<T> & model) {
31+
return Fct<T>(model,
32+
[](double x) {return std::sin(x); },
33+
[](double x) {return std::cos(x); });
34+
}
35+
36+
template <typename T>
37+
Fct<T> cos(const Model<T> & model) {
38+
return Fct<T>(model,
39+
[](double x) {return std::cos(x); },
40+
[](double x) {return - std::sin(x); });
41+
}
42+
43+
template <typename T>
44+
Fct<T> exp(const Model<T> & model) {
45+
return Fct<T>(model,
46+
[](double x) {return std::exp(x); },
47+
[](double x) {return std::exp(x); });
48+
}
49+
50+
template <typename T>
51+
Fct<T> log(const Model<T> & model) {
52+
return Fct<T>(model,
53+
[](double x) {return std::log(x); },
54+
[](double x) {return 1 / x; });
55+
}
56+
57+
template <typename T>
58+
Fct<T> pow(const Model<T> & model, double exp) {
59+
return Fct<T>(model,
60+
[exp](double x) {return std::pow(x, exp); },
61+
[exp](double x) {return exp * std::pow(x, exp - 1); });
62+
}
63+
64+
template <typename T>
65+
Fct<T> sigmoid(const Model<T> & model) {
66+
return Fct<T>(model,
67+
[](double x) {return 1 / (1 + std::exp(-x)); },
68+
[](double x) {double ex = std::exp(x); return ex / ((1 + ex) * (1 + ex)); });
69+
}
70+
}
71+
72+
#include "Fct.inl"

Fct.inl

+48
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
#include "armadillo/armadillo"
2+
3+
#include "Model.h"
4+
#include "Var.h"
5+
6+
namespace netn {
7+
8+
template<typename T>
9+
inline Fct<T>::Fct(const Model<T> & var, const func_t & func, const func_t & deriv)
10+
: _var(var.toModel()), _func(func), _deriv(deriv) {}
11+
12+
template<typename T>
13+
inline Fct<T>::Fct(const Fct & other)
14+
: _var(other._var), _func(other._func), _deriv(other._deriv) {}
15+
16+
template<typename T>
17+
inline T Fct<T>::eval() const {
18+
return _func(_var->eval());
19+
}
20+
21+
template<>
22+
inline arma::mat Fct<arma::mat>::eval() const {
23+
arma::mat result(_var->eval());
24+
result.for_each([&](arma::mat::elem_type & elem) {elem = _func(elem); });
25+
return result;
26+
}
27+
28+
template<typename T>
29+
inline T Fct<T>::derivPart(const Component & component) const {
30+
return _var->derivPart(component) * _deriv(_var->eval());
31+
}
32+
33+
template<>
34+
inline arma::mat Fct<arma::mat>::derivPart(const Component & component) const {
35+
arma::mat result(_var->eval());
36+
arma::mat deriv(_var->derivPart(component));
37+
38+
for (int i = 0; i < result.size(); i++) {
39+
result.at(i) = deriv.at(i) * _deriv(result.at(i));
40+
}
41+
return result;
42+
}
43+
44+
template<typename T>
45+
inline std::shared_ptr<Model<T>> Fct<T>::toModel() const {
46+
return std::make_shared<Fct<T>>(*this);
47+
}
48+
}

Maths.sln

+28
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
2+
Microsoft Visual Studio Solution File, Format Version 12.00
3+
# Visual Studio 14
4+
VisualStudioVersion = 14.0.25420.1
5+
MinimumVisualStudioVersion = 10.0.40219.1
6+
Project("{8BC9CEB8-8B4A-11D0-8D11-00A0C91BC942}") = "Maths", "Maths.vcxproj", "{03ED9FC5-E8FD-4A42-B348-7BCE53EE2F71}"
7+
EndProject
8+
Global
9+
GlobalSection(SolutionConfigurationPlatforms) = preSolution
10+
Debug|x64 = Debug|x64
11+
Debug|x86 = Debug|x86
12+
Release|x64 = Release|x64
13+
Release|x86 = Release|x86
14+
EndGlobalSection
15+
GlobalSection(ProjectConfigurationPlatforms) = postSolution
16+
{03ED9FC5-E8FD-4A42-B348-7BCE53EE2F71}.Debug|x64.ActiveCfg = Debug|x64
17+
{03ED9FC5-E8FD-4A42-B348-7BCE53EE2F71}.Debug|x64.Build.0 = Debug|x64
18+
{03ED9FC5-E8FD-4A42-B348-7BCE53EE2F71}.Debug|x86.ActiveCfg = Debug|Win32
19+
{03ED9FC5-E8FD-4A42-B348-7BCE53EE2F71}.Debug|x86.Build.0 = Debug|Win32
20+
{03ED9FC5-E8FD-4A42-B348-7BCE53EE2F71}.Release|x64.ActiveCfg = Release|x64
21+
{03ED9FC5-E8FD-4A42-B348-7BCE53EE2F71}.Release|x64.Build.0 = Release|x64
22+
{03ED9FC5-E8FD-4A42-B348-7BCE53EE2F71}.Release|x86.ActiveCfg = Release|Win32
23+
{03ED9FC5-E8FD-4A42-B348-7BCE53EE2F71}.Release|x86.Build.0 = Release|Win32
24+
EndGlobalSection
25+
GlobalSection(SolutionProperties) = preSolution
26+
HideSolutionNode = FALSE
27+
EndGlobalSection
28+
EndGlobal

Maths.vcxproj

+134
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,134 @@
1+
<?xml version="1.0" encoding="utf-8"?>
2+
<Project DefaultTargets="Build" ToolsVersion="14.0" xmlns="http://schemas.microsoft.com/developer/msbuild/2003">
3+
<ItemGroup Label="ProjectConfigurations">
4+
<ProjectConfiguration Include="Debug|Win32">
5+
<Configuration>Debug</Configuration>
6+
<Platform>Win32</Platform>
7+
</ProjectConfiguration>
8+
<ProjectConfiguration Include="Release|Win32">
9+
<Configuration>Release</Configuration>
10+
<Platform>Win32</Platform>
11+
</ProjectConfiguration>
12+
<ProjectConfiguration Include="Debug|x64">
13+
<Configuration>Debug</Configuration>
14+
<Platform>x64</Platform>
15+
</ProjectConfiguration>
16+
<ProjectConfiguration Include="Release|x64">
17+
<Configuration>Release</Configuration>
18+
<Platform>x64</Platform>
19+
</ProjectConfiguration>
20+
</ItemGroup>
21+
<PropertyGroup Label="Globals">
22+
<ProjectGuid>{03ED9FC5-E8FD-4A42-B348-7BCE53EE2F71}</ProjectGuid>
23+
<RootNamespace>Maths</RootNamespace>
24+
<WindowsTargetPlatformVersion>8.1</WindowsTargetPlatformVersion>
25+
</PropertyGroup>
26+
<Import Project="$(VCTargetsPath)\Microsoft.Cpp.Default.props" />
27+
<PropertyGroup Condition="'$(Configuration)|$(Platform)'=='Debug|Win32'" Label="Configuration">
28+
<ConfigurationType>Application</ConfigurationType>
29+
<UseDebugLibraries>true</UseDebugLibraries>
30+
<PlatformToolset>v140</PlatformToolset>
31+
<CharacterSet>MultiByte</CharacterSet>
32+
</PropertyGroup>
33+
<PropertyGroup Condition="'$(Configuration)|$(Platform)'=='Release|Win32'" Label="Configuration">
34+
<ConfigurationType>Application</ConfigurationType>
35+
<UseDebugLibraries>false</UseDebugLibraries>
36+
<PlatformToolset>v140</PlatformToolset>
37+
<WholeProgramOptimization>true</WholeProgramOptimization>
38+
<CharacterSet>MultiByte</CharacterSet>
39+
</PropertyGroup>
40+
<PropertyGroup Condition="'$(Configuration)|$(Platform)'=='Debug|x64'" Label="Configuration">
41+
<ConfigurationType>Application</ConfigurationType>
42+
<UseDebugLibraries>true</UseDebugLibraries>
43+
<PlatformToolset>v140</PlatformToolset>
44+
<CharacterSet>MultiByte</CharacterSet>
45+
</PropertyGroup>
46+
<PropertyGroup Condition="'$(Configuration)|$(Platform)'=='Release|x64'" Label="Configuration">
47+
<ConfigurationType>Application</ConfigurationType>
48+
<UseDebugLibraries>false</UseDebugLibraries>
49+
<PlatformToolset>v140</PlatformToolset>
50+
<WholeProgramOptimization>true</WholeProgramOptimization>
51+
<CharacterSet>MultiByte</CharacterSet>
52+
</PropertyGroup>
53+
<Import Project="$(VCTargetsPath)\Microsoft.Cpp.props" />
54+
<ImportGroup Label="ExtensionSettings">
55+
</ImportGroup>
56+
<ImportGroup Label="Shared">
57+
</ImportGroup>
58+
<ImportGroup Label="PropertySheets" Condition="'$(Configuration)|$(Platform)'=='Debug|Win32'">
59+
<Import Project="$(UserRootDir)\Microsoft.Cpp.$(Platform).user.props" Condition="exists('$(UserRootDir)\Microsoft.Cpp.$(Platform).user.props')" Label="LocalAppDataPlatform" />
60+
</ImportGroup>
61+
<ImportGroup Label="PropertySheets" Condition="'$(Configuration)|$(Platform)'=='Release|Win32'">
62+
<Import Project="$(UserRootDir)\Microsoft.Cpp.$(Platform).user.props" Condition="exists('$(UserRootDir)\Microsoft.Cpp.$(Platform).user.props')" Label="LocalAppDataPlatform" />
63+
</ImportGroup>
64+
<ImportGroup Label="PropertySheets" Condition="'$(Configuration)|$(Platform)'=='Debug|x64'">
65+
<Import Project="$(UserRootDir)\Microsoft.Cpp.$(Platform).user.props" Condition="exists('$(UserRootDir)\Microsoft.Cpp.$(Platform).user.props')" Label="LocalAppDataPlatform" />
66+
</ImportGroup>
67+
<ImportGroup Label="PropertySheets" Condition="'$(Configuration)|$(Platform)'=='Release|x64'">
68+
<Import Project="$(UserRootDir)\Microsoft.Cpp.$(Platform).user.props" Condition="exists('$(UserRootDir)\Microsoft.Cpp.$(Platform).user.props')" Label="LocalAppDataPlatform" />
69+
</ImportGroup>
70+
<PropertyGroup Label="UserMacros" />
71+
<PropertyGroup Condition="'$(Configuration)|$(Platform)'=='Debug|Win32'">
72+
<IncludePath>$(VC_IncludePath);$(WindowsSDK_IncludePath);$(VC_SourcePath)/armadillo/</IncludePath>
73+
</PropertyGroup>
74+
<ItemDefinitionGroup Condition="'$(Configuration)|$(Platform)'=='Debug|Win32'">
75+
<ClCompile>
76+
<WarningLevel>Level3</WarningLevel>
77+
<Optimization>Disabled</Optimization>
78+
<SDLCheck>true</SDLCheck>
79+
</ClCompile>
80+
</ItemDefinitionGroup>
81+
<ItemDefinitionGroup Condition="'$(Configuration)|$(Platform)'=='Debug|x64'">
82+
<ClCompile>
83+
<WarningLevel>Level3</WarningLevel>
84+
<Optimization>Disabled</Optimization>
85+
<SDLCheck>true</SDLCheck>
86+
</ClCompile>
87+
</ItemDefinitionGroup>
88+
<ItemDefinitionGroup Condition="'$(Configuration)|$(Platform)'=='Release|Win32'">
89+
<ClCompile>
90+
<WarningLevel>Level3</WarningLevel>
91+
<Optimization>MaxSpeed</Optimization>
92+
<FunctionLevelLinking>true</FunctionLevelLinking>
93+
<IntrinsicFunctions>true</IntrinsicFunctions>
94+
<SDLCheck>true</SDLCheck>
95+
</ClCompile>
96+
<Link>
97+
<EnableCOMDATFolding>true</EnableCOMDATFolding>
98+
<OptimizeReferences>true</OptimizeReferences>
99+
</Link>
100+
</ItemDefinitionGroup>
101+
<ItemDefinitionGroup Condition="'$(Configuration)|$(Platform)'=='Release|x64'">
102+
<ClCompile>
103+
<WarningLevel>Level3</WarningLevel>
104+
<Optimization>MaxSpeed</Optimization>
105+
<FunctionLevelLinking>true</FunctionLevelLinking>
106+
<IntrinsicFunctions>true</IntrinsicFunctions>
107+
<SDLCheck>true</SDLCheck>
108+
</ClCompile>
109+
<Link>
110+
<EnableCOMDATFolding>true</EnableCOMDATFolding>
111+
<OptimizeReferences>true</OptimizeReferences>
112+
</Link>
113+
</ItemDefinitionGroup>
114+
<ItemGroup>
115+
<ClCompile Include="main.cpp" />
116+
</ItemGroup>
117+
<ItemGroup>
118+
<ClInclude Include="Fct.h" />
119+
<ClInclude Include="MatrixOps.h" />
120+
<ClInclude Include="Model.h" />
121+
<ClInclude Include="Ops.h" />
122+
<ClInclude Include="Type_Ops.h" />
123+
<ClInclude Include="Var.h" />
124+
</ItemGroup>
125+
<ItemGroup>
126+
<None Include="Fct.inl" />
127+
<None Include="MatrixOps.inl" />
128+
<None Include="Model.inl" />
129+
<None Include="Ops.inl" />
130+
</ItemGroup>
131+
<Import Project="$(VCTargetsPath)\Microsoft.Cpp.targets" />
132+
<ImportGroup Label="ExtensionTargets">
133+
</ImportGroup>
134+
</Project>

MatrixOps.h

+27
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
#pragma once
2+
3+
#include <memory>
4+
#include "armadillo/armadillo"
5+
6+
#include "Model.h"
7+
8+
namespace netn {
9+
class MatSum : public Model<double> {
10+
public:
11+
MatSum(const MatSum & other);
12+
MatSum(const Model<arma::mat> & model);
13+
14+
double eval() const override;
15+
double derivPart(const Component & component) const override;
16+
17+
std::shared_ptr<Model<double>> toModel() const override;
18+
private:
19+
std::shared_ptr<Model<arma::mat>> _matrix;
20+
};
21+
22+
MatSum sum(const Model<arma::mat> & model) {
23+
return MatSum(model);
24+
}
25+
}
26+
27+
#include "MatrixOps.inl"

MatrixOps.inl

+25
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
#pragma once
2+
3+
#include "MatrixOps.h"
4+
5+
namespace netn {
6+
MatSum::MatSum(const MatSum & other)
7+
: _matrix(other._matrix) {}
8+
9+
inline MatSum::MatSum(const Model<arma::mat> & model)
10+
: _matrix(model.toModel()) {}
11+
12+
inline double netn::MatSum::eval() const {
13+
auto matrix = _matrix->eval();
14+
return arma::accu(matrix);
15+
}
16+
17+
inline double netn::MatSum::derivPart(const Component & component) const {
18+
auto deriv = _matrix->derivPart(component);
19+
return arma::accu(deriv);
20+
}
21+
22+
inline std::shared_ptr<Model<double>> MatSum::toModel() const {
23+
return std::make_shared<MatSum>(*this);
24+
}
25+
}

Model.h

+36
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
#pragma once
2+
3+
#include <memory>
4+
#include <tuple>
5+
6+
namespace netn {
7+
8+
struct Component;
9+
10+
class IModel {
11+
public:
12+
virtual ~IModel() = default;
13+
};
14+
15+
template <typename T>
16+
class Model : public IModel {
17+
public:
18+
typedef T value_t;
19+
20+
virtual ~Model() = default;
21+
22+
virtual value_t eval() const = 0;
23+
virtual value_t derivPart(const Component & component) const = 0;
24+
virtual std::shared_ptr<Model<T>> toModel() const = 0;
25+
26+
// Calcul de gradients
27+
28+
template <typename Var_T>
29+
Var_T computeGradient(const Var<Var_T> & var);
30+
};
31+
32+
template <typename T, typename... Vars>
33+
std::tuple<Vars...> computeGradients(const Model<T> & model, Vars... vars);
34+
}
35+
36+
#include "Model.inl"

0 commit comments

Comments
 (0)