Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ANN module #1

Open
wants to merge 1 commit into
base: dev/m2lines
Choose a base branch
from
Open

Conversation

Pperezhogin
Copy link

This PR adds capability to make inference with ANN having ReLU activation function.

Usage:

use MOM_ANN,           only : ANN_init, ANN_apply, ANN_end, ANN_CS

type(ANN_CS) :: ann_instance !< ANN instance
call ANN_init(ann_instance, 'path/to/ann.nc')
call ANN_apply(x, y, ann_instance)
call ANN_end(ann_instance)

This module is compatible only with certain ANNs. In Pytorch implementation of such ANN together with export function are shown below:

import torch.nn as nn
from torch.nn import functional as functional
import xarray as xr

class ANN(nn.Module):
    def __init__(self, layer_sizes=[3, 17, 27, 5]):
        super().__init__()
        
        self.layer_sizes = layer_sizes

        layers = []
        for i in range(len(layer_sizes)-1):
            layers.append(nn.Linear(layer_sizes[i], layer_sizes[i+1]))
        
        self.layers = nn.Sequential(*layers)
    
    def forward(self, x):
        for i in range(len(self.layers)):
            x = self.layers[i](x)
            if i < len(self.layers)-1:
                x = functional.relu(x)
        return x

def export_ANN(ann, input_norms, output_norms, filename='ANN_test.nc'):
    ds = xr.Dataset()
    ds['num_layers'] = xr.DataArray(len(ann.layer_sizes)).expand_dims('dummy_dimension')
    ds['layer_sizes'] = xr.DataArray(ann.layer_sizes, dims=['nlayers'])
    ds = ds.astype('int32') # MOM6 reads only int32 numbers
    
    for i in range(len(ann.layers)):
        # Naming convention for weights and dimensions
        matrix = f'A{i}'
        bias = f'b{i}'
        ncol = f'ncol{i}'
        nrow = f'nrow{i}'
        layer = ann.layers[i]
        
        # Transposed, because torch is row-major, while Fortran is column-major
        ds[matrix] = xr.DataArray(layer.weight.data.T, dims=[ncol, nrow])
        ds[bias] = xr.DataArray(layer.bias.data, dims=[nrow])
    
    # Save true answer for random vector for testing
    x0 = torch.randn(ann.layer_sizes[0])
    y0 = ann(x0 / input_norms) * output_norms
    nrow = f'nrow{len(ann.layers)-1}'
    
    ds['x_test'] = xr.DataArray(x0.data, dims=['ncol0'])
    ds['y_test'] = xr.DataArray(y0.data, dims=[nrow])
    
    ds['input_norms']  = xr.DataArray(input_norms.data, dims=['ncol0'])
    ds['output_norms'] = xr.DataArray(output_norms.data, dims=[nrow])

    
    # print('x_test = ', ds['x_test'].data)
    # print('y_test = ', ds['y_test'].data)
    
    if os.path.exists(filename):
        print(f'Rewrite {filename} ?')
        input()
        os.system(f'rm -f {filename}')
        print(f'{filename} is rewritten')
    
    ds.to_netcdf(filename)

Usage for mapping of 3 features to 1 feature:

ann = ANN([3,64,64,1])
input_norms=torch.ones(3)
output_norms=torch.ones(1)
export_ANN(ann, input_norms, output_norms, 'ANN_file.nc')

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant