Skip to content

Commit c15cc5e

Browse files
authored
[Target] Remove deprecated parameters from target (apache#12416)
* remove depricated parameters in target * lint * fix cpp tests fix * remove more configs in test files * address comments * fix error * fix hexagon * fix micro tutorial * fix integration tests * fix hexagon * lint * fix unittest * fix readme * fix assert executor in target * address comments * fix tutorials * fix hexagon target * fix tutorial * fix for tutorials * hexagon
1 parent 8174d08 commit c15cc5e

File tree

25 files changed

+80
-261
lines changed

25 files changed

+80
-261
lines changed

apps/hexagon_launcher/README.md

+2-2
Original file line numberDiff line numberDiff line change
@@ -118,7 +118,7 @@ mod, params = relay.frontend.from_tflite(
118118
tflite_model, shape_dict=shape_dict, dtype_dict=dtype_dict
119119
)
120120
121-
target = tvm.target.hexagon('v68', link_params=True)
121+
target = tvm.target.hexagon('v68')
122122
with tvm.transform.PassContext(opt_level=3):
123123
lib = relay.build(mod, tvm.target.Target(target, host=target), params=params, mod_name="default")
124124
@@ -172,7 +172,7 @@ A sample output JSON from running the Inception V3 model may look like
172172

173173
When using AoT, the `target` needs to be `llvm`:
174174
```
175-
aot_target = "llvm -keys=hexagon -link-params=0 -mattr=+hvxv69,+hvx-length128b,+hvx-qfloat,-hvx-ieee-fp -mcpu=hexagonv69 -mtriple=hexagon"
175+
aot_target = "llvm -keys=hexagon -mattr=+hvxv69,+hvx-length128b,+hvx-qfloat,-hvx-ieee-fp -mcpu=hexagonv69 -mtriple=hexagon"
176176
aot_host_target = aot_target
177177
```
178178

apps/howto_deploy/prepare_test_libs.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ def prepare_test_libs(base_path):
3333
fadd_dylib.export_library(dylib_path)
3434

3535
# Compile library in system library mode
36-
fadd_syslib = tvm.build(s, [A, B], "llvm --system-lib", name="addonesys")
36+
fadd_syslib = tvm.build(s, [A, B], "llvm", name="addonesys")
3737
syslib_path = os.path.join(base_path, "test_addone_sys.o")
3838
fadd_syslib.save(syslib_path)
3939

apps/sgx/src/build_model.py

+6-1
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,12 @@ def main():
3939
)
4040

4141
with tvm.transform.PassContext(opt_level=3):
42-
graph, lib, params = relay.build(net, "llvm --system-lib", params=params)
42+
graph, lib, params = relay.build(
43+
net,
44+
"llvm",
45+
params=params,
46+
runtime=tvm.relay.backend.Runtime("cpp", {"system-lib": True}),
47+
)
4348

4449
build_dir = osp.abspath(sys.argv[1])
4550
if not osp.isdir(build_dir):

apps/wasm-standalone/wasm-graph/tools/build_graph_lib.py

+7-2
Original file line numberDiff line numberDiff line change
@@ -72,10 +72,15 @@ def build_graph_lib(opt_level):
7272
shape_dict = {input_name: img_data.shape}
7373

7474
mod, params = relay.frontend.from_onnx(onnx_model, shape_dict)
75-
target = "llvm -mtriple=wasm32-unknown-unknown -mattr=+simd128 --system-lib"
75+
target = "llvm -mtriple=wasm32-unknown-unknown -mattr=+simd128"
7676

7777
with tvm.transform.PassContext(opt_level=opt_level):
78-
factory = relay.build(mod, target=target, params=params)
78+
factory = relay.build(
79+
mod,
80+
target=target,
81+
params=params,
82+
runtime=tvm.relay.backend.Runtime("cpp", {"system-lib": True}),
83+
)
7984

8085
# Save the model artifacts to obj_file
8186
obj_file = os.path.join(out_dir, "graph.o")
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,2 @@
11
# Keep a valid schedule for demonstraction. This is used to prevent flasky errors in CI.
2-
{"i": [["[\"matmul_add\", 1024, 1024, 1024, \"float32\"]", "llvm -keys=cpu -link-params=0", [18, 64, 64, 0, 0, 0, 0, 0]], [[], [["SP", 2, 0, 1024, [2, 1, 4], 1], ["SP", 2, 4, 1024, [1, 1, 8], 1], ["SP", 2, 8, 1024, [4], 1], ["RE", 2, [0, 4, 1, 5, 8, 2, 6, 9, 3, 7]], ["FSP", 4, 0, 0, 2], ["FSP", 4, 3, 1, 2], ["RE", 4, [0, 3, 1, 4, 2, 5]], ["CA", 2, 4, 3], ["FU", 4, [0, 1]], ["AN", 4, 0, 3], ["PR", 2, 0, "auto_unroll_max_step$8"], ["AN", 2, 9, 2], ["AN", 4, 4, 2]]]], "r": [[0.0044742], 0, 0.335558, 1607112214], "v": "v0.3"}
2+
{"i": [["[\"matmul_add\", 1024, 1024, 1024, \"float32\"]", "llvm -keys=cpu", [18, 64, 64, 0, 0, 0, 0, 0]], [[], [["SP", 2, 0, 1024, [2, 1, 4], 1], ["SP", 2, 4, 1024, [1, 1, 8], 1], ["SP", 2, 8, 1024, [4], 1], ["RE", 2, [0, 4, 1, 5, 8, 2, 6, 9, 3, 7]], ["FSP", 4, 0, 0, 2], ["FSP", 4, 3, 1, 2], ["RE", 4, [0, 3, 1, 4, 2, 5]], ["CA", 2, 4, 3], ["FU", 4, [0, 1]], ["AN", 4, 0, 3], ["PR", 2, 0, "auto_unroll_max_step$8"], ["AN", 2, 9, 2], ["AN", 4, 4, 2]]]], "r": [[0.0044742], 0, 0.335558, 1607112214], "v": "v0.3"}

gallery/how_to/tune_with_autoscheduler/ci_logs/resnet-50-NHWC-B1-llvm.json

+26-26
Large diffs are not rendered by default.
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,2 @@
11
# Keep a valid schedule for demonstraction. This is used to prevent flasky errors in CI.
2-
{"i": [["[\"sparse_dense\", 512, 512, 512, [9831, 16, 1], [9831], [33], \"float32\"]", "llvm -keys=cpu -link-params=0", [6, 64, 64, 0, 0, 0, 0, 0], "", 1, ["sparse_dense_bsr_512_512_512_16_1_0.60_W_data", "sparse_dense_bsr_512_512_512_16_1_0.60_W_indices", "sparse_dense_bsr_512_512_512_16_1_0.60_W_indptr"]], [[], [["CI", 8], ["CI", 6], ["SP", 5, 0, 512, [1, 8], 1], ["FSP", 9, 0, 2, 1], ["SP", 5, 3, 32, [32], 1], ["FSP", 9, 2, 4, 1], ["RE", 5, [0, 3, 1, 4, 6, 2, 5, 7]], ["RE", 9, [0, 2, 1, 3]], ["CA", 5, 9, 1], ["CI", 4], ["FU", 9, [0, 1]], ["AN", 9, 0, 3], ["PR", 5, 0, "auto_unroll_max_step$0"], ["AN", 9, 2, 2]]]], "r": [[0.000957008], 0, 0.605709, 1614689820], "v": "v0.6"}
2+
{"i": [["[\"sparse_dense\", 512, 512, 512, [9831, 16, 1], [9831], [33], \"float32\"]", "llvm -keys=cpu", [6, 64, 64, 0, 0, 0, 0, 0], "", 1, ["sparse_dense_bsr_512_512_512_16_1_0.60_W_data", "sparse_dense_bsr_512_512_512_16_1_0.60_W_indices", "sparse_dense_bsr_512_512_512_16_1_0.60_W_indptr"]], [[], [["CI", 8], ["CI", 6], ["SP", 5, 0, 512, [1, 8], 1], ["FSP", 9, 0, 2, 1], ["SP", 5, 3, 32, [32], 1], ["FSP", 9, 2, 4, 1], ["RE", 5, [0, 3, 1, 4, 6, 2, 5, 7]], ["RE", 9, [0, 2, 1, 3]], ["CA", 5, 9, 1], ["CI", 4], ["FU", 9, [0, 1]], ["AN", 9, 0, 3], ["PR", 5, 0, "auto_unroll_max_step$0"], ["AN", 9, 2, 2]]]], "r": [[0.000957008], 0, 0.605709, 1614689820], "v": "v0.6"}

gallery/how_to/tune_with_autotvm/tune_relay_x86.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -298,7 +298,7 @@ def tune_and_evaluate(tuning_opt):
298298
#
299299
# Evaluation of the network been tuned on graph level:
300300
# Compile...
301-
# Config for target=llvm -keys=cpu -link-params=0, workload=('dense_nopack.x86', ('TENSOR', (1, 512), 'float32'), ('TENSOR', (1000, 512), 'float32'), None, 'float32') is missing in ApplyGraphBest context. A fallback configuration is used, which may bring great performance regression.
302-
# Config for target=llvm -keys=cpu -link-params=0, workload=('dense_pack.x86', ('TENSOR', (1, 512), 'float32'), ('TENSOR', (1000, 512), 'float32'), None, 'float32') is missing in ApplyGraphBest context. A fallback configuration is used, which may bring great performance regression.
301+
# Config for target=llvm -keys=cpu, workload=('dense_nopack.x86', ('TENSOR', (1, 512), 'float32'), ('TENSOR', (1000, 512), 'float32'), None, 'float32') is missing in ApplyGraphBest context. A fallback configuration is used, which may bring great performance regression.
302+
# Config for target=llvm -keys=cpu, workload=('dense_pack.x86', ('TENSOR', (1, 512), 'float32'), ('TENSOR', (1000, 512), 'float32'), None, 'float32') is missing in ApplyGraphBest context. A fallback configuration is used, which may bring great performance regression.
303303
# Evaluate inference time cost...
304304
# Mean inference time (std dev): 3.16 ms (0.03 ms)

gallery/how_to/work_with_microtvm/micro_tvmc.sh

+2-2
Original file line numberDiff line numberDiff line change
@@ -99,7 +99,7 @@ wget https://github.com/tensorflow/tflite-micro/raw/main/tensorflow/lite/micro/e
9999
#
100100
# bash
101101
tvmc compile magic_wand.tflite \
102-
--target='c -keys=cpu -link-params=0 -model=host' \
102+
--target='c -keys=cpu -model=host' \
103103
--runtime=crt \
104104
--runtime-crt-system-lib 1 \
105105
--executor='graph' \
@@ -111,7 +111,7 @@ tvmc compile magic_wand.tflite \
111111
# bash
112112
# This will generate a ``model.tar`` file which contains TVM compiler output files. To run this command for
113113
# a different Zephyr device, you need to update ``target``. For instance, for ``nrf5340dk_nrf5340_cpuapp`` board
114-
# the target is ``--target='c -keys=cpu -link-params=0 -model=nrf5340dk'``.
114+
# the target is ``--target='c -keys=cpu -model=nrf5340dk'``.
115115
#
116116

117117

gallery/tutorial/auto_scheduler_matmul_x86.py

-2
Original file line numberDiff line numberDiff line change
@@ -44,8 +44,6 @@
4444
testing.utils.install_request_hook(depth=3)
4545
# sphinx_gallery_end_ignore
4646

47-
import os
48-
4947
import numpy as np
5048
import tvm
5149
from tvm import te, auto_scheduler

python/tvm/contrib/hexagon/pytest_plugin.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -245,7 +245,7 @@ def terminate_rpc_servers():
245245

246246
aot_host_target = tvm.testing.parameter(
247247
"c",
248-
"llvm -keys=hexagon -link-params=0 "
248+
"llvm -keys=hexagon "
249249
"-mattr=+hvxv68,+hvx-length128b,+hvx-qfloat,-hvx-ieee-fp "
250250
"-mcpu=hexagonv68 -mtriple=hexagon",
251251
)

python/tvm/relay/build_module.py

-78
Original file line numberDiff line numberDiff line change
@@ -274,69 +274,6 @@ def _build_module_no_factory(mod, target=None, target_host=None, params=None, mo
274274
return _build_module_no_factory_impl(mod, target, target_host, params, mod_name)
275275

276276

277-
def _reconstruct_from_deprecated_options(deprecated_params_target):
278-
executor = None
279-
runtime = None
280-
281-
deprecated_executor = None
282-
deprecated_executor_args = {}
283-
if "executor" in deprecated_params_target.attrs:
284-
_deprecated_target_param_warning("Executor", "executor")
285-
deprecated_executor = deprecated_params_target.attrs.get("executor", "graph")
286-
if "interface-api" in deprecated_params_target.attrs:
287-
_deprecated_target_sub_param_warning("Executor", "interface-api")
288-
deprecated_executor_args.update(
289-
{"interface-api": deprecated_params_target.attrs["interface-api"]}
290-
)
291-
if "unpacked-api" in deprecated_params_target.attrs:
292-
_deprecated_target_sub_param_warning("Executor", "unpacked-api")
293-
deprecated_executor_args.update(
294-
{"unpacked-api": deprecated_params_target.attrs["unpacked-api"]}
295-
)
296-
if (
297-
"link-params" in deprecated_params_target.attrs
298-
and deprecated_params_target.attrs["link-params"]
299-
):
300-
_deprecated_target_sub_param_warning("Executor", "link-params")
301-
if deprecated_executor != "aot":
302-
deprecated_executor_args.update(
303-
{"link-params": deprecated_params_target.attrs["link-params"]}
304-
)
305-
if deprecated_executor or deprecated_executor_args:
306-
executor = Executor(deprecated_executor or "graph", deprecated_executor_args)
307-
308-
deprecated_runtime = None
309-
deprecated_runtime_args = {}
310-
if "runtime" in deprecated_params_target.attrs:
311-
_deprecated_target_param_warning("Runtime", "runtime")
312-
deprecated_runtime = deprecated_params_target.attrs.get("runtime", "cpp")
313-
if deprecated_runtime == "c":
314-
deprecated_runtime = "crt"
315-
if "system-lib" in deprecated_params_target.attrs:
316-
_deprecated_target_sub_param_warning("Runtime", "system-lib")
317-
deprecated_runtime_args.update({"system-lib": deprecated_params_target.attrs["system-lib"]})
318-
if deprecated_runtime or deprecated_runtime_args:
319-
runtime = Runtime(deprecated_runtime or "cpp", deprecated_runtime_args)
320-
321-
return executor, runtime
322-
323-
324-
def _deprecated_target_param_warning(registry, param):
325-
warnings.warn(
326-
f"Please use {registry} (tvm.relay.backend.{registry}) "
327-
f"instead of deprecated Target parameter -{param}",
328-
DeprecationWarning,
329-
)
330-
331-
332-
def _deprecated_target_sub_param_warning(registry, param):
333-
warnings.warn(
334-
f"Please use {registry} (tvm.relay.backend.{registry}) parameter {param} "
335-
f"instead of deprecated Target parameter -{param}",
336-
DeprecationWarning,
337-
)
338-
339-
340277
def build(
341278
ir_mod,
342279
target=None,
@@ -415,17 +352,6 @@ def build(
415352
assert len(raw_targets) > 0
416353
target_host = raw_targets[0].host
417354

418-
# All of this logic is to raise deprecation warnings for various parameters
419-
# TODO(Mousius) Remove these after some time
420-
deprecated_params_target = target_host or list(raw_targets)[0]
421-
deprecated_executor, deprecated_runtime = _reconstruct_from_deprecated_options(
422-
deprecated_params_target
423-
)
424-
if deprecated_executor:
425-
executor = deprecated_executor
426-
if deprecated_runtime:
427-
runtime = deprecated_runtime
428-
429355
# If current dispatch context is fallback context (the default root context),
430356
# then load pre-tuned parameters from TopHub
431357
if isinstance(autotvm.DispatchContext.current, autotvm.FallbackContext):
@@ -756,9 +682,5 @@ def create_executor(kind="debug", mod=None, device=None, target="llvm", params=N
756682
if kind == "vm":
757683
return VMExecutor(mod, device, raw_targets)
758684
if kind == "aot":
759-
# The AOT requires the executor as a target attribute.
760-
# (The compilation paths for the other executors currently do not always provide this
761-
# attribute, hence the above generic assert is more forgiving).
762-
assert "executor" in raw_targets[0].attrs
763685
return AotExecutor(mod, device, raw_targets)
764686
raise RuntimeError("unknown execution strategy: {0}".format(kind))

python/tvm/target/target.py

+1-18
Original file line numberDiff line numberDiff line change
@@ -636,8 +636,6 @@ def hexagon(cpu_ver="v66", **kwargs):
636636
Whether to use QFloat HVX instructions.
637637
use_ieee_fp : bool (default: False)
638638
Whether to use IEEE HVX instructions
639-
link_params : bool (default: False)
640-
Whether to link graph parameters into the LLVM module.
641639
642640
Note: Floating point support in HVX requires LLVM 14+.
643641
"""
@@ -671,7 +669,6 @@ def get_arch_version(cpu_ver):
671669
"llvm_options": None,
672670
"use_qfloat": arch_version >= 68,
673671
"use_ieee_fp": False,
674-
"link_params": False,
675672
}
676673
config.update(kwargs)
677674

@@ -738,24 +735,10 @@ def create_llvm_options(cpu_ver, config): # pylint: disable=unused-argument
738735
args = [s.replace("=", "@") for s in llvm_options.split()]
739736
return "--llvm-options=" + ",".join(args)
740737

741-
# TVM target attributes string
742-
def create_tvm_options(cpu_ver, config): # pylint: disable=unused-argument
743-
"""Create TVM target features string."""
744-
745-
features = {
746-
"link_params": "link-params",
747-
}
748-
opts = ""
749-
for k in config:
750-
if k in features:
751-
opts += " --" + features[k] + "=" + str(config[k])
752-
return opts
753-
754738
target_str = create_llvm_target(cpu_ver, config)
755739
llvm_str = create_llvm_options(cpu_ver, config)
756-
tvm_str = create_tvm_options(cpu_ver, config)
757740

758-
args_list = target_str.split() + llvm_str.split() + tvm_str.split()
741+
args_list = target_str.split() + llvm_str.split()
759742

760743
return Target(" ".join(["hexagon"] + args_list))
761744

src/target/target_kind.cc

+6-32
Original file line numberDiff line numberDiff line change
@@ -264,12 +264,7 @@ TVM_REGISTER_TARGET_KIND("llvm", kDLCPU)
264264
.add_attr_option<String>("mtriple")
265265
.add_attr_option<String>("mfloat-abi")
266266
.add_attr_option<String>("mabi")
267-
.add_attr_option<Bool>("system-lib")
268-
.add_attr_option<String>("runtime")
269267
.add_attr_option<Integer>("num-cores")
270-
.add_attr_option<Bool>("link-params", Bool(false))
271-
.add_attr_option<Bool>("unpacked-api")
272-
.add_attr_option<String>("interface-api")
273268
// Fast math flags, see https://llvm.org/docs/LangRef.html#fast-math-flags
274269
.add_attr_option<Bool>("fast-math") // implies all the below
275270
.add_attr_option<Bool>("fast-math-nnan")
@@ -310,23 +305,16 @@ TVM_REGISTER_TARGET_KIND("llvm", kDLCPU)
310305
// Hence the type is "uint".
311306

312307
TVM_REGISTER_TARGET_KIND("c", kDLCPU)
313-
.add_attr_option<Bool>("system-lib")
314-
.add_attr_option<Bool>("link-params", Bool(false))
315-
.add_attr_option<String>("runtime")
316308
.add_attr_option<String>("mcpu")
317309
.add_attr_option<String>("march")
318-
.add_attr_option<String>("executor")
319310
.add_attr_option<Integer>("workspace-byte-alignment")
320311
.add_attr_option<Integer>("constants-byte-alignment")
321-
.add_attr_option<Bool>("unpacked-api")
322-
.add_attr_option<String>("interface-api")
323312
.set_default_keys({"cpu"})
324313
.set_target_parser(tvm::target::parsers::cpu::ParseTarget);
325314

326315
TVM_REGISTER_TARGET_KIND("cuda", kDLCUDA)
327316
.add_attr_option<String>("mcpu")
328317
.add_attr_option<String>("arch")
329-
.add_attr_option<Bool>("system-lib")
330318
.add_attr_option<Integer>("max_shared_memory_per_block")
331319
.add_attr_option<Integer>("max_threads_per_block")
332320
.add_attr_option<Integer>("thread_warp_size", Integer(32))
@@ -338,7 +326,6 @@ TVM_REGISTER_TARGET_KIND("cuda", kDLCUDA)
338326
TVM_REGISTER_TARGET_KIND("nvptx", kDLCUDA)
339327
.add_attr_option<String>("mcpu")
340328
.add_attr_option<String>("mtriple")
341-
.add_attr_option<Bool>("system-lib")
342329
.add_attr_option<Integer>("max_num_threads", Integer(1024))
343330
.add_attr_option<Integer>("thread_warp_size", Integer(32))
344331
.set_default_keys({"cuda", "gpu"})
@@ -348,7 +335,6 @@ TVM_REGISTER_TARGET_KIND("rocm", kDLROCM)
348335
.add_attr_option<String>("mcpu")
349336
.add_attr_option<String>("mtriple")
350337
.add_attr_option<Array<String>>("mattr")
351-
.add_attr_option<Bool>("system-lib")
352338
// TODO(masahi): Support querying from a target device
353339
// On RDNA cards, thread_warp_size should be 32
354340
.add_attr_option<Integer>("max_num_threads", Integer(256))
@@ -359,7 +345,6 @@ TVM_REGISTER_TARGET_KIND("rocm", kDLROCM)
359345
.set_target_parser(UpdateROCmAttrs);
360346

361347
TVM_REGISTER_TARGET_KIND("opencl", kDLOpenCL)
362-
.add_attr_option<Bool>("system-lib")
363348
.add_attr_option<Integer>("max_num_threads", Integer(256))
364349
.add_attr_option<Integer>("thread_warp_size", Integer(1))
365350
.add_attr_option<Integer>("texture_spatial_limit", Integer(16384))
@@ -370,15 +355,13 @@ TVM_REGISTER_TARGET_KIND("opencl", kDLOpenCL)
370355
// information about this limitation can be found here:
371356
// https://developer.apple.com/documentation/metal/buffers/about_argument_buffers?language=objc
372357
TVM_REGISTER_TARGET_KIND("metal", kDLMetal)
373-
.add_attr_option<Bool>("system-lib")
374358
.add_attr_option<Integer>("max_num_threads", Integer(256))
375359
.add_attr_option<Integer>("thread_warp_size", Integer(16))
376360
.add_attr_option<Integer>("max_function_args", Integer(31))
377361
.set_default_keys({"metal", "gpu"});
378362

379363
TVM_REGISTER_TARGET_KIND("vulkan", kDLVulkan)
380364
.add_attr_option<Array<String>>("mattr")
381-
.add_attr_option<Bool>("system-lib")
382365
// Feature support
383366
.add_attr_option<Bool>("supports_float16")
384367
.add_attr_option<Bool>("supports_float32", Bool(true))
@@ -417,39 +400,30 @@ TVM_REGISTER_TARGET_KIND("vulkan", kDLVulkan)
417400
.set_default_keys({"vulkan", "gpu"});
418401

419402
TVM_REGISTER_TARGET_KIND("webgpu", kDLWebGPU)
420-
.add_attr_option<Bool>("system-lib")
421403
.add_attr_option<Integer>("max_num_threads", Integer(256))
422404
.set_default_keys({"webgpu", "gpu"});
423405

424-
TVM_REGISTER_TARGET_KIND("sdaccel", kDLOpenCL)
425-
.add_attr_option<Bool>("system-lib")
406+
TVM_REGISTER_TARGET_KIND("sdaccel", kDLOpenCL) // line break
426407
.set_default_keys({"sdaccel", "hls"});
427408

428-
TVM_REGISTER_TARGET_KIND("aocl", kDLAOCL)
429-
.add_attr_option<Bool>("system-lib")
409+
TVM_REGISTER_TARGET_KIND("aocl", kDLAOCL) // line break
430410
.set_default_keys({"aocl", "hls"});
431411

432-
TVM_REGISTER_TARGET_KIND("aocl_sw_emu", kDLAOCL)
433-
.add_attr_option<Bool>("system-lib")
412+
TVM_REGISTER_TARGET_KIND("aocl_sw_emu", kDLAOCL) // line break
434413
.set_default_keys({"aocl", "hls"});
435414

436415
TVM_REGISTER_TARGET_KIND("hexagon", kDLHexagon)
437416
.add_attr_option<Array<String>>("mattr")
438417
.add_attr_option<String>("mcpu")
439418
.add_attr_option<String>("mtriple")
440-
.add_attr_option<Bool>("system-lib")
441-
.add_attr_option<Bool>("link-params", Bool(false))
442419
.add_attr_option<Array<String>>("llvm-options")
443420
.set_default_keys({"hexagon"});
444421

445-
TVM_REGISTER_TARGET_KIND("stackvm", kDLCPU) // line break
446-
.add_attr_option<Bool>("system-lib");
422+
TVM_REGISTER_TARGET_KIND("stackvm", kDLCPU);
447423

448-
TVM_REGISTER_TARGET_KIND("ext_dev", kDLExtDev) // line break
449-
.add_attr_option<Bool>("system-lib");
424+
TVM_REGISTER_TARGET_KIND("ext_dev", kDLExtDev);
450425

451-
TVM_REGISTER_TARGET_KIND("hybrid", kDLCPU) // line break
452-
.add_attr_option<Bool>("system-lib");
426+
TVM_REGISTER_TARGET_KIND("hybrid", kDLCPU);
453427

454428
TVM_REGISTER_TARGET_KIND("composite", kDLCPU) // line break
455429
.add_attr_option<Array<Target>>("devices");

0 commit comments

Comments
 (0)