Skip to content

Commit a870c7c

Browse files
committed
Add minibatch proposal
1 parent ea2c917 commit a870c7c

File tree

1 file changed

+37
-2
lines changed

1 file changed

+37
-2
lines changed

VI_Overview.ipynb

Lines changed: 37 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -460,7 +460,7 @@
460460
" def __init__(self, model=None, optimizers=None):\n",
461461
" ...\n",
462462
"\n",
463-
" def step(self, batched_data):\n",
463+
" def step(self, batch):\n",
464464
" ...\n",
465465
" return loss\n",
466466
"````\n",
@@ -506,6 +506,7 @@
506506
"\n",
507507
"````python\n",
508508
"with pm.Model() as model:\n",
509+
" data = pm.Data(\"data\", ...)\n",
509510
" x = pm.Normal(\"x\", 0, 1)\n",
510511
" y = pm.Normal(\"y\", x, 1, observed=data)\n",
511512
"\n",
@@ -530,10 +531,44 @@
530531
"````"
531532
]
532533
},
534+
{
535+
"cell_type": "markdown",
536+
"id": "7f97c341-e9bb-4301-b452-d006d6408cec",
537+
"metadata": {},
538+
"source": [
539+
"### Reworking Minibatch\n",
540+
"\n",
541+
"Another small change we should consider is moving `pm.Minibatch` out of the model. Max already has a [proposal](https://github.com/pymc-devs/pymc/issues/7496) that I think can be adopted with only a few changes.\n",
542+
"\n",
543+
"I think where before we explicitly minibatch the data, instead we have dataloaders that stream in updates to the model.\n",
544+
"\n",
545+
"````python\n",
546+
"with pm.Model() as model:\n",
547+
" data = pm.Data(\"data\", None)\n",
548+
" x = pm.Normal(\"x\", 0, 1)\n",
549+
" y = pm.Normal(\"y\", x, 1, observed=data)\n",
550+
"\n",
551+
"dataloader = pm.Dataloader(np.random.normal(10_000, 2), batch_size=64)\n",
552+
"\n",
553+
"with model:\n",
554+
" trainer = Trainer(method=ADVI(), dataloader=dataloader)\n",
555+
" trainer.fit(n=10_000)\n",
556+
"````\n",
557+
"\n",
558+
"Importantly, the model doesn't need to know about the dataloader. We will need to tweak the inference object, but it's not so bad.\n",
559+
"\n",
560+
"````python\n",
561+
"class ADVI(Inference):\n",
562+
" def step(self, batch):\n",
563+
" self.model.set_data(\"data\", batch)\n",
564+
" ...\n",
565+
"````"
566+
]
567+
},
533568
{
534569
"cell_type": "code",
535570
"execution_count": null,
536-
"id": "50fcb3a1-4467-4ace-acdd-666e4f342984",
571+
"id": "220ba769-fb8f-47a7-82b6-ab6ca13ad61e",
537572
"metadata": {},
538573
"outputs": [],
539574
"source": []

0 commit comments

Comments
 (0)