|
60 | 60 | "import os\n",
|
61 | 61 | "import shutil\n",
|
62 | 62 | "from tempfile import TemporaryDirectory\n",
|
63 |
| - "import scrapbook as sb\n", |
64 | 63 | "import pprint\n",
|
| 64 | + "import scrapbook as sb\n", |
| 65 | + "import sys\n", |
65 | 66 | "import time\n",
|
| 67 | + "import torch\n", |
| 68 | + "\n", |
| 69 | + "nlp_path = os.path.abspath(\"../../\")\n", |
| 70 | + "if nlp_path not in sys.path:\n", |
| 71 | + " sys.path.insert(0, nlp_path)\n", |
66 | 72 | "\n",
|
67 | 73 | "from utils_nlp.dataset.cnndm import CNNDMSummarizationDatasetOrg\n",
|
68 | 74 | "from utils_nlp.models.transformers.abstractive_summarization_seq2seq import S2SAbsSumProcessor, S2SAbstractiveSummarizer\n",
|
69 | 75 | "from utils_nlp.eval import compute_rouge_python\n",
|
70 | 76 | "\n",
|
| 77 | + "from utils_nlp.models.transformers.datasets import SummarizationDataset\n", |
| 78 | + "from utils_nlp.dataset.cnndm import detokenize\n", |
| 79 | + "\n", |
71 | 80 | "start_time = time.time()"
|
72 | 81 | ]
|
73 | 82 | },
|
|
82 | 91 | "outputs": [],
|
83 | 92 | "source": [
|
84 | 93 | "# model parameters\n",
|
85 |
| - "MODEL_NAME = \"unilm-large-cased\"\n", |
| 94 | + "MODEL_NAME = \"unilm-base-cased\"\n", |
86 | 95 | "MAX_SEQ_LENGTH = 768\n",
|
87 | 96 | "MAX_SOURCE_SEQ_LENGTH = 640\n",
|
88 | 97 | "MAX_TARGET_SEQ_LENGTH = 128\n",
|
89 | 98 | "\n",
|
| 99 | + "# use 0 for CPU\n", |
| 100 | + "NUM_GPUS = torch.cuda.device_count()\n", |
| 101 | + "\n", |
90 | 102 | "# fine-tuning parameters\n",
|
91 | 103 | "TRAIN_PER_GPU_BATCH_SIZE = 1\n",
|
92 | 104 | "GRADIENT_ACCUMULATION_STEPS = 2\n",
|
|
101 | 113 | " WARMUP_STEPS = 5\n",
|
102 | 114 | " MAX_STEPS = 50\n",
|
103 | 115 | " BEAM_SIZE = 3\n",
|
| 116 | + " if NUM_GPUS == 0:\n", |
| 117 | + " TOP_N = 5\n", |
| 118 | + " MAX_STEPS = 10\n", |
104 | 119 | "\n",
|
105 | 120 | "# inference parameters\n",
|
106 | 121 | "TEST_PER_GPU_BATCH_SIZE = 12\n",
|
|
220 | 235 | "source": [
|
221 | 236 | "abs_summarizer.fit(\n",
|
222 | 237 | " train_dataset=train_dataset,\n",
|
| 238 | + " num_gpus=NUM_GPUS,\n", |
223 | 239 | " per_gpu_batch_size=TRAIN_PER_GPU_BATCH_SIZE,\n",
|
224 | 240 | " gradient_accumulation_steps=GRADIENT_ACCUMULATION_STEPS,\n",
|
225 | 241 | " learning_rate=LEARNING_RATE,\n",
|
|
240 | 256 | {
|
241 | 257 | "cell_type": "code",
|
242 | 258 | "execution_count": null,
|
243 |
| - "metadata": {}, |
| 259 | + "metadata": { |
| 260 | + "scrolled": true |
| 261 | + }, |
244 | 262 | "outputs": [],
|
245 | 263 | "source": [
|
246 |
| - "res = abs_summarizer.predict(\n", |
| 264 | + "predictions = abs_summarizer.predict(\n", |
247 | 265 | " test_dataset=test_dataset,\n",
|
| 266 | + " num_gpus=NUM_GPUS,\n", |
248 | 267 | " per_gpu_batch_size=TEST_PER_GPU_BATCH_SIZE,\n",
|
249 | 268 | " beam_size=BEAM_SIZE,\n",
|
250 | 269 | " forbid_ignore_word=FORBID_IGNORE_WORD,\n",
|
|
258 | 277 | "metadata": {},
|
259 | 278 | "outputs": [],
|
260 | 279 | "source": [
|
261 |
| - "for r in res[:5]:\n", |
| 280 | + "for r in predictions[:TOP_N]:\n", |
262 | 281 | " print(r)"
|
263 | 282 | ]
|
264 | 283 | },
|
| 284 | + { |
| 285 | + "cell_type": "code", |
| 286 | + "execution_count": null, |
| 287 | + "metadata": {}, |
| 288 | + "outputs": [], |
| 289 | + "source": [ |
| 290 | + "test_ds.get_source()[0]" |
| 291 | + ] |
| 292 | + }, |
| 293 | + { |
| 294 | + "cell_type": "code", |
| 295 | + "execution_count": null, |
| 296 | + "metadata": {}, |
| 297 | + "outputs": [], |
| 298 | + "source": [ |
| 299 | + "test_ds.get_target()[0]" |
| 300 | + ] |
| 301 | + }, |
| 302 | + { |
| 303 | + "cell_type": "code", |
| 304 | + "execution_count": null, |
| 305 | + "metadata": {}, |
| 306 | + "outputs": [], |
| 307 | + "source": [ |
| 308 | + "predictions[0]" |
| 309 | + ] |
| 310 | + }, |
265 | 311 | {
|
266 | 312 | "cell_type": "code",
|
267 | 313 | "execution_count": null,
|
268 | 314 | "metadata": {},
|
269 | 315 | "outputs": [],
|
270 | 316 | "source": [
|
271 | 317 | "with open(OUTPUT_FILE, 'w', encoding=\"utf-8\") as f:\n",
|
272 |
| - " for line in res:\n", |
| 318 | + " for line in predictions:\n", |
273 | 319 | " f.write(line + '\\n')"
|
274 | 320 | ]
|
275 | 321 | },
|
| 322 | + { |
| 323 | + "cell_type": "markdown", |
| 324 | + "metadata": {}, |
| 325 | + "source": [ |
| 326 | + "## Prediction on a single input sample" |
| 327 | + ] |
| 328 | + }, |
| 329 | + { |
| 330 | + "cell_type": "code", |
| 331 | + "execution_count": null, |
| 332 | + "metadata": {}, |
| 333 | + "outputs": [], |
| 334 | + "source": [ |
| 335 | + "source = \"\"\"\n", |
| 336 | + "But under the new rule, set to be announced in the next 48 hours, Border Patrol agents would immediately return anyone to Mexico — without any detainment and without any due process — who attempts to cross the southwestern border between the legal ports of entry. The person would not be held for any length of time in an American facility.\n", |
| 337 | + "\n", |
| 338 | + "Although they advised that details could change before the announcement, administration officials said the measure was needed to avert what they fear could be a systemwide outbreak of the coronavirus inside detention facilities along the border. Such an outbreak could spread quickly through the immigrant population and could infect large numbers of Border Patrol agents, leaving the southwestern border defenses weakened, the officials argued.\n", |
| 339 | + "The Trump administration plans to immediately turn back all asylum seekers and other foreigners attempting to enter the United States from Mexico illegally, saying the nation cannot risk allowing the coronavirus to spread through detention facilities and Border Patrol agents, four administration officials said.\n", |
| 340 | + "The administration officials said the ports of entry would remain open to American citizens, green-card holders and foreigners with proper documentation. Some foreigners would be blocked, including Europeans currently subject to earlier travel restrictions imposed by the administration. The points of entry will also be open to commercial traffic.\"\"\"" |
| 341 | + ] |
| 342 | + }, |
| 343 | + { |
| 344 | + "cell_type": "code", |
| 345 | + "execution_count": null, |
| 346 | + "metadata": {}, |
| 347 | + "outputs": [], |
| 348 | + "source": [ |
| 349 | + "singel_test_ds = SummarizationDataset(\n", |
| 350 | + " None, source=[source], source_preprocessing=[detokenize],\n", |
| 351 | + ")\n", |
| 352 | + "single_test_dataset = processor.s2s_dataset_from_sum_ds(singel_test_ds, train_mode=False)" |
| 353 | + ] |
| 354 | + }, |
| 355 | + { |
| 356 | + "cell_type": "code", |
| 357 | + "execution_count": null, |
| 358 | + "metadata": {}, |
| 359 | + "outputs": [], |
| 360 | + "source": [ |
| 361 | + "single_prediction = abs_summarizer.predict(\n", |
| 362 | + " test_dataset=single_test_dataset,\n", |
| 363 | + " num_gpus=NUM_GPUS,\n", |
| 364 | + " per_gpu_batch_size=1,\n", |
| 365 | + " beam_size=BEAM_SIZE,\n", |
| 366 | + " forbid_ignore_word=FORBID_IGNORE_WORD,\n", |
| 367 | + " fp16=FP16\n", |
| 368 | + ")" |
| 369 | + ] |
| 370 | + }, |
| 371 | + { |
| 372 | + "cell_type": "code", |
| 373 | + "execution_count": null, |
| 374 | + "metadata": {}, |
| 375 | + "outputs": [], |
| 376 | + "source": [ |
| 377 | + "single_prediction[0]" |
| 378 | + ] |
| 379 | + }, |
276 | 380 | {
|
277 | 381 | "cell_type": "markdown",
|
278 | 382 | "metadata": {},
|
|
297 | 401 | "metadata": {},
|
298 | 402 | "outputs": [],
|
299 | 403 | "source": [
|
300 |
| - "rouge_scores = compute_rouge_python(cand=res, ref=test_ds.get_target())\n", |
| 404 | + "rouge_scores = compute_rouge_python(cand=predictions, ref=test_ds.get_target())\n", |
301 | 405 | "pprint.pprint(rouge_scores)"
|
302 | 406 | ]
|
303 | 407 | },
|
|
358 | 462 | "metadata": {},
|
359 | 463 | "outputs": [],
|
360 | 464 | "source": [
|
361 |
| - "print(\"Total notebook runningn time {}\".format(time.time() - start_time))" |
| 465 | + "print(\"Total notebook running time {}\".format(time.time() - start_time))" |
362 | 466 | ]
|
363 | 467 | },
|
364 | 468 | {
|
|
375 | 479 | }
|
376 | 480 | ],
|
377 | 481 | "metadata": {
|
| 482 | + "celltoolbar": "Tags", |
378 | 483 | "kernelspec": {
|
379 |
| - "display_name": "nlp_gpu", |
| 484 | + "display_name": "Python (nlp_gpu)", |
380 | 485 | "language": "python",
|
381 | 486 | "name": "nlp_gpu"
|
382 | 487 | },
|
|
0 commit comments