diff --git a/hls4ml/backends/__init__.py b/hls4ml/backends/__init__.py index 4a48f072cd..2e95491485 100644 --- a/hls4ml/backends/__init__.py +++ b/hls4ml/backends/__init__.py @@ -1,6 +1,7 @@ from hls4ml.backends.backend import Backend, get_available_backends, get_backend, register_backend # noqa: F401 from hls4ml.backends.fpga.fpga_backend import FPGABackend # noqa: F401 from hls4ml.backends.oneapi.oneapi_backend import OneAPIBackend +from hls4ml.backends.oneapi_accelerator.oneapi_accelerator_backend import OneAPIAcceleratorBackend from hls4ml.backends.quartus.quartus_backend import QuartusBackend from hls4ml.backends.symbolic.symbolic_backend import SymbolicExpressionBackend from hls4ml.backends.vivado.vivado_backend import VivadoBackend @@ -18,3 +19,4 @@ register_backend('Catapult', CatapultBackend) register_backend('SymbolicExpression', SymbolicExpressionBackend) register_backend('oneAPI', OneAPIBackend) +register_backend('oneAPIAccelerator', OneAPIAcceleratorBackend) # Can only be registered after oneAPI diff --git a/hls4ml/backends/oneapi/oneapi_backend.py b/hls4ml/backends/oneapi/oneapi_backend.py index c727238db6..81e701c5f5 100644 --- a/hls4ml/backends/oneapi/oneapi_backend.py +++ b/hls4ml/backends/oneapi/oneapi_backend.py @@ -17,8 +17,8 @@ class OneAPIBackend(FPGABackend): - def __init__(self): - super().__init__('oneAPI') + def __init__(self, name='oneAPI'): # the default name should be used in most cases + super().__init__(name) self._register_layer_attributes() self._register_flows() diff --git a/hls4ml/backends/oneapi/oneapi_template.py b/hls4ml/backends/oneapi/oneapi_template.py index c86b8f7ea3..48668688e2 100644 --- a/hls4ml/backends/oneapi/oneapi_template.py +++ b/hls4ml/backends/oneapi/oneapi_template.py @@ -52,8 +52,8 @@ def _default_function_params(self, layer): params = self._default_params(layer) params['name'] = layer.name params['config'] = f'config{layer.index}' - params['input_pipe'] = layer.get_input_variable().pipe_name - params['output_pipe'] = layer.get_output_variable().pipe_name + params['input_pipe'] = layer.get_input_variable(layer.inputs[0]).pipe_name + params['output_pipe'] = layer.get_output_variable(layer.outputs[0]).pipe_name return params diff --git a/hls4ml/backends/oneapi_accelerator/__init__.py b/hls4ml/backends/oneapi_accelerator/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/hls4ml/backends/oneapi_accelerator/oneapi_accelerator_backend.py b/hls4ml/backends/oneapi_accelerator/oneapi_accelerator_backend.py new file mode 100644 index 0000000000..0c5a5f51c4 --- /dev/null +++ b/hls4ml/backends/oneapi_accelerator/oneapi_accelerator_backend.py @@ -0,0 +1,67 @@ +from hls4ml.backends import OneAPIBackend +from hls4ml.model.flow import register_flow + + +class OneAPIAcceleratorBackend(OneAPIBackend): + """ + This is the backend to run oneAPI code on an accelerator using the oneAPI framework. + """ + + def __init__(self): + super().__init__(name='oneAPIAccelerator') + + def _register_flows(self): + writer_passes = ['make_stamp', 'oneapiaccelerator:write_hls'] + self._writer_flow = register_flow('write', writer_passes, requires=['oneapi:ip'], backend=self.name) + + oneapi_types = [ + 'oneapiaccelerator:transform_types', + 'oneapi:register_bram_weights', + 'oneapi:apply_resource_strategy', + 'oneapi:apply_winograd_kernel_transformation', + ] + oneapi_types_flow = register_flow('specific_types', oneapi_types, requires=['oneapi:init_layers'], backend=self.name) + + streaming_passes = [ + 'oneapi:clone_output', + 'oneapiaccelerator:extract_sideband', + 'oneapiaccelerator:merge_sideband', + ] + streaming_flow = register_flow('streaming', streaming_passes, requires=['oneapi:init_layers'], backend=self.name) + + template_flow = register_flow( + 'apply_templates', self._get_layer_templates, requires=['oneapi:init_layers'], backend=self.name + ) + + accel_flow_requirements = [ + 'optimize', + 'oneapi:init_layers', + streaming_flow, + 'oneapi:quantization', + 'oneapi:optimize', + oneapi_types_flow, + template_flow, + ] + + accel_flow_requirements = list(filter(None, accel_flow_requirements)) + self._default_flow = register_flow('accel', None, requires=accel_flow_requirements, backend=self.name) + + def create_initial_config( + self, part, clock_period=5, hyperopt_handshake=False, io_type='io_parallel', write_tar=False, **_ + ): + """Create initial configuration of the oneAPI backend. + + Args: + part (str): The path to the board support package to be used. Can add : + clock_period (int, optional): The clock period in ns. Defaults to 5. + hyperopt_handshake (bool, optional): Should hyper-optimized handshaking be used? Defaults to False + io_type (str, optional): Type of implementation used. One of + 'io_parallel' or 'io_stream'. Defaults to 'io_parallel'. + write_tar (bool, optional): If True, compresses the output directory into a .tar.gz file. Defaults to False. + + Returns: + dict: initial configuration. + """ + config = super().create_initial_config(part, clock_period, hyperopt_handshake, io_type, write_tar, **_) + config['UseOneAPIBSP'] = True + return config diff --git a/hls4ml/backends/oneapi_accelerator/oneapi_accelerator_layers.py b/hls4ml/backends/oneapi_accelerator/oneapi_accelerator_layers.py new file mode 100644 index 0000000000..710ce94927 --- /dev/null +++ b/hls4ml/backends/oneapi_accelerator/oneapi_accelerator_layers.py @@ -0,0 +1,45 @@ +from hls4ml.model.attributes import Attribute, TypeAttribute +from hls4ml.model.layers import Layer, register_layer +from hls4ml.model.types import IntegerPrecisionType + +SIDEBAND_SHAPE = 2 + + +class SidebandExtraction(Layer): + """This layer extract the sideband and sends it to a different strem""" + + _expected_attributes = [Attribute('n_in'), TypeAttribute('sideband_t', description='The type of the sidbands')] + + def initialize(self): + inp = self.get_input_variable() + self.set_attr('n_in', inp.size()) + + # I think the order of these must be as stated because they each set the result_t type. + # We want the second one to be the actual result_t. + self.add_output_variable( + SIDEBAND_SHAPE, + out_name='sideband', + var_name='sideband_out', + type_name='sideband_t', + precision=IntegerPrecisionType(1, False), + ) + self.set_attr('sideband_t', self.get_attr('sideband').type) # need to manually set this, unlike result_t + self.add_output_variable(inp.shape, precision=inp.type.precision) + + +class SidebandMerging(Layer): + """This layer gets the sideband from a different input and merges it""" + + _expected_attributes = [ + Attribute('n_in'), + ] + + def initialize(self): + inp = self.get_input_variable() + self.set_attr('n_in', inp.size()) + self.add_output_variable(inp.shape, precision=inp.type.precision) + + +# register the layers +register_layer('SidebandExtraction', SidebandExtraction) +register_layer('SidebandMerging', SidebandMerging) diff --git a/hls4ml/backends/oneapi_accelerator/oneapi_accelerator_types.py b/hls4ml/backends/oneapi_accelerator/oneapi_accelerator_types.py new file mode 100644 index 0000000000..27ce510f0c --- /dev/null +++ b/hls4ml/backends/oneapi_accelerator/oneapi_accelerator_types.py @@ -0,0 +1,35 @@ +from hls4ml.backends.fpga.fpga_types import VariableDefinition +from hls4ml.backends.oneapi.oneapi_types import AggregratedArrayVariableConverter + + +# region InterfaceMemberVariable +class OneAPIAcceleratorInterfaceVariableDefinition(VariableDefinition): + def definition_cpp(self, name_suffix='', as_reference=False): + if self.pragma and not isinstance(self.pragma, tuple): + return f'[[{self.pragma}]] {self.type.name} {self.name}{name_suffix}' + else: + return f'{self.type.name} {self.name}{name_suffix}' + + # Updated pipe min size to be 32 for simulation. + def declare_cpp(self, pipe_min_size=32, indent=''): + # Updated to use streaming beat for restartable streaming kernel. + # Streaming beat is a wrapper type of the actual type with sideband control signals. + # Syntax: using BeatT = sycl::ext::intel::experimental::StreamingBeat; + streaming_beat_t = f"{self.pipe_name}BeatT" + lines = ( + f"{indent}class {self.pipe_id};\n" + f"{indent}using {streaming_beat_t} = " + f"sycl::ext::intel::experimental::StreamingBeat<{self.type.name}, true, true>;\n" + f"{indent}using {self.pipe_name} = sycl::ext::intel::experimental::pipe<" + f"{self.pipe_id}, {streaming_beat_t}, {pipe_min_size}, HostPipePropertiesT>;\n" + ) + return lines + + +class OneAPIAcceleratorInterfaceVariableConverter(AggregratedArrayVariableConverter): + def __init__(self, type_converter): + super().__init__( + type_converter=type_converter, + prefix='OneAPIAccelerator', + definition_cls=OneAPIAcceleratorInterfaceVariableDefinition, + ) diff --git a/hls4ml/backends/oneapi_accelerator/passes/__init__.py b/hls4ml/backends/oneapi_accelerator/passes/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/hls4ml/backends/oneapi_accelerator/passes/sideband_templates.py b/hls4ml/backends/oneapi_accelerator/passes/sideband_templates.py new file mode 100644 index 0000000000..a417db1c9f --- /dev/null +++ b/hls4ml/backends/oneapi_accelerator/passes/sideband_templates.py @@ -0,0 +1,71 @@ +"""The sideband handling templates are needed for oneAPI accelerator when using io_stream. +They are not used in io_paralle. +""" + +from hls4ml.backends.oneapi.oneapi_template import StreamFunctionCallTemplate, TaskSequenceTemplate +from hls4ml.backends.oneapi_accelerator.oneapi_accelerator_layers import SidebandExtraction, SidebandMerging +from hls4ml.backends.template import FunctionCallTemplate, LayerConfigTemplate + +sideband_config_template = """struct config{index} : nnet::sideband_config {{ + static constexpr unsigned n_in = {n_in}; +}};\n""" +sideband_stream_function_template = '{name}.async();' +sideband_extract_task_sequence_template = ( + 'task_sequence> {name};' +) +sideband_merge_task_sequence_template = ( + 'task_sequence> {name};' +) +sideband_include_list = ['nnet_utils/nnet_stream_beat.h'] + + +class SidebandConfigTemplate(LayerConfigTemplate): + def __init__(self): + super().__init__((SidebandExtraction, SidebandMerging)) + self.template = sideband_config_template + + def format(self, node): + params = self._default_config_params(node) + return self.template.format(**params) + + +class SidebandFunctionTemplate(FunctionCallTemplate): + """Only used to add the include list""" + + def __init__(self): + super().__init__((SidebandExtraction, SidebandMerging), include_header=sideband_include_list) + + def format(self, node): + return '' + + +class SidebandStreamFunctionTemplate(StreamFunctionCallTemplate): + def __init__(self): + super().__init__((SidebandExtraction, SidebandMerging)) + self.template = sideband_stream_function_template + + def format(self, node): + params = self._default_function_params(node) + return self.template.format(**params) + + +class SidebandExtractionTaskSequenceTemplate(TaskSequenceTemplate): + def __init__(self): + super().__init__(SidebandExtraction) + self.template = sideband_extract_task_sequence_template + + def format(self, node): + params = self._default_function_params(node) + params['skip_pipe'] = node.get_output_variable('sideband').pipe_name + return self.template.format(**params) + + +class SidebandMergeTaskSequenceTemplate(TaskSequenceTemplate): + def __init__(self): + super().__init__(SidebandMerging) + self.template = sideband_merge_task_sequence_template + + def format(self, node): + params = self._default_function_params(node) + params['skip_pipe'] = node.get_input_variable('sideband').pipe_name + return self.template.format(**params) diff --git a/hls4ml/backends/oneapi_accelerator/passes/sidebands.py b/hls4ml/backends/oneapi_accelerator/passes/sidebands.py new file mode 100644 index 0000000000..37ac4f8616 --- /dev/null +++ b/hls4ml/backends/oneapi_accelerator/passes/sidebands.py @@ -0,0 +1,83 @@ +""" +This file contains optimizers to add layers to extract and merge the sidebands. This is useful +for the accelerator flow when using io_stream. + +Warning: current version only works for network with single inputs and outputs. + +""" + +import warnings +from collections import OrderedDict + +from hls4ml.backends.oneapi_accelerator.oneapi_accelerator_layers import SidebandExtraction, SidebandMerging +from hls4ml.model.layers import Input +from hls4ml.model.optimizer import OptimizerPass + + +class ExtractSideband(OptimizerPass): + """Add a layer after the input to extract the sideband signals.""" + + def match(self, node): + if not (isinstance(node, Input) and node.model.config.get_config_value('IOType') == 'io_stream'): + return False + # now check that not already converted + output_nodes = node.get_output_nodes() + if len(output_nodes) == 1 and isinstance(output_nodes[0], SidebandExtraction): + # already transformed + return False + return True + + def transform(self, model, node): + if len(model.inputs) > 1: + warnings.warn('Current sideband extraction scheme only tested on models with one input', stacklevel=1) + + attributes = {'input_shape': node.get_attr('input_shape')} + new_node = model.make_node( + SidebandExtraction, + f'{node.name}_extract_sb', + attributes, + inputs=[node.outputs[0]], + outputs=[f'{node.name}_extract_sb', 'sideband'], + ) + model.insert_node(new_node) + return True + + +class MergeSideband(OptimizerPass): + """Add a layer after the last layer to merge the sideband signals.""" + + def match(self, node): + if node.model.config.get_config_value('IOType') == 'io_stream' and not isinstance(node, SidebandMerging): + for node_out in node.outputs: + if node_out in node.model.outputs: # if the node output is a model output + return True + return False + + def transform(self, model, node): + if len(model.outputs) > 1: + warnings.warn('Current sideband extraction scheme only tested on models with one output', stacklevel=1) + + attributes = {} + + inputs = [out for out in node.outputs if out in model.outputs] + + if len(inputs) != 1: + raise RuntimeError('Unsupported number of outputs found') + + inputs.append('sideband') + + new_name = f'{node.name}_merge_sb' + new_node = model.make_node(SidebandMerging, new_name, attributes, inputs=inputs) + + # note that model.insert_node fails here because of the two input nodes, so using a custom version below + model.outputs[0] = new_name + + new_graph = OrderedDict() + for k, v in model.graph.items(): + new_graph[k] = v + if k == node.name: + new_graph[new_node.name] = new_node + + model.graph = new_graph + + return True diff --git a/hls4ml/backends/oneapi_accelerator/passes/transform_types.py b/hls4ml/backends/oneapi_accelerator/passes/transform_types.py new file mode 100644 index 0000000000..8947bacef9 --- /dev/null +++ b/hls4ml/backends/oneapi_accelerator/passes/transform_types.py @@ -0,0 +1,62 @@ +from hls4ml.backends.oneapi.oneapi_types import ( + OneAPIACTypeConverter, + OneAPIArrayVariableConverter, + OneAPIHLSTypeConverter, + OneAPIInplaceArrayVariableConverter, + OneAPIInplaceStreamVariableConverter, + OneAPIStaticWeightVariableConverter, + OneAPIStreamVariableConverter, +) +from hls4ml.backends.oneapi_accelerator.oneapi_accelerator_types import ( + OneAPIAcceleratorInterfaceVariableConverter, +) +from hls4ml.model.optimizer import GlobalOptimizerPass +from hls4ml.model.types import InplaceTensorVariable + + +class TransformTypes(GlobalOptimizerPass): + def __init__(self): + self.type_converter = OneAPIHLSTypeConverter(precision_converter=OneAPIACTypeConverter()) + self.array_var_converter = OneAPIArrayVariableConverter(type_converter=self.type_converter) + self.inplace_array_var_converter = OneAPIInplaceArrayVariableConverter(type_converter=self.type_converter) + self.interface_var_converter = OneAPIAcceleratorInterfaceVariableConverter(type_converter=self.type_converter) + self.stream_var_converter = OneAPIStreamVariableConverter(type_converter=self.type_converter) + self.inplace_stream_var_converter = OneAPIInplaceStreamVariableConverter(type_converter=self.type_converter) + self.weight_var_converter = OneAPIStaticWeightVariableConverter(type_converter=self.type_converter) + + def transform(self, model, node): + io_type = node.model.config.get_config_value('IOType') + + for out_name, var in node.variables.items(): + if io_type == 'io_stream': + if out_name in node.model.inputs: + new_var = self.interface_var_converter.convert(var, pragma='stream') + elif out_name in node.model.outputs: + new_var = self.interface_var_converter.convert(var, pragma='stream') + elif isinstance(var, InplaceTensorVariable): + new_var = self.inplace_stream_var_converter.convert(var, pragma='stream') + else: + new_var = self.stream_var_converter.convert(var, pragma='stream') + elif io_type == 'io_parallel': + if out_name in node.model.inputs: + new_var = self.interface_var_converter.convert(var, pragma='intel::fpga_register') + elif out_name in node.model.outputs: + new_var = self.interface_var_converter.convert(var, pragma='intel::fpga_register') + elif isinstance(var, InplaceTensorVariable): + new_var = self.inplace_array_var_converter.convert(var, pragma='') + else: + new_var = self.array_var_converter.convert(var, pragma='intel::fpga_register') + else: + raise Exception(f'Unknown IOType {io_type} in {node.name} ({node.class_name})') + + node.set_attr(out_name, new_var) + if new_var.type.name in node.attributes: + node.set_attr(new_var.type.name, new_var.type) # this is for variables that are not result_t + + for w_name, weight in node.weights.items(): + new_weight = self.weight_var_converter.convert(weight) + node.set_attr(w_name, new_weight) + + for t_name, type in node.types.items(): + new_type = self.type_converter.convert(type) + node.set_attr(t_name, new_type) diff --git a/hls4ml/model/graph.py b/hls4ml/model/graph.py index e3c293dd46..dc636d39d4 100644 --- a/hls4ml/model/graph.py +++ b/hls4ml/model/graph.py @@ -577,7 +577,7 @@ def make_node(self, kind, name, attributes, inputs, outputs=None, initialize=Tru self.output_vars[o] = out_var return node - def insert_node(self, node, before=None, input_idx=0): + def insert_node(self, node, before=None, input_idx=-1): """Insert a new node into the model graph. The node to be inserted should be created with `make_node()` function. The optional @@ -587,7 +587,8 @@ def insert_node(self, node, before=None, input_idx=0): node (Layer): Node to insert before (Layer, optional): The next node in sequence before which a new node should be inserted. - input_idx (int, optional): If the next node takes multiple inputs, the input index + input_idx (int, optional): If the next node takes multiple inputs, the input index; + The default (-1) means match by name Raises: Exception: If an attempt to insert a node with multiple inputs is made or if `before` does not specify a correct node in sequence. @@ -603,19 +604,28 @@ def insert_node(self, node, before=None, input_idx=0): if overlap: next_nodes.append(x) - if before is None: - next_node = next((x for x in self.graph.values() if x.inputs and x.inputs[0] in prev_node.outputs), None) - else: - if before not in next_nodes: - raise Exception( - 'Cannot insert a node {} before {} (candidates: {}).'.format( - node.name, before.name, ','.join([n.name for n in next_nodes]) + if before is not None: + if not isinstance(before, (tuple, list)): + before = [before] + + # check that before is in next_nodes + for bf in before: + if bf not in next_nodes: + raise RuntimeError( + 'Cannot insert a node {} before {} (candidates: {}).'.format( + node.name, before.name, ','.join([n.name for n in next_nodes]) + ) ) - ) - next_node = before + # only put before as next_nodes + next_nodes = before - if next_node is not None: - next_node.inputs[input_idx] = node.outputs[0] + if next_nodes: + repl = {old_name: new_name for old_name, new_name in zip(prev_node.outputs, node.outputs)} + for next_node in next_nodes: + if input_idx >= 0: + next_node.inputs[input_idx] = node.outputs[0] + else: + next_node.inputs = [repl[val] if val in repl else val for val in next_node.inputs] else: self.outputs = [node.outputs[0] if name == prev_node.outputs[0] else name for name in self.outputs] @@ -693,6 +703,11 @@ def replace_node(self, old_node, new_node): repl = {old_name: new_name for old_name, new_name in zip(old_node.outputs, new_node.outputs)} repl.update({old_name: new_name for old_name, new_name in zip(old_node.inputs, new_node.inputs)}) + for old_output in old_node.outputs: + if old_output in self.outputs: + new_output = repl[old_output] + self.outputs = [new_output if name == old_output else name for name in self.outputs] + for node in self.graph.values(): for i, n in enumerate(node.inputs): if n in repl: @@ -703,11 +718,6 @@ def replace_node(self, old_node, new_node): self.graph = OrderedDict((new_node.name, new_node) if k == old_node.name else (k, v) for k, v in self.graph.items()) - old_name = old_node.name - if old_name in self.outputs: - new_name = new_node.name - self.outputs = [new_name if name == old_name else name for name in self.outputs] - def split_node(self, old_node, new_node1, new_node2): """Replace an existing node in the graph with two nodes in sequence. @@ -728,6 +738,11 @@ def split_node(self, old_node, new_node1, new_node2): repl = {old_name: new_name for old_name, new_name in zip(old_node.outputs, new_node2.outputs)} repl.update({old_name: new_name for old_name, new_name in zip(old_node.inputs, new_node1.inputs)}) + for old_output in old_node.outputs: + if old_output in self.outputs: + new_output = repl[old_output] + self.outputs = [new_output if name == old_output else name for name in self.outputs] + for node in self.graph.values(): for i, n in enumerate(node.inputs): if n in repl: @@ -745,9 +760,6 @@ def split_node(self, old_node, new_node1, new_node2): new_graph[key] = value self.graph = new_graph - if old_node.name in self.outputs: - self.outputs = [new_node2.name if name == old_node.name else name for name in self.outputs] - def next_layer(self): self.index += 1 return self.index diff --git a/hls4ml/templates/oneapi/CMakeLists.txt b/hls4ml/templates/oneapi/CMakeLists.txt index 5bce2aaf84..0837a2976b 100644 --- a/hls4ml/templates/oneapi/CMakeLists.txt +++ b/hls4ml/templates/oneapi/CMakeLists.txt @@ -39,14 +39,18 @@ set(LIBRARY_NAME myproject-${LIB_STAMP}) # specific part number (E.g. "10AS066N3F40E2SG") to generate a standalone IP. if(NOT DEFINED FPGA_DEVICE) set(FPGA_DEVICE "Agilex7") + set(BSP_FLAG "") endif() +# Set the target to a BSP if we target an actual accelerator board. +# hls-fpga-machine-learning insert oneapi_bsp_cmake_flag + # Use cmake -DUSER_FPGA_FLAGS= to set extra flags for FPGA backend # compilation. set(USER_FPGA_FLAGS -Wno-unused-label;${USER_FPGA_FLAGS}) # Use cmake -DUSER_FLAGS= to set extra flags for general compilation. -set(USER_FLAGS -Wno-unused-label -fconstexpr-steps=134217728 ${USER_FLAGS}) +set(USER_FLAGS -Wno-unused-label -fconstexpr-steps=134217728 ${USER_FLAGS} ${BSP_FLAG}) # Use cmake -DUSER_INCLUDE_PATHS= to set extra paths for general # compilation. diff --git a/hls4ml/templates/oneapi/firmware/myproject.h b/hls4ml/templates/oneapi/firmware/myproject.h index 082ae5dc8c..ec3bf146f5 100644 --- a/hls4ml/templates/oneapi/firmware/myproject.h +++ b/hls4ml/templates/oneapi/firmware/myproject.h @@ -5,7 +5,7 @@ // This file defines the interface to the kernel -// currently this is fixed +// this is for both the internal pipes and the interface for the HLS (ip) flow using PipeProps = decltype(sycl::ext::oneapi::experimental::properties(sycl::ext::intel::experimental::ready_latency<0>)); // Need to declare the input and output pipes diff --git a/hls4ml/templates/oneapi/firmware/nnet_utils/nnet_data_movement.h b/hls4ml/templates/oneapi/firmware/nnet_utils/nnet_data_movement.h new file mode 100644 index 0000000000..04395b2bda --- /dev/null +++ b/hls4ml/templates/oneapi/firmware/nnet_utils/nnet_data_movement.h @@ -0,0 +1,191 @@ +#ifndef NNET_DATA_MOVEMENT_H +#define NNET_DATA_MOVEMENT_H + +#include +#include + +// This file defines the methods to transfer the data to the kernel. In the HLS flow, +// these are really part of the testbench. However, in the accelerator (BSP) flow, they are +// actual kernels that are deployed in hardware. + +namespace nnet { + +////////////////////////////////////////////////////////////////////////////// +// These are the simple, testbench and bridge versions for the HLS flow +////////////////////////////////////////////////////////////////////////////// +template void convert_data(sycl::queue &q, srcType *src) { + constexpr auto dstTypeSize = std::tuple_size::value_type>{}; + for (size_t i = 0; i < SIZE / dstTypeSize; i++) { + typename ExtractPipeType::value_type ctype; + for (size_t j = 0; j < dstTypeSize; j++) { + ctype[j] = src[i * dstTypeSize + j]; + } + dest_pipe::write(q, ctype); + } +} + +template void convert_data_back(sycl::queue &q, dstType *dst) { + constexpr auto srcTypeSize = std::tuple_size::value_type>{}; + for (size_t i = 0; i < SIZE / srcTypeSize; i++) { + auto ctype = src_pipe::read(q); + for (size_t j = 0; j < srcTypeSize; j++) { + dst[i * srcTypeSize + j] = ctype[j].to_double(); + } + } +} + +////////////////////////////////////////////////////////////////////////////// +// The ones below can be used both in testbenches and in the accelerator flow +////////////////////////////////////////////////////////////////////////////// +#if !defined(IS_BSP) +// Definition for buffer locations for Avalon MM host. +inline constexpr unsigned kInputBufferLocation = 0; +inline constexpr unsigned kOutputBufferLocation = 1; +#endif + +// Implementation of a direct memory access kernel. Move data from source, convert, +// and send to the sink. Adaptive to SYCL HLS and hardware acceleration flow. +template struct DMA_convert_data { +#if !defined(IS_BSP) + // When targeting a device family, we instantiate an Avalon Memory Mapped Host for + // data transaction between host and the DMA kernel during emulation and simulation. + sycl::ext::oneapi::experimental::annotated_arg< + src_T *, + decltype(sycl::ext::oneapi::experimental::properties{ + sycl::ext::intel::experimental::latency<0>, sycl::ext::intel::experimental::dwidth<16>, + sycl::ext::intel::experimental::buffer_location, + sycl::ext::intel::experimental::read_write_mode_read, sycl::ext::intel::experimental::wait_request_requested})> +#else + // When targeting oneAPI BSP, we can use USM pointer to access host memory. + src_T *const +#endif + src; + size_t num_iteration; + + [[intel::kernel_args_restrict]] void operator()() const { + +#if defined(IS_BSP) + // Access data using host pointer. + sycl::ext::intel::host_ptr src_ptr(src); +#else + // Host allocation is not supported when targeting an FPGA family or part number. + src_T *src_ptr(src); +#endif + // First, extract the PipeDataT from the pipe + using PipeDataType = typename nnet::ExtractPipeType::value_type; + // Then, extract the DataT from StreamingBeat + using DstDataType = typename nnet::ExtractDataType::value_type; + constexpr auto dstTypeSize = std::tuple_size{}; + + [[intel::fpga_register]] typename nnet::ExtractPipeType::value_type packet; + + // Keep sending data to the input layer and keep the kernels running. + for (size_t i = 0; i < num_iteration; i++) { + #pragma unroll + for (size_t j = 0; j < dstTypeSize; j++) { + packet.data[j] = src_ptr[i * dstTypeSize + j]; + } + packet.sop = (i == 0); + // Assert end-of-packet signal after the last iteration. + // All down-stream kernels will stop seeing eop. + packet.eop = (i == (num_iteration - 1)); + dest_pipe::write(packet); + } + } +}; + +// Symmetrical to the DMA_convert_data above, this DMA drains the output pipe and +// writes result to memory. +template struct DMA_convert_data_back { +#if !defined(IS_BSP) + // Without BSP, instantiate an Avalon Memory Mapped Host to write to host. + sycl::ext::oneapi::experimental::annotated_arg< + dst_T *, + decltype(sycl::ext::oneapi::experimental::properties{ + sycl::ext::intel::experimental::latency<0>, sycl::ext::intel::experimental::dwidth<16>, + sycl::ext::intel::experimental::buffer_location, + sycl::ext::intel::experimental::read_write_mode_write, sycl::ext::intel::experimental::wait_request_requested})> +#else + // USM pointer, otherwise. + dst_T *const +#endif + dst; + size_t num_iteration; + + [[intel::kernel_args_restrict]] void operator()() const { +#if defined(IS_BSP) + sycl::ext::intel::host_ptr dst_ptr(dst); +#else + dst_T *dst_ptr(dst); +#endif + // First, extract the PipeDataT from the pipe + using PipeDataType = typename nnet::ExtractPipeType::value_type; + // Then, extract the DataT from StreamingBeat + using SrcDataType = typename nnet::ExtractDataType::value_type; + constexpr auto srcTypeSize = std::tuple_size{}; + + [[intel::fpga_register]] typename nnet::ExtractPipeType::value_type packet; + + // Drain the output pipe and write result to memory. + for (size_t i = 0; i < num_iteration; i++) { + packet = src_pipe::read(); + #pragma unroll + for (size_t j = 0; j < srcTypeSize; j++) { + dst_ptr[i * srcTypeSize + j] = static_cast(packet.data[j].to_double()); + } + } + } +}; + +////////////////////////////////////////////////////////////////////////////// +// These are versions to convert data for the accelerator bridge (using BSP) +////////////////////////////////////////////////////////////////////////////// +template void DMA_bridge_convert_data(sycl::queue &q, srcType *src) { + // First, extract the PipeDataT from the pipe + using PipeDataType = typename nnet::ExtractPipeType::value_type; + // Then, extract the DataT from StreamingBeat + using DstDataType = typename nnet::ExtractDataType::value_type; + constexpr auto dstTypeSize = std::tuple_size{}; + + constexpr size_t num_iterations = SIZE / dstTypeSize; + + // Allocate host memory + srcType *vals = sycl::malloc_host(SIZE, q); + if (vals == nullptr) { + std::cerr << "ERROR: host allocation failed for input\n"; + return; + } + // copy to host memory + for (size_t i = 0; i < SIZE; i++) { + vals[i] = src[i]; + } + q.single_task(DMA_convert_data{vals, num_iterations}); +} + +template void DMA_bridge_convert_data_back(sycl::queue &q, dstType *dst) { + // First, extract the PipeDataT from the pipe + using PipeDataType = typename nnet::ExtractPipeType::value_type; + // Then, extract the DataT from StreamingBeat + using SrcDataType = typename nnet::ExtractDataType::value_type; + constexpr auto srcTypeSize = std::tuple_size{}; + + constexpr size_t num_iterations = SIZE / srcTypeSize; + + // Allocate host memory + dstType *outputs = sycl::malloc_host(SIZE, q); + if (outputs == nullptr) { + std::cerr << "ERROR: host allocation failed for output\n"; + return; + } + + q.single_task(DMA_convert_data_back{outputs, num_iterations}).wait(); + + // copy the data back + for (size_t j = 0; j < SIZE; j++) { + dst[j] = outputs[j]; + } +} + +} // namespace nnet + +#endif diff --git a/hls4ml/templates/oneapi/firmware/nnet_utils/nnet_dense_stream.h b/hls4ml/templates/oneapi/firmware/nnet_utils/nnet_dense_stream.h index 92c9adc3bb..f162b4736c 100644 --- a/hls4ml/templates/oneapi/firmware/nnet_utils/nnet_dense_stream.h +++ b/hls4ml/templates/oneapi/firmware/nnet_utils/nnet_dense_stream.h @@ -7,7 +7,9 @@ namespace nnet { -// Note: DataPack logic removed, at least in the initial version +// Note: DataPack logic removed, at least in the initial version. +// The data should be sent to the dense layer in parallel, in one stream transaction. +// Note that this means flatten is not a noop in oneAPI streaming. template void dense_resource_stream(typename CONFIG_T::weight_t weights, typename CONFIG_T::bias_t biases) { diff --git a/hls4ml/templates/oneapi/firmware/nnet_utils/nnet_helpers.h b/hls4ml/templates/oneapi/firmware/nnet_utils/nnet_helpers.h index c7af2e7a68..e5b451655a 100644 --- a/hls4ml/templates/oneapi/firmware/nnet_utils/nnet_helpers.h +++ b/hls4ml/templates/oneapi/firmware/nnet_utils/nnet_helpers.h @@ -12,27 +12,6 @@ namespace nnet { -template void convert_data(sycl::queue &q, srcType *src) { - constexpr auto dstTypeSize = std::tuple_size::value_type>{}; - for (size_t i = 0; i < SIZE / dstTypeSize; i++) { - typename ExtractPipeType::value_type ctype; - for (size_t j = 0; j < dstTypeSize; j++) { - ctype[j] = src[i * dstTypeSize + j]; - } - dest_pipe::write(q, ctype); - } -} - -template void convert_data_back(sycl::queue &q, dstType *dst) { - constexpr auto srcTypeSize = std::tuple_size::value_type>{}; - for (size_t i = 0; i < SIZE / srcTypeSize; i++) { - auto ctype = src_pipe::read(q); - for (size_t j = 0; j < srcTypeSize; j++) { - dst[i * srcTypeSize + j] = ctype[j].to_double(); - } - } -} - extern bool trace_enabled; extern std::map *trace_outputs; extern size_t trace_type_size; diff --git a/hls4ml/templates/oneapi/firmware/nnet_utils/nnet_stream_beat.h b/hls4ml/templates/oneapi/firmware/nnet_utils/nnet_stream_beat.h new file mode 100644 index 0000000000..1b41f008c7 --- /dev/null +++ b/hls4ml/templates/oneapi/firmware/nnet_utils/nnet_stream_beat.h @@ -0,0 +1,77 @@ +#ifndef NNET_STREAM_BEAT_H +#define NNET_STREAM_BEAT_H + +// These are functions just for streaming in accelerator mode. They convert from using packets +// to not using packets, and visa versa. + +namespace nnet { + +struct sideband_config { + static const unsigned n_in = 10; +}; + +// ************************************************* +// Remove sideband and passes it to end via skip pipe +// ************************************************* +template +[[intel::use_stall_enable_clusters]] void extract_sideband_stream() { + + [[intel::fpga_register]] typename ExtractPipeType::value_type out_data; + + bool sop = false; + bool eop = false; + +LinearActLoop: + [[intel::initiation_interval(1)]] while (!eop) { + for (int i = 0; i < CONFIG_T::n_in / std::tuple_size::value_type>{}; i++) { + auto in_data = data_pipe::read(); + + LinearPackLoop: + #pragma unroll + for (int j = 0; j < std::tuple_size::value_type>{}; j++) { + out_data[j] = in_data.data[j]; + } + + res_pipe::write(out_data); + + if (i == 0) { + sop = in_data.sop; + } + eop = in_data.eop; + } + typename nnet::ExtractPipeType::value_type skip_data; // this is a two-element array, {sop, eop}. + skip_data[0] = sop; + skip_data[1] = eop; + skip_pipe::write(skip_data); + } +} + +// ************************************************* +// Recieves sideband via skip pipe, and makees it sideband +// ************************************************* + +template +[[intel::use_stall_enable_clusters]] void merge_sideband_stream() { + using ResT = typename ExtractDataType::value_type>::value_type; + [[intel::fpga_register]] typename ExtractPipeType::value_type out_data; + + constexpr auto num_transfers = CONFIG_T::n_in / std::tuple_size{}; + + auto skip_data = skip_pipe::read(); + +LinearActLoop: + [[intel::initiation_interval(1)]] for (int i = 0; i < num_transfers; i++) { + auto in_data = data_pipe::read(); + + LinearPackLoop: + #pragma unroll + for (int j = 0; j < std::tuple_size{}; j++) { + out_data.data[j] = in_data[j]; + } + out_data.sop = (i == 0) ? static_cast(skip_data[0]) : false; + out_data.eop = (i == num_transfers - 1) ? static_cast(skip_data[1]) : false; + res_pipe::write(out_data); + } +} +} // namespace nnet +#endif diff --git a/hls4ml/templates/oneapi/firmware/nnet_utils/nnet_types.h b/hls4ml/templates/oneapi/firmware/nnet_utils/nnet_types.h index 8cf883c1d5..a35bba17bf 100644 --- a/hls4ml/templates/oneapi/firmware/nnet_utils/nnet_types.h +++ b/hls4ml/templates/oneapi/firmware/nnet_utils/nnet_types.h @@ -8,6 +8,8 @@ #include #include +#include // Streaming Beat and pipe properties. + namespace nnet { // Define the pipe type that we use @@ -34,6 +36,15 @@ struct ExtractPipeType struct ExtractDataType { typedef T value_type; }; + +// Specialization on oneAPI StreamingBeat type. +template +struct ExtractDataType> { + typedef DataT value_type; +}; + /* * HLS Shift Register Implementation * To verify a shift register is used in hardware, go to report.html > Area Analysis of System diff --git a/hls4ml/templates/oneapi/myproject_bridge.cpp b/hls4ml/templates/oneapi/myproject_bridge.cpp index ddad1d054b..fa73db7c2a 100644 --- a/hls4ml/templates/oneapi/myproject_bridge.cpp +++ b/hls4ml/templates/oneapi/myproject_bridge.cpp @@ -2,7 +2,7 @@ #define MYPROJECT_BRIDGE_H_ #include "firmware/myproject.h" -#include "firmware/nnet_utils/nnet_helpers.h" +#include "firmware/nnet_utils/nnet_data_movement.h" #include #include diff --git a/hls4ml/templates/oneapi/myproject_test.cpp b/hls4ml/templates/oneapi/myproject_test.cpp index 82fb60d2f8..dc570dcb07 100644 --- a/hls4ml/templates/oneapi/myproject_test.cpp +++ b/hls4ml/templates/oneapi/myproject_test.cpp @@ -7,6 +7,7 @@ #include #include "firmware/myproject.h" +#include "firmware/nnet_utils/nnet_data_movement.h" #include "firmware/parameters.h" #include diff --git a/hls4ml/templates/oneapi_accelerator/firmware/myproject.h b/hls4ml/templates/oneapi_accelerator/firmware/myproject.h new file mode 100644 index 0000000000..ef0d458a4d --- /dev/null +++ b/hls4ml/templates/oneapi_accelerator/firmware/myproject.h @@ -0,0 +1,29 @@ +#ifndef MYPROJECT_H_ +#define MYPROJECT_H_ + +#include "defines.h" + +// This file defines the interface to the kernel + +// currently this is for the internal pipes, not the interface, in accelerator flow +using PipeProps = decltype(sycl::ext::oneapi::experimental::properties(sycl::ext::intel::experimental::ready_latency<0>)); + +// Pipe properties for host pipes. Host pipes connect to the data source DMA and sink DMA. +// They are connected to the first and the last layer to stream data into and out from the kernel. +using HostPipePropertiesT = decltype(sycl::ext::oneapi::experimental::properties( + sycl::ext::intel::experimental::ready_latency<0>, sycl::ext::intel::experimental::bits_per_symbol<16>, + sycl::ext::intel::experimental::uses_valid, sycl::ext::intel::experimental::first_symbol_in_high_order_bits, + sycl::ext::intel::experimental::protocol_avalon_streaming_uses_ready)); + +// Need to declare the input and output pipes + +// hls-fpga-machine-learning insert inputs +// hls-fpga-machine-learning insert outputs + +class MyProjectID; + +struct MyProject { + SYCL_EXTERNAL void operator()() const; +}; + +#endif diff --git a/hls4ml/templates/oneapi_accelerator/myproject_test.cpp b/hls4ml/templates/oneapi_accelerator/myproject_test.cpp new file mode 100644 index 0000000000..f90129db4c --- /dev/null +++ b/hls4ml/templates/oneapi_accelerator/myproject_test.cpp @@ -0,0 +1,199 @@ +#include +#include +#include +#include +#include +#include +#include +#include + +#include "firmware/myproject.h" +#include "firmware/nnet_utils/nnet_data_movement.h" +#include "firmware/parameters.h" + +#include + +#if (__INTEL_CLANG_COMPILER < 20250000) +#include +#endif + +#include "exception_handler.hpp" +// hls-fpga-machine-learning insert bram + +#define CHECKPOINT 5000 + +// Functions that reads input and prediction data from files. +// Returns `true` if files are read successfully and not empty. +// Returns `false` otherwise. +bool prepare_data_from_file(std::string &fin_path, std::string &fpr_path, std::vector> &inputs, + std::vector> &predictions) { + // load input data from text file + std::ifstream fin(fin_path.c_str()); + // load predictions from text file + std::ifstream fpr(fpr_path.c_str()); + + std::string iline; + std::string pline; + + if (fin.is_open() && fpr.is_open()) { + size_t num_iterations = 0; + + // Prepare input data from file. Load predictions from file. + for (; std::getline(fin, iline) && std::getline(fpr, pline); num_iterations++) { + if (num_iterations % CHECKPOINT == 0) { + std::cout << "Processing input " << num_iterations << std::endl; + } + + std::vector in; + std::vector pr; + float current; + + std::stringstream ssin(iline); + while (ssin >> current) { + in.push_back(current); + } + + std::stringstream sspred(pline); + while (sspred >> current) { + pr.push_back(current); + } + + std::copy(pr.cbegin(), pr.cend(), predictions.back().begin()); + std::copy(in.cbegin(), in.cend(), inputs.back().begin()); + } + fin.close(); + fpr.close(); + if (inputs.empty()) + return false; + else + return true; + } else { + return false; + } +} + +int main(int argc, char **argv) { + +#if FPGA_SIMULATOR +#define NUM_ITERATIONS 5 + auto selector = sycl::ext::intel::fpga_simulator_selector_v; +#elif FPGA_HARDWARE +#define NUM_ITERATIONS 100 + auto selector = sycl::ext::intel::fpga_selector_v; +#else // #if FPGA_EMULATOR +#define NUM_ITERATIONS 10 + auto selector = sycl::ext::intel::fpga_emulator_selector_v; +#endif + + sycl::queue q(selector, fpga_tools::exception_handler, sycl::property::queue::enable_profiling{}); + + auto device = q.get_device(); + + // make sure the device supports USM host allocations + if (!device.has(sycl::aspect::usm_host_allocations)) { + std::cerr << "This design must either target a board that supports USM " + "Host/Shared allocations, or IP Component Authoring. " + << std::endl; + std::terminate(); + } + + std::cout << "Running on device: " << device.get_info().c_str() << std::endl; + + std::string INPUT_FILE = "tb_data/tb_input_features.dat"; + std::string PRED_FILE = "tb_data/tb_output_predictions.dat"; + std::string RESULTS_LOG = "tb_data/results.log"; + std::ofstream fout(RESULTS_LOG); + + // Allocate vectors on stack to hold data from files temporarily. + std::vector> inputs; + std::vector> predictions; + bool file_valid = prepare_data_from_file(INPUT_FILE, PRED_FILE, inputs, predictions); + unsigned int num_iterations; + if (file_valid) { + num_iterations = inputs.size(); + } else { + num_iterations = NUM_ITERATIONS; + } + + // hls-fpga-machine-learning insert runtime contant + + try { + // Allocate host memory if BSP is in use. + float *vals = sycl::malloc_host(kInputSz, q); + if (vals == nullptr) { + std::cerr << "ERROR: host allocation failed for input\n"; + fout.close(); + return 1; + } + float *outputs = sycl::malloc_host(kOutputSz, q); + if (outputs == nullptr) { + std::cerr << "ERROR: host allocation failed for output\n"; + fout.close(); + return 1; + } + + if (file_valid) { + // Start always-run streaming kernel here, instead of inside a loop. + q.single_task(MyProject{}); + + // hls-fpga-machine-learning insert data + + // hls-fpga-machine-learning convert output + + // Print output from kernel and from prediction file. + for (int i = 0; i < num_iterations; i++) { + for (int j = 0; j < kOutLayerSize; j++) { + fout << outputs[i * kOutLayerSize + j] << " "; + } + fout << std::endl; + if (i % CHECKPOINT == 0) { + std::cout << "Predictions" << std::endl; + // hls-fpga-machine-learning insert predictions + for (auto predval : predictions[i]) { + std::cout << predval << " "; + } + std::cout << std::endl; + std::cout << "Quantized predictions" << std::endl; + // hls-fpga-machine-learning insert quantized + for (int j = 0; j < kOutLayerSize; j++) { + std::cout << outputs[i * kOutLayerSize + j] << " "; + } + std::cout << std::endl; + } + } + } else { + std::cout << "INFO: Unable to open input/predictions file, using default input with " << num_iterations + << " invocations." << std::endl; + q.single_task(MyProject{}); + // hls-fpga-machine-learning insert top-level-function + // hls-fpga-machine-learning insert zero + // hls-fpga-machine-learning convert output + for (int i = 0; i < num_iterations; i++) { + for (int j = 0; j < kOutLayerSize; j++) { + std::cout << outputs[i * kOutLayerSize + j] << " "; + fout << outputs[i * kOutLayerSize + j] << " "; + } + std::cout << std::endl; + fout << std::endl; + } + } + sycl::free(vals, q); + sycl::free(outputs, q); + fout.close(); + std::cout << "INFO: Saved inference results to file: " << RESULTS_LOG << std::endl; + } catch (sycl::exception const &e) { + // Catches exceptions in the host code. + std::cerr << "Caught a SYCL host exception:\n" << e.what() << "\n"; + + // Most likely the runtime couldn't find FPGA hardware! + if (e.code().value() == CL_DEVICE_NOT_FOUND) { + std::cerr << "If you are targeting an FPGA, please ensure that your " + "system has a correctly configured FPGA board.\n"; + std::cerr << "Run sys_check in the oneAPI root directory to verify.\n"; + std::cerr << "If you are targeting the FPGA emulator, compile with " + "-DFPGA_EMULATOR.\n"; + } + std::terminate(); + } + return 0; +} diff --git a/hls4ml/writer/__init__.py b/hls4ml/writer/__init__.py index 8de19fe1d2..b8b066f036 100644 --- a/hls4ml/writer/__init__.py +++ b/hls4ml/writer/__init__.py @@ -1,4 +1,5 @@ from hls4ml.writer.catapult_writer import CatapultWriter +from hls4ml.writer.oneapi_accelerator_writer import OneAPIAcceleratorWriter from hls4ml.writer.oneapi_writer import OneAPIWriter from hls4ml.writer.quartus_writer import QuartusWriter from hls4ml.writer.symbolic_writer import SymbolicExpressionWriter @@ -12,5 +13,6 @@ register_writer('Vitis', VitisWriter) register_writer('Quartus', QuartusWriter) register_writer('oneAPI', OneAPIWriter) +register_writer('oneAPIAccelerator', OneAPIAcceleratorWriter) register_writer('Catapult', CatapultWriter) register_writer('SymbolicExpression', SymbolicExpressionWriter) diff --git a/hls4ml/writer/oneapi_accelerator_writer.py b/hls4ml/writer/oneapi_accelerator_writer.py new file mode 100644 index 0000000000..b5cdc3ed3b --- /dev/null +++ b/hls4ml/writer/oneapi_accelerator_writer.py @@ -0,0 +1,406 @@ +import os +from shutil import copyfile + +from hls4ml.utils.string_utils import convert_to_pascal_case +from hls4ml.writer.oneapi_writer import OneAPIWriter + +config_filename = 'hls4ml_config.yml' + + +class OneAPIAcceleratorWriter(OneAPIWriter): + + def write_project_cpp(self, model): + """Write the main architecture source file (myproject.cpp) + + Args: + model (ModelGraph): the hls4ml model. + """ + project_name = model.config.get_project_name() + + filedir = os.path.dirname(os.path.abspath(__file__)) + with ( + open(os.path.join(filedir, '../templates/oneapi/firmware/myproject.cpp')) as f, + open(f'{model.config.get_output_dir()}/src/firmware/{project_name}.cpp', 'w') as fout, + ): + model_inputs = model.get_input_variables() + model_outputs = model.get_output_variables() + model_brams = [var for var in model.get_weight_variables() if var.storage.lower() == 'bram'] + + if len(model_brams) != 0: + raise NotImplementedError("Weights on the interface is currently not supported") + + io_type = model.config.get_config_value('IOType') + indent = ' ' + + for line in f.readlines(): + # Add headers to weights and biases + if 'myproject' in line: + newline = line.replace('myproject', project_name) + elif 'MyProject' in line: + newline = line.replace('MyProject', convert_to_pascal_case(project_name)) + + # oneAPI pipes need to be declared and passed as template parameters + elif '// hls-fpga-machine-learning insert inter-task pipes' in line: + newline = line + if io_type == 'io_stream': + for layer in model.get_layers(): + vars = layer.get_variables() + for var in vars: + if var not in model_inputs and var not in model_outputs: + newline += var.declare_cpp() + + # Read in inputs + elif '// hls-fpga-machine-learning read in' in line: + newline = line + if io_type == 'io_parallel': + restartable_kernel_loop = f"bool keep_going = true;\n\n" f"{indent}while (keep_going) {{\n" + newline += indent + restartable_kernel_loop + for inp in model_inputs: + newline += indent * 2 + f'auto {inp.name}_beat = {inp.pipe_name}::read();\n' + newline += indent * 2 + f'auto {inp.name} = {inp.name}_beat.data;\n' + # for streaming we don't need to read it in + + # Insert weights + elif '// hls-fpga-machine-learning insert weights' in line: + newline = line + for layer in model.get_layers(): + for w in layer.get_weights(): + if w not in model_brams: + newline += f'#include "weights/{w.name}.h"\n' + + # Insert task sequences + elif '// hls-fpga-machine-learning declare task sequences' in line: + if io_type == 'io_stream': # only need this for io_stream + newline = line + for layer in model.get_layers(): + ts = layer.get_attr('tast_sequence_cpp') + if ts: + newline += ' ' + ts + '\n' + else: + newline = indent + line + + # Neural net instantiation + elif '// hls-fpga-machine-learning insert layers' in line: + if io_type == 'io_parallel': + newline = indent + line + '\n' + else: + newline = line + '\n' + for layer in model.get_layers(): + if io_type != 'io_stream': + vars = layer.get_variables() + for var in vars: + if var not in model_inputs: + def_cpp = var.definition_cpp() + if def_cpp is not None: + newline += indent * 2 + def_cpp + ';\n' + func = ( + layer.get_attr('function_cpp') + if io_type == 'io_parallel' + else layer.get_attr('stream_function_cpp') + ) + if func: + newline += (indent * 2 if io_type == 'io_parallel' else indent) + func + '\n' + if model.config.trace_output and layer.get_attr('trace', False): + newline += '#ifndef HLS_SYNTHESIS\n' + for var in vars: + newline += ' nnet::save_layer_output<{}>({}, "{}", {});\n'.format( + var.type.name, var.name, layer.name, var.size_cpp() + ) + newline += '#endif\n' + + # Write the output + elif '// hls-fpga-machine-learning return' in line: + newline = line + if io_type == 'io_parallel': + newline = indent + newline + for out in model_outputs: + out_beat = f"{out.name}_beat" + newline += ( + indent * 2 + f'typename nnet::ExtractPipeType<{out.pipe_name}>::value_type {out_beat};\n' + ) + newline += indent * 2 + f'{out_beat}.data = {out.name};\n' + newline += indent * 2 + f'{out.pipe_name}::write({out_beat});\n' + newline += indent * 2 + '// stops the kernel when the last input seen.\n' + newline += indent * 2 + f'keep_going = !{model_inputs[0].name}_beat.eop;\n' + newline += f"{indent}}}\n" + # don't need to add anything in io_stream + + # Just copy line + else: + newline = line + + fout.write(newline) + + def write_project_header(self, model): + """Write the main architecture header file (myproject.h) + + Args: + model (ModelGraph): the hls4ml model. + """ + + project_name = model.config.get_project_name() + + filedir = os.path.dirname(os.path.abspath(__file__)) + with ( + open(os.path.join(filedir, '../templates/oneapi_accelerator/firmware/myproject.h')) as f, + open(f'{model.config.get_output_dir()}/src/firmware/{project_name}.h', 'w') as fout, + ): + model_inputs = model.get_input_variables() + model_outputs = model.get_output_variables() + # model_brams = [var for var in model.get_weight_variables() if var.storage.lower() == 'bram'] + + # io_parallel and io_stream instantiate the top-level function differently (io_stream not yet supported) + # io_type = model.config.get_config_value('IOType') + # indent = ' ' + # brams_str = ', \n'.join([indent + b.definition_cpp(as_reference=False) for b in model_brams]) + + for line in f.readlines(): + if 'MYPROJECT' in line: + newline = line.replace('MYPROJECT', format(project_name.upper())) + + elif 'myproject' in line: + newline = line.replace('myproject', project_name) + + elif 'MyProject' in line: + newline = line.replace('MyProject', convert_to_pascal_case(project_name)) + + # Declarations for the inputs. May need modification when io_stream is supported + elif '// hls-fpga-machine-learning insert inputs' in line: + newline = line + for inp in model_inputs: + newline += inp.declare_cpp() + + # and declareations for the outputs + elif '// hls-fpga-machine-learning insert outputs' in line: + newline = line + for out in model_outputs: + newline += out.declare_cpp() + + # Simply copy line, if no inserts are required + else: + newline = line + + fout.write(newline) + + def write_test_bench(self, model): + """Write the testbench + + Args: + model (ModelGraph): the hls4ml model. + """ + # TODO - This function only works with one model input + # (NOT one data point - it works as expected with multiple data points) + + # copy the exception handler + filedir = os.path.dirname(os.path.abspath(__file__)) + srcpath = os.path.join(filedir, '../templates/oneapi/exception_handler.hpp') + dstpath = f'{model.config.get_output_dir()}/src/exception_handler.hpp' + copyfile(srcpath, dstpath) + + project_name = model.config.get_project_name() + model_inputs = model.get_input_variables() + model_outputs = model.get_output_variables() + model_brams = [var for var in model.get_weight_variables() if var.storage.lower() == 'bram'] + + if len(model_brams) != 0: + raise NotImplementedError("Weights on the interface is currently not supported") + + if len(model_inputs) != 1 or len(model_outputs) != 1: + print("The testbench supports only single input arrays and single output arrays.") + print("Please modify it before using it.") + + if not os.path.exists(f'{model.config.get_output_dir()}/tb_data/'): + os.mkdir(f'{model.config.get_output_dir()}/tb_data/') + + input_data = model.config.get_config_value('InputData') + output_predictions = model.config.get_config_value('OutputPredictions') + + if input_data: + if input_data[-3:] == "dat": + copyfile(input_data, f'{model.config.get_output_dir()}/tb_data/tb_input_features.dat') + else: + self.__make_dat_file(input_data, f'{model.config.get_output_dir()}/tb_data/tb_input_features.dat') + + if output_predictions: + if output_predictions[-3:] == "dat": + copyfile(output_predictions, f'{model.config.get_output_dir()}/tb_data/tb_output_predictions.dat') + else: + self.__make_dat_file( + output_predictions, f'{model.config.get_output_dir()}/tb_data/tb_output_predictions.dat' + ) + + with ( + open(os.path.join(filedir, '../templates/oneapi_accelerator/myproject_test.cpp')) as f, + open(f'{model.config.get_output_dir()}/src/{project_name}_test.cpp', 'w') as fout, + ): + for line in f.readlines(): + indent = ' ' * (len(line) - len(line.lstrip(' '))) + + if 'myproject' in line: + newline = line.replace('myproject', project_name) + elif 'MyProject' in line: + newline = line.replace('MyProject', convert_to_pascal_case(project_name)) + + elif '// hls-fpga-machine-learning insert bram' in line: + newline = line + for bram in model_brams: + newline += f'#include \"firmware/weights/{bram.name}.h\"\n' + elif '// hls-fpga-machine-learning insert runtime contant' in line: + newline = line + insert_constant_lines = ( + f'{indent}const size_t kInputSz = {model_inputs[0].size_cpp()} * num_iterations;\n' + f'{indent}const size_t kOutputSz = {model_outputs[0].size_cpp()} * num_iterations;\n' + f'{indent}const size_t kInputLayerSize = {model_inputs[0].size_cpp()};\n' + f'{indent}const size_t kOutLayerSize = {model_outputs[0].size_cpp()};\n' + ) + newline += insert_constant_lines + elif '// hls-fpga-machine-learning insert zero' in line: + newline = line + inp = model_inputs[0] + insert_zero_lines = ( + f'{indent}for (int j = 0 ; j < kInputSz; j++)\n' + f'{indent} vals[j] = 0.0;\n' + f'{indent}q.single_task(nnet::DMA_convert_data{{vals, num_iterations}});\n' + ) + newline += insert_zero_lines + elif '// hls-fpga-machine-learning insert data' in line: + newline = line + inp = model_inputs[0] + insert_data_lines = ( + f'{indent}for (int i = 0; i < num_iterations; i++)\n' + f'{indent} for (int j = 0 ; j < kInputLayerSize; j++)\n' + f'{indent} vals[i * kInputLayerSize + j] = inputs[i][j]; \n' + f'{indent}q.single_task(nnet::DMA_convert_data{{vals, num_iterations}});\n' + ) + newline += insert_data_lines + elif '// hls-fpga-machine-learning convert output' in line: + newline = line + out = model_outputs[0] + newline += f'{indent}q.single_task(nnet::DMA_convert_data_back<{out.pipe_name}, float>' + newline += '{outputs, num_iterations}).wait();\n' + else: + newline = line + + fout.write(newline) + + def write_bridge(self, model): + """Write the Python-C++ bridge (myproject_bridge.cpp) + + Args: + model (ModelGraph): the hls4ml model. + """ + project_name = model.config.get_project_name() + stamp = model.config.get_config_value('Stamp') + model_inputs = model.get_input_variables() + model_outputs = model.get_output_variables() + model_brams = [var for var in model.get_weight_variables() if var.storage.lower() == 'bram'] + # model brambs aren't actually supported yet + + # io_type = model.config.get_config_value('IOType') + indent = ' ' + + filedir = os.path.dirname(os.path.abspath(__file__)) + with ( + open(os.path.join(filedir, '../templates/oneapi/myproject_bridge.cpp')) as f, + open(f'{model.config.get_output_dir()}/src/{project_name}_bridge.cpp', 'w') as fout, + ): + for line in f.readlines(): + if 'MYPROJECT' in line: + newline = line.replace('MYPROJECT', format(project_name.upper())) + + elif 'myproject' in line: + newline = line.replace('myproject', format(project_name)) + + elif 'MyProject' in line: + newline = line.replace('MyProject', convert_to_pascal_case(project_name)) + + elif '// hls-fpga-machine-learning insert bram' in line: + newline = line + for bram in model_brams: + newline += f'#include \"firmware/weights/{bram.name}.h\"\n' + + elif '// hls-fpga-machine-learning insert class def' in line: + dtype = line.split('#', 1)[1].strip() + newline = f'class {convert_to_pascal_case(project_name)}Class{dtype.capitalize()}_{stamp};\n' + + elif '// hls-fpga-machine-learning insert header' in line: + dtype = line.split('#', 1)[1].strip() + inputs_str = ', '.join([f'{dtype} {i.name}[{i.size_cpp()}]' for i in model_inputs]) + outputs_str = ', '.join([f'{dtype} {o.name}[{o.size_cpp()}]' for o in model_outputs]) + + newline = '' + newline += indent + inputs_str + ',\n' + newline += indent + outputs_str + '\n' + + elif '// hls-fpga-machine-learning insert wrapper' in line: + dtype = line.split('#', 1)[1].strip() + newline = '' + for i in model_inputs: + newline += ( + indent + f'nnet::DMA_bridge_convert_data<{dtype}, {i.pipe_name}, {i.size_cpp()}>(q, {i.name});\n' + ) + + newline += ( + indent + + f'q.single_task<{convert_to_pascal_case(project_name)}Class{dtype.capitalize()}_{stamp}>' + + f'({convert_to_pascal_case(project_name)}{{}});\n' + ) + + for o in model_outputs: + newline += ( + indent + + f'nnet::DMA_bridge_convert_data_back<{o.pipe_name}, {dtype}, {o.size_cpp()}>(q, {o.name});\n' + ) + newline += '\n' + newline += indent + 'q.wait();\n' + + elif '// hls-fpga-machine-learning insert trace_outputs' in line: + newline = '' + for layer in model.get_layers(): + func = layer.get_attr('function_cpp') + if func and model.config.trace_output and layer.get_attr('trace', False): + vars = layer.get_variables() + for var in vars: + newline += ( + indent + + 'nnet::trace_outputs->insert(std::pair(' + + f'"{layer.name}", (void *) malloc({var.size_cpp()} * element_size)));\n' + ) + + else: + newline = line + fout.write(newline) + + def write_build_script(self, model): + """Write the build scripts (Makefile, build_lib.sh) + + Args: + model (ModelGraph): the hls4ml model. + """ + + # Makefile + filedir = os.path.dirname(os.path.abspath(__file__)) + device = model.config.get_config_value('Part') + period = model.config.get_config_value('ClockPeriod') + hyper = model.config.get_config_value('HyperoptHandshake') + with ( + open(os.path.join(filedir, '../templates/oneapi/CMakeLists.txt')) as f, + open(f'{model.config.get_output_dir()}/CMakeLists.txt', 'w') as fout, + ): + for line in f.readlines(): + line = line.replace('myproject', model.config.get_project_name()) + line = line.replace('mystamp', model.config.get_config_value('Stamp')) + + if 'set(FPGA_DEVICE' in line: + line = f' set(FPGA_DEVICE "{device}")\n' + + if model.config.get_config_value('UseOneAPIBSP'): + if 'hls-fpga-machine-learning insert oneapi_bsp_cmake_flag' in line: + line = 'set(BSP_FLAG "-DIS_BSP")' + + if 'set(USER_FPGA_FLAGS' in line: + line += f'set(USER_FPGA_FLAGS -Xsclock={period}ns; ${{USER_FPGA_FLAGS}})\n' + if not hyper: + line += 'set(USER_FPGA_FLAGS -Xsoptimize=latency; ${USER_FPGA_FLAGS})\n' + + fout.write(line)