Skip to content

Commit f12e712

Browse files
committed
initial port + horribly inefficient upscale demo
1 parent 45b6440 commit f12e712

13 files changed

+334
-83
lines changed

.gitignore

+2
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
/test/LAUNCH.cs
2+
13
## Ignore Visual Studio temporary files, build results, and
24
## files generated by popular Visual Studio add-ons.
35
##

README.md

+3-33
Original file line numberDiff line numberDiff line change
@@ -1,34 +1,4 @@
1-
An enhanced template for a C# class library project
1+
C# implementation of the [SIREN neural network](https://vsitzmann.github.io/siren/)
2+
(sinusoid activations).
23

3-
## Features
4-
5-
- class lib + test projects
6-
- solution file referencing common stuff in repository: gitignore, license, readme, CI
7-
- Azure DevOps-based build + test pipeline
8-
- tests: DevOps results and coverage integration
9-
- builds NuGet package with
10-
[Source Link](https://docs.microsoft.com/en-us/dotnet/standard/library-guidance/sourcelink) support
11-
- builds symbols package
12-
13-
## Setting up
14-
15-
- clone this repository into the folder you want your new lib to be
16-
- enter working copy directory
17-
- `git branch --unset-upstream` to detach `master` branch from template; now it belongs to **your** project
18-
- `git remote rename origin template` to preserve the ability to pull template updates
19-
- create a new repository on GitHub/GitLab/etc
20-
- `git remote add origin https://full.url/to/your_new.git`
21-
- `git push --set-upstream origin master` to upload your new project
22-
- go to Azure DevOps and create a new pipeline for your new project, and point it to `CI/Azure-Master.yml`
23-
- remove the sample classes and tests, rename the projects (if needed), and start hacking!
24-
25-
## Updating your project to the latest version of the template
26-
27-
- add the template repository to remotes: `git remote add template https://full.url/to/this_project.git`
28-
- `git pull template master`
29-
30-
## Planned/missing features
31-
32-
- run tests on all platforms
33-
- publish preview versions of NuGet package on successful build
34-
- README status badges
4+
Full image learning demo in app folder.

ClassLib.sln Siren.sln

+16-3
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
21
Microsoft Visual Studio Solution File, Format Version 12.00
32
# Visual Studio Version 16
43
VisualStudioVersion = 16.0.30309.148
@@ -15,9 +14,11 @@ Project("{2150E333-8FDC-42A3-9474-1A3956D46DE8}") = "CI", "CI", "{8C3BD3D9-5F7E-
1514
CI\Azure-Master.yml = CI\Azure-Master.yml
1615
EndProjectSection
1716
EndProject
18-
Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "ClassLib", "src\ClassLib.csproj", "{04BE1707-4235-44E6-AB58-48621D5160D3}"
17+
Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Siren", "src\Siren.csproj", "{04BE1707-4235-44E6-AB58-48621D5160D3}"
1918
EndProject
20-
Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "ClassLib.Tests", "test\ClassLib.Tests.csproj", "{241AC79D-0399-4B2A-9C01-07C609F74B69}"
19+
Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Siren.Tests", "test\Siren.Tests.csproj", "{241AC79D-0399-4B2A-9C01-07C609F74B69}"
20+
EndProject
21+
Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "app", "app\app.csproj", "{8E80FDD8-29AE-4F86-9C00-939DAB2A233A}"
2122
EndProject
2223
Global
2324
GlobalSection(SolutionConfigurationPlatforms) = preSolution
@@ -53,6 +54,18 @@ Global
5354
{241AC79D-0399-4B2A-9C01-07C609F74B69}.Release|x64.Build.0 = Release|Any CPU
5455
{241AC79D-0399-4B2A-9C01-07C609F74B69}.Release|x86.ActiveCfg = Release|Any CPU
5556
{241AC79D-0399-4B2A-9C01-07C609F74B69}.Release|x86.Build.0 = Release|Any CPU
57+
{8E80FDD8-29AE-4F86-9C00-939DAB2A233A}.Debug|Any CPU.ActiveCfg = Debug|Any CPU
58+
{8E80FDD8-29AE-4F86-9C00-939DAB2A233A}.Debug|Any CPU.Build.0 = Debug|Any CPU
59+
{8E80FDD8-29AE-4F86-9C00-939DAB2A233A}.Debug|x64.ActiveCfg = Debug|Any CPU
60+
{8E80FDD8-29AE-4F86-9C00-939DAB2A233A}.Debug|x64.Build.0 = Debug|Any CPU
61+
{8E80FDD8-29AE-4F86-9C00-939DAB2A233A}.Debug|x86.ActiveCfg = Debug|Any CPU
62+
{8E80FDD8-29AE-4F86-9C00-939DAB2A233A}.Debug|x86.Build.0 = Debug|Any CPU
63+
{8E80FDD8-29AE-4F86-9C00-939DAB2A233A}.Release|Any CPU.ActiveCfg = Release|Any CPU
64+
{8E80FDD8-29AE-4F86-9C00-939DAB2A233A}.Release|Any CPU.Build.0 = Release|Any CPU
65+
{8E80FDD8-29AE-4F86-9C00-939DAB2A233A}.Release|x64.ActiveCfg = Release|Any CPU
66+
{8E80FDD8-29AE-4F86-9C00-939DAB2A233A}.Release|x64.Build.0 = Release|Any CPU
67+
{8E80FDD8-29AE-4F86-9C00-939DAB2A233A}.Release|x86.ActiveCfg = Release|Any CPU
68+
{8E80FDD8-29AE-4F86-9C00-939DAB2A233A}.Release|x86.Build.0 = Release|Any CPU
5669
EndGlobalSection
5770
GlobalSection(SolutionProperties) = preSolution
5871
HideSolutionNode = FALSE

app/UpscaleProgram.cs

+164
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,164 @@
1+
namespace tensorflow.keras {
2+
using System;
3+
using System.Collections.Generic;
4+
using System.Diagnostics;
5+
using System.Drawing;
6+
using System.Drawing.Imaging;
7+
using System.Linq;
8+
9+
using LostTech.Gradient;
10+
using LostTech.TensorFlow;
11+
12+
using numpy;
13+
14+
using tensorflow.keras.callbacks;
15+
using tensorflow.keras.layers;
16+
using tensorflow.keras.losses;
17+
using tensorflow.keras.optimizers;
18+
19+
class UpscaleProgram {
20+
static void Main(string[] args) {
21+
GradientEngine.UseEnvironmentFromVariable();
22+
TensorFlowSetup.Instance.EnsureInitialized();
23+
24+
// this allows SIREN to oversaturate channels without adding to the loss
25+
var clampToValidChannelRange = PythonFunctionContainer.Of<Tensor, Tensor>(ClampToValidChannelValueRange);
26+
var siren = new Sequential(new object[] {
27+
new GaussianNoise(stddev: 1f/(128*1024)),
28+
new Siren(2, Enumerable.Repeat(256, 5).ToArray()),
29+
new Dense(units: 4, activation: clampToValidChannelRange),
30+
new GaussianNoise(stddev: 1f/128),
31+
});
32+
33+
siren.compile(
34+
// too slow to converge
35+
//optimizer: new SGD(momentum: 0.5),
36+
// lowered learning rate to avoid destabilization
37+
optimizer: new Adam(learning_rate: 0.00032),
38+
loss: new MeanSquaredError());
39+
40+
foreach (string imagePath in args) {
41+
using var original = new Bitmap(imagePath);
42+
byte[,,] image = ToBytesHWC(original);
43+
int height = image.GetLength(0);
44+
int width = image.GetLength(1);
45+
int channels = image.GetLength(2);
46+
Debug.Assert(channels == 4);
47+
48+
var imageSamples = PrepareImage(image);
49+
50+
var coords = SirenTests.Coord(height, width).ToNumPyArray()
51+
.reshape(new[] { width * height, 2 });
52+
53+
var upscaleCoords = SirenTests.Coord(height * 2, width * 2).ToNumPyArray();
54+
55+
var improved = new ImprovedCallback();
56+
improved.OnLossImproved += (sender, eventArgs) => {
57+
if (eventArgs.Epoch < 10) return;
58+
ndarray<float> upscaled = siren.predict(
59+
upscaleCoords.reshape(new[] { height * width * 4, 2 }),
60+
batch_size: 1024);
61+
upscaled = (ndarray<float>)upscaled.reshape(new[] { height * 2, width * 2, channels });
62+
using var bitmap = ToImage(RestoreImage(upscaled));
63+
bitmap.Save("sample4X.png", ImageFormat.Png);
64+
65+
siren.save_weights("sample.weights");
66+
67+
Console.WriteLine();
68+
Console.WriteLine("saved!");
69+
};
70+
71+
siren.fit(coords, imageSamples, epochs: 100, batchSize: 64, stepsPerEpoch: 200,
72+
shuffleMode: TrainingShuffleMode.Batch,
73+
callbacks: new ICallback[] { improved });
74+
}
75+
}
76+
77+
class ImprovedCallback : Callback {
78+
double bestLoss = double.PositiveInfinity;
79+
public override void on_epoch_end(int epoch, IDictionary<string, dynamic> logs) {
80+
base.on_epoch_end(epoch, logs);
81+
if (logs["loss"] < this.bestLoss) {
82+
this.bestLoss = logs["loss"];
83+
this.OnLossImproved?.Invoke(this, new EpochEndEventArgs {
84+
Epoch = epoch,
85+
Logs = logs,
86+
});
87+
}
88+
}
89+
90+
public event EventHandler<EpochEndEventArgs> OnLossImproved;
91+
}
92+
93+
class EpochEndEventArgs : EventArgs {
94+
public int Epoch { get; set; }
95+
public IDictionary<string, dynamic> Logs { get; set; }
96+
}
97+
98+
static ndarray<float> PrepareImage(byte[,,] image) {
99+
int height = image.GetLength(0);
100+
int width = image.GetLength(1);
101+
int channels = image.GetLength(2);
102+
103+
var normalized = SirenTests.NormalizeChannelValue(image.ToNumPyArray());
104+
var flattened = normalized.reshape(new[] { height * width, channels }).astype(np.float32_fn);
105+
return (ndarray<float>)flattened;
106+
}
107+
108+
static Tensor ClampToValidChannelValueRange(Tensor input)
109+
=> tf.clip_by_value(input,
110+
clip_value_min: SirenTests.NormalizeChannelValue(-0.01f),
111+
clip_value_max: SirenTests.NormalizeChannelValue(255.01f));
112+
113+
static unsafe byte[,,] RestoreImage(ndarray<float> learnedImage) {
114+
(int height, int width, int channels) = (ValueTuple<int, int, int>)learnedImage.shape;
115+
var bytes = (learnedImage * 128f + 128f).clip(0, 255).astype(np.uint8_fn).tobytes();
116+
Debug.Assert(bytes.Length == height * width * channels);
117+
byte[,,] result = new byte[height, width, channels];
118+
fixed (byte* dest = result)
119+
fixed (byte* source = bytes)
120+
Buffer.MemoryCopy(source: source, destination: dest, bytes.Length, bytes.Length);
121+
return result;
122+
}
123+
124+
static unsafe Bitmap ToImage(byte[,,] bytesHWC) {
125+
if (bytesHWC.GetLength(2) != 4)
126+
throw new NotSupportedException();
127+
var bitmap = new Bitmap(bytesHWC.GetLength(1), bytesHWC.GetLength(0));
128+
int rowBytes = bitmap.Width * 4;
129+
var rect = new Rectangle(default, new Size(bitmap.Width, bitmap.Height));
130+
var data = bitmap.LockBits(rect, ImageLockMode.WriteOnly, PixelFormat.Format32bppArgb);
131+
try {
132+
fixed (byte* source = bytesHWC) {
133+
for (int y = 0; y < bitmap.Height; y++) {
134+
var dest = data.Scan0 + data.Stride * y;
135+
Buffer.MemoryCopy(&source[rowBytes * y], destination: (byte*)dest, rowBytes, rowBytes);
136+
}
137+
}
138+
} finally {
139+
bitmap.UnlockBits(data);
140+
}
141+
142+
return bitmap;
143+
}
144+
145+
static unsafe byte[,,] ToBytesHWC(Bitmap bitmap) {
146+
byte[,,] result = new byte[bitmap.Height, bitmap.Width, 4];
147+
int rowBytes = bitmap.Width * 4;
148+
var rect = new Rectangle(default, new Size(bitmap.Width, bitmap.Height));
149+
var data = bitmap.LockBits(rect, ImageLockMode.ReadOnly, PixelFormat.Format32bppArgb);
150+
try {
151+
fixed (byte* dest = result) {
152+
for (int y = 0; y < bitmap.Height; y++) {
153+
var source = data.Scan0 + data.Stride * y;
154+
Buffer.MemoryCopy((byte*)source, destination: &dest[rowBytes * y], rowBytes, rowBytes);
155+
}
156+
}
157+
} finally {
158+
bitmap.UnlockBits(data);
159+
}
160+
161+
return result;
162+
}
163+
}
164+
}

app/app.csproj

+26
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
<Project Sdk="Microsoft.NET.Sdk">
2+
3+
<PropertyGroup>
4+
<OutputType>Exe</OutputType>
5+
<TargetFramework>netcoreapp3.1</TargetFramework>
6+
<AssemblyName>siren</AssemblyName>
7+
<RootNamespace>tensorflow.keras</RootNamespace>
8+
</PropertyGroup>
9+
10+
<PropertyGroup Condition="'$(Configuration)|$(Platform)'=='Release|AnyCPU'">
11+
<AllowUnsafeBlocks>true</AllowUnsafeBlocks>
12+
</PropertyGroup>
13+
14+
<PropertyGroup Condition="'$(Configuration)|$(Platform)'=='Debug|AnyCPU'">
15+
<AllowUnsafeBlocks>true</AllowUnsafeBlocks>
16+
</PropertyGroup>
17+
18+
<ItemGroup>
19+
<PackageReference Include="System.Drawing.Common" Version="4.7.0" />
20+
</ItemGroup>
21+
22+
<ItemGroup>
23+
<ProjectReference Include="..\test\Siren.Tests.csproj" />
24+
</ItemGroup>
25+
26+
</Project>

src/Adder.cs

-7
This file was deleted.

src/Siren.cs

+67
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
namespace tensorflow.keras {
2+
using System;
3+
using System.Collections.Generic;
4+
using System.Linq;
5+
6+
using LostTech.Gradient.ManualWrappers;
7+
8+
using tensorflow.keras.layers;
9+
10+
public class Siren : Model {
11+
readonly Dense[] innerLayers;
12+
public float FrequencyScale { get; }
13+
14+
public Siren(int inputSize, int[] innerSizes, float frequencyScale = 30.0f) {
15+
if (inputSize <= 0)
16+
throw new ArgumentOutOfRangeException(nameof(inputSize));
17+
if (innerSizes is null || innerSizes.Length == 0)
18+
throw new ArgumentNullException(nameof(innerSizes));
19+
if (innerSizes.Any(size => size < 0))
20+
throw new ArgumentOutOfRangeException(nameof(innerSizes));
21+
if (float.IsInfinity(frequencyScale) || float.IsNaN(frequencyScale)
22+
|| Math.Abs(frequencyScale) <= 4*float.Epsilon)
23+
throw new ArgumentOutOfRangeException(nameof(frequencyScale));
24+
25+
this.FrequencyScale = frequencyScale;
26+
27+
this.innerLayers = new Dense[innerSizes.Length];
28+
29+
int currentInputSize = inputSize;
30+
for (int innerIndex = 0; innerIndex < innerSizes.Length; innerIndex++) {
31+
double weightLimits = innerIndex > 0
32+
? Math.Sqrt(6.0f / currentInputSize) / this.FrequencyScale
33+
: 1.0f / inputSize;
34+
this.innerLayers[innerIndex] = new Dense(innerSizes[innerIndex],
35+
kernel_initializer: new initializers.uniform(minval: -weightLimits, maxval: +weightLimits)
36+
);
37+
this.Track(this.innerLayers[innerIndex]);
38+
39+
currentInputSize = innerSizes[innerIndex];
40+
}
41+
}
42+
43+
Tensor CallImpl(IGraphNodeBase input, object? mask) {
44+
if (mask != null)
45+
throw new NotImplementedException("mask");
46+
var result = (Tensor)input;
47+
foreach (var innerLayer in this.innerLayers)
48+
result = tf.sin(innerLayer.__call__(result) * this.FrequencyScale);
49+
return result;
50+
}
51+
52+
public override Tensor call(IGraphNodeBase inputs, IGraphNodeBase training, IGraphNodeBase mask)
53+
=> this.CallImpl(inputs, mask);
54+
55+
public override Tensor call(IGraphNodeBase inputs, bool training, IGraphNodeBase? mask = null)
56+
=> this.CallImpl(inputs, mask);
57+
58+
public override Tensor call(IGraphNodeBase inputs, IGraphNodeBase? training = null, IEnumerable<IGraphNodeBase>? mask = null)
59+
=> this.CallImpl(inputs, mask);
60+
61+
public override TensorShape compute_output_shape(TensorShape input_shape) {
62+
var outputShape = input_shape.as_list();
63+
outputShape[^1] = this.innerLayers[^1].units;
64+
return new TensorShape(outputShape);
65+
}
66+
}
67+
}

src/ClassLib.csproj src/Siren.csproj

+6-1
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,11 @@
11
<Project Sdk="Microsoft.NET.Sdk">
22

33
<PropertyGroup>
4+
<AssemblyName>LostTech.TensorFlow.Siren</AssemblyName>
5+
<RootNamespace>tensorflow.keras</RootNamespace>
46
<TargetFramework>netstandard2.0</TargetFramework>
57
<LangVersion>8.0</LangVersion>
8+
<VersionPrefix>0.0.1</VersionPrefix>
69
<Nullable>enable</Nullable>
710

811
<!-- Package stuff -->
@@ -17,10 +20,12 @@
1720
</PropertyGroup>
1821

1922
<ItemGroup>
20-
<None Include="..\LICENSE" Pack="true" PackagePath=""/>
23+
<None Include="..\LICENSE" Pack="true" PackagePath="" />
2124
</ItemGroup>
2225

2326
<ItemGroup>
27+
<PackageReference Include="LostTech.Python.Runtime" Version="3.0.2-b1224" />
28+
<PackageReference Include="LostTech.TensorFlow" Version="1.15.0-RC1" />
2429
<!-- The following is recommended for public projects -->
2530
<PackageReference Include="Microsoft.SourceLink.GitHub" Version="1.0.0" PrivateAssets="All" />
2631
</ItemGroup>

src/Subtracter.cs

-7
This file was deleted.

test/AdderTests.cs

-14
This file was deleted.

0 commit comments

Comments
 (0)