|
460 | 460 | " def __init__(self, model=None, optimizers=None):\n",
|
461 | 461 | " ...\n",
|
462 | 462 | "\n",
|
463 |
| - " def step(self, batched_data):\n", |
| 463 | + " def step(self, batch):\n", |
464 | 464 | " ...\n",
|
465 | 465 | " return loss\n",
|
466 | 466 | "````\n",
|
|
506 | 506 | "\n",
|
507 | 507 | "````python\n",
|
508 | 508 | "with pm.Model() as model:\n",
|
| 509 | + " data = pm.Data(\"data\", ...)\n", |
509 | 510 | " x = pm.Normal(\"x\", 0, 1)\n",
|
510 | 511 | " y = pm.Normal(\"y\", x, 1, observed=data)\n",
|
511 | 512 | "\n",
|
|
530 | 531 | "````"
|
531 | 532 | ]
|
532 | 533 | },
|
| 534 | + { |
| 535 | + "cell_type": "markdown", |
| 536 | + "id": "7f97c341-e9bb-4301-b452-d006d6408cec", |
| 537 | + "metadata": {}, |
| 538 | + "source": [ |
| 539 | + "### Reworking Minibatch\n", |
| 540 | + "\n", |
| 541 | + "Another small change we should consider is moving `pm.Minibatch` out of the model. Max already has a [proposal](https://github.com/pymc-devs/pymc/issues/7496) that I think can be adopted with only a few changes.\n", |
| 542 | + "\n", |
| 543 | + "I think where before we explicitly minibatch the data, instead we have dataloaders that stream in updates to the model.\n", |
| 544 | + "\n", |
| 545 | + "````python\n", |
| 546 | + "with pm.Model() as model:\n", |
| 547 | + " data = pm.Data(\"data\", None)\n", |
| 548 | + " x = pm.Normal(\"x\", 0, 1)\n", |
| 549 | + " y = pm.Normal(\"y\", x, 1, observed=data)\n", |
| 550 | + "\n", |
| 551 | + "dataloader = pm.Dataloader(np.random.normal(10_000, 2), batch_size=64)\n", |
| 552 | + "\n", |
| 553 | + "with model:\n", |
| 554 | + " trainer = Trainer(method=ADVI(), dataloader=dataloader)\n", |
| 555 | + " trainer.fit(n=10_000)\n", |
| 556 | + "````\n", |
| 557 | + "\n", |
| 558 | + "Importantly, the model doesn't need to know about the dataloader. We will need to tweak the inference object, but it's not so bad.\n", |
| 559 | + "\n", |
| 560 | + "````python\n", |
| 561 | + "class ADVI(Inference):\n", |
| 562 | + " def step(self, batch):\n", |
| 563 | + " self.model.set_data(\"data\", batch)\n", |
| 564 | + " ...\n", |
| 565 | + "````" |
| 566 | + ] |
| 567 | + }, |
533 | 568 | {
|
534 | 569 | "cell_type": "code",
|
535 | 570 | "execution_count": null,
|
536 |
| - "id": "50fcb3a1-4467-4ace-acdd-666e4f342984", |
| 571 | + "id": "220ba769-fb8f-47a7-82b6-ab6ca13ad61e", |
537 | 572 | "metadata": {},
|
538 | 573 | "outputs": [],
|
539 | 574 | "source": []
|
|
0 commit comments