Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
156 changes: 95 additions & 61 deletions Online/inference/08-LSTM+CRF/mindspore_sequence_labeling.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -48,13 +48,42 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 1,
"id": "f01741d9",
"metadata": {},
"outputs": [],
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/home/mindspore/miniconda/envs/jupyter/lib/python3.9/site-packages/numpy/core/getlimits.py:549: UserWarning: The value of the smallest subnormal for <class 'numpy.float64'> type is zero.\n",
" setattr(self, word, getattr(machar, word).flat[0])\n",
"/home/mindspore/miniconda/envs/jupyter/lib/python3.9/site-packages/numpy/core/getlimits.py:89: UserWarning: The value of the smallest subnormal for <class 'numpy.float64'> type is zero.\n",
" return self._float_to_str(self.smallest_subnormal)\n",
"/home/mindspore/miniconda/envs/jupyter/lib/python3.9/site-packages/numpy/core/getlimits.py:549: UserWarning: The value of the smallest subnormal for <class 'numpy.float32'> type is zero.\n",
" setattr(self, word, getattr(machar, word).flat[0])\n",
"/home/mindspore/miniconda/envs/jupyter/lib/python3.9/site-packages/numpy/core/getlimits.py:89: UserWarning: The value of the smallest subnormal for <class 'numpy.float32'> type is zero.\n",
" return self._float_to_str(self.smallest_subnormal)\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"2.6.0\n"
]
}
],
"source": [
"import mindspore\n",
"mindspore.set_context(max_device_memory=\"2GB\", mode=mindspore.GRAPH_MODE, device_target=\"Ascend\", jit_config={\"jit_level\":\"O2\"}, ascend_config={\"precision_mode\":\"allow_mix_precision\"})"
"import mindspore as ms\n",
"import mindspore.nn as nn\n",
"import mindspore.ops as ops\n",
"import mindspore.numpy as mnp\n",
"from mindspore.common.initializer import initializer, Uniform\n",
"import mindspore.mint as mint\n",
"mindspore.set_context(max_device_memory=\"2GB\", mode=mindspore.PYNATIVE_MODE, device_target=\"Ascend\", jit_config={\"jit_level\":\"O0\"}, ascend_config={\"precision_mode\":\"allow_mix_precision\"})\n",
"print(mindspore.__version__)\n"
]
},
{
Expand Down Expand Up @@ -124,7 +153,7 @@
},
{
"cell_type": "code",
"execution_count": 1,
"execution_count": 2,
"id": "d09936ad-6b20-4423-9f57-8e14917e61d4",
"metadata": {},
"outputs": [],
Expand All @@ -142,7 +171,7 @@
" score = start_trans[tags[0]]\n",
" # score += 第一次发射概率\n",
" # shape: (batch_size,)\n",
" score += emissions[0, mnp.arange(batch_size), tags[0]]\n",
" score += emissions[0, mint.arange(batch_size), tags[0]]\n",
"\n",
" for i in range(1, seq_length):\n",
" # 标签由i-1转移至i的转移概率(当mask == 1时有效)\n",
Expand All @@ -151,11 +180,11 @@
"\n",
" # 预测tags[i]的发射概率(当mask == 1时有效)\n",
" # shape: (batch_size,)\n",
" score += emissions[i, mnp.arange(batch_size), tags[i]] * mask[i]\n",
" score += emissions[i, mint.arange(batch_size), tags[i]] * mask[i]\n",
"\n",
" # 结束转移\n",
" # shape: (batch_size,)\n",
" last_tags = tags[seq_ends, mnp.arange(batch_size)]\n",
" last_tags = tags[seq_ends, mint.arange(batch_size)]\n",
" # score += 结束转移概率\n",
" # shape: (batch_size,)\n",
" score += end_trans[last_tags]\n",
Expand Down Expand Up @@ -185,7 +214,7 @@
},
{
"cell_type": "code",
"execution_count": 2,
"execution_count": 3,
"id": "d9a0ef6a-1c3a-400e-9053-e0659e8f9e7e",
"metadata": {},
"outputs": [],
Expand Down Expand Up @@ -217,18 +246,18 @@
"\n",
" # 对score_i做log_sum_exp运算,用于下一个Token的score计算\n",
" # shape: (batch_size, num_tags)\n",
" next_score = ops.logsumexp(next_score, axis=1)\n",
" next_score = mint.logsumexp(next_score, dim=1)\n",
"\n",
" # 当mask == 1时,score才会变化\n",
" # shape: (batch_size, num_tags)\n",
" score = mnp.where(mask[i].expand_dims(1), next_score, score)\n",
" score = mint.where(mask[i].expand_dims(1), next_score, score)\n",
"\n",
" # 最后加结束转移概率\n",
" # shape: (batch_size, num_tags)\n",
" score += end_trans\n",
" # 对所有可能的路径得分求log_sum_exp\n",
" # shape: (batch_size,)\n",
" return ops.logsumexp(score, axis=1)"
" return mint.logsumexp(score, dim=1)"
]
},
{
Expand All @@ -251,7 +280,7 @@
},
{
"cell_type": "code",
"execution_count": 3,
"execution_count": 4,
"id": "7c286935-74eb-413a-8e47-b8fb5264e6b6",
"metadata": {},
"outputs": [],
Expand All @@ -262,20 +291,22 @@
"\n",
" seq_length = mask.shape[0]\n",
"\n",
" score = start_trans + emissions[0]\n",
" score = (start_trans + emissions[0]).astype(ms.float32)\n",
" history = ()\n",
"\n",
" mask_bool = mask.astype(ms.uint8)\n",
" for i in range(1, seq_length):\n",
" broadcast_score = score.expand_dims(2)\n",
" broadcast_emission = emissions[i].expand_dims(1)\n",
" next_score = broadcast_score + trans + broadcast_emission\n",
"\n",
" # 求当前Token对应score取值最大的标签,并保存\n",
" #indices = next_score.argmax(axis=1)\n",
" indices = next_score.argmax(axis=1)\n",
" history += (indices,)\n",
"\n",
" next_score = next_score.max(axis=1)\n",
" score = mnp.where(mask[i].expand_dims(1), next_score, score)\n",
" score = mint.where(mask_bool[i].expand_dims(1), next_score, score)\n",
"\n",
"\n",
" score += end_trans\n",
"\n",
Expand Down Expand Up @@ -321,37 +352,15 @@
},
{
"cell_type": "code",
"execution_count": 4,
"execution_count": 5,
"id": "120a39c7-c89d-4cd3-8a75-da27d2853f0e",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"[WARNING] ME(97316:281472950050832,MainProcess):2024-09-06-10:50:02.452.307 [mindspore/run_check/_check_version.py:357] MindSpore version 2.2.14 and Ascend AI software package (Ascend Data Center Solution)version 7.0 does not match, the version of software package expect one of ['7.1']. Please refer to the match info on: https://www.mindspore.cn/install\n",
"/home/ma-user/anaconda3/envs/MindSpore/lib/python3.9/site-packages/numpy/core/getlimits.py:549: UserWarning: The value of the smallest subnormal for <class 'numpy.float64'> type is zero.\n",
" setattr(self, word, getattr(machar, word).flat[0])\n",
"/home/ma-user/anaconda3/envs/MindSpore/lib/python3.9/site-packages/numpy/core/getlimits.py:549: UserWarning: The value of the smallest subnormal for <class 'numpy.float32'> type is zero.\n",
" setattr(self, word, getattr(machar, word).flat[0])\n",
"[WARNING] ME(97316:281472950050832,MainProcess):2024-09-06-10:50:04.855.681 [mindspore/run_check/_check_version.py:375] MindSpore version 2.2.14 and \"te\" wheel package version 7.0 does not match. For details, refer to the installation guidelines: https://www.mindspore.cn/install\n",
"[WARNING] ME(97316:281472950050832,MainProcess):2024-09-06-10:50:04.861.443 [mindspore/run_check/_check_version.py:382] MindSpore version 2.2.14 and \"hccl\" wheel package version 7.0 does not match. For details, refer to the installation guidelines: https://www.mindspore.cn/install\n",
"[WARNING] ME(97316:281472950050832,MainProcess):2024-09-06-10:50:04.862.319 [mindspore/run_check/_check_version.py:396] Please pay attention to the above warning, countdown: 3\n",
"[WARNING] ME(97316:281472950050832,MainProcess):2024-09-06-10:50:05.864.215 [mindspore/run_check/_check_version.py:396] Please pay attention to the above warning, countdown: 2\n",
"[WARNING] ME(97316:281472950050832,MainProcess):2024-09-06-10:50:06.866.175 [mindspore/run_check/_check_version.py:396] Please pay attention to the above warning, countdown: 1\n"
]
}
],
"outputs": [],
"source": [
"import mindspore as ms\n",
"import mindspore.nn as nn\n",
"import mindspore.ops as ops\n",
"import mindspore.numpy as mnp\n",
"from mindspore.common.initializer import initializer, Uniform\n",
"\n",
"def sequence_mask(seq_length, max_length, batch_first=False):\n",
" \"\"\"根据序列实际长度和最大长度生成mask矩阵\"\"\"\n",
" range_vector = mnp.arange(0, max_length, 1, seq_length.dtype)\n",
" range_vector = mint.arange(0, max_length, 1, dtype=seq_length.dtype)\n",
" result = range_vector < seq_length.view(seq_length.shape + (1,))\n",
" if batch_first:\n",
" return result.astype(ms.int64)\n",
Expand Down Expand Up @@ -385,7 +394,7 @@
" max_length, batch_size = tags.shape\n",
"\n",
" if seq_length is None:\n",
" seq_length = mnp.full((batch_size,), max_length, ms.int64)\n",
" seq_length = mint.full((batch_size,), max_length, ms.int64)\n",
"\n",
" mask = sequence_mask(seq_length, max_length)\n",
"\n",
Expand All @@ -412,7 +421,7 @@
" batch_size, max_length = emissions.shape[:2]\n",
"\n",
" if seq_length is None:\n",
" seq_length = mnp.full((batch_size,), max_length, ms.int64)\n",
" seq_length = mint.full((batch_size,), max_length, ms.int64)\n",
"\n",
" mask = sequence_mask(seq_length, max_length)\n",
"\n",
Expand All @@ -437,7 +446,7 @@
},
{
"cell_type": "code",
"execution_count": 5,
"execution_count": 6,
"id": "c07555e3-f2a2-4c25-beff-5a78491ab2d1",
"metadata": {},
"outputs": [],
Expand Down Expand Up @@ -469,7 +478,7 @@
},
{
"cell_type": "code",
"execution_count": 6,
"execution_count": 7,
"id": "fa53f535-34bc-49a3-b769-e85d3a184cc0",
"metadata": {},
"outputs": [],
Expand Down Expand Up @@ -497,7 +506,7 @@
},
{
"cell_type": "code",
"execution_count": 7,
"execution_count": 8,
"id": "5c255358-938b-4e88-a810-2fe4e616a044",
"metadata": {},
"outputs": [
Expand All @@ -507,7 +516,7 @@
"21"
]
},
"execution_count": 7,
"execution_count": 8,
"metadata": {},
"output_type": "execute_result"
}
Expand All @@ -526,7 +535,7 @@
},
{
"cell_type": "code",
"execution_count": 8,
"execution_count": 9,
"id": "e6d3f2bb-0ea7-457c-8a74-a8e0ea78a109",
"metadata": {},
"outputs": [],
Expand All @@ -544,7 +553,7 @@
},
{
"cell_type": "code",
"execution_count": 9,
"execution_count": 10,
"id": "7899b6e6-95a4-4ffe-ba49-b27cebe8b306",
"metadata": {},
"outputs": [],
Expand All @@ -569,7 +578,7 @@
},
{
"cell_type": "code",
"execution_count": 10,
"execution_count": 11,
"id": "0e984dcb-6f89-4520-940f-620cdb997529",
"metadata": {},
"outputs": [
Expand All @@ -579,7 +588,7 @@
"((2, 11), (2, 11), (2,))"
]
},
"execution_count": 10,
"execution_count": 11,
"metadata": {},
"output_type": "execute_result"
}
Expand All @@ -599,10 +608,33 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 12,
"id": "65304086-de03-4edc-b3a2-3ca5e43c4795",
"metadata": {},
"outputs": [],
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Downloading data from https://cdn.modelers.cn/lfs/b8/d8/c8dd957c842f8c2c4c07ee2e77dee1fca7abe0e8305fdc6b301453aae00b?response-content-disposition=attachment%3B+filename%3D%22lstm-crf.ckpt%22&AWSAccessKeyId=HAZQA0Q6AQL2GHX4TKTL&Expires=1761968039&Signature=eI%2FfqpDCWGkcGche5vPhitwqZik%3D (19 kB)\n",
"\n",
"file_sizes: 100%|██████████████████████████| 19.8k/19.8k [00:00<00:00, 10.8MB/s]\n",
"Successfully downloaded file to ./lstm-crf.ckpt\n"
]
},
{
"data": {
"text/plain": [
"Tensor(shape=[2, 3], dtype=Float32, value=\n",
"[[ 3.28305664e+01, 3.76408653e+01, 3.22400093e+01],\n",
" [ 2.89628849e+01, 2.69458580e+01, 3.40970802e+01]])"
]
},
"execution_count": 12,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"from mindspore import load_param_into_net, load_checkpoint\n",
"from download import download\n",
Expand All @@ -628,7 +660,7 @@
},
{
"cell_type": "code",
"execution_count": 20,
"execution_count": 13,
"id": "ed5ff78a-099c-41f8-ac1e-8f9fbfab7659",
"metadata": {},
"outputs": [
Expand All @@ -638,7 +670,7 @@
"[[0, 1, 1, 1, 2, 2, 2, 2, 2, 0, 1], [0, 1, 2, 2, 2, 2, 2, 2, 2]]"
]
},
"execution_count": 20,
"execution_count": 13,
"metadata": {},
"output_type": "execute_result"
}
Expand All @@ -658,7 +690,7 @@
},
{
"cell_type": "code",
"execution_count": 21,
"execution_count": 14,
"id": "df505983-9c1f-4e09-8a67-1cd78ce69fd0",
"metadata": {},
"outputs": [],
Expand All @@ -674,9 +706,11 @@
},
{
"cell_type": "code",
"execution_count": 22,
"execution_count": 15,
"id": "e30d43c7-7c09-445e-94a4-33fccfaeaf11",
"metadata": {},
"metadata": {
"tags": []
},
"outputs": [
{
"data": {
Expand All @@ -685,7 +719,7 @@
" ['B', 'I', 'O', 'O', 'O', 'O', 'O', 'O', 'O']]"
]
},
"execution_count": 22,
"execution_count": 15,
"metadata": {},
"output_type": "execute_result"
}
Expand All @@ -697,9 +731,9 @@
],
"metadata": {
"kernelspec": {
"display_name": "MindSpore",
"display_name": "jupyter",
"language": "python",
"name": "mindspore"
"name": "jupyter"
},
"language_info": {
"codemirror_mode": {
Expand All @@ -711,7 +745,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.18"
"version": "3.9.23"
},
"vscode": {
"interpreter": {
Expand Down