@@ -623,6 +623,8 @@ def total_bound(b, i_1, i_2, dic, n=None):
623623 return (attn_1 , bound , bound_2 , out , out_2 , out_3 )
624624
625625
626+ # %%
627+ runtime_model_1 , model_1 = train_or_load_model (ABCAB8_1H , force = "load" )
626628# %%
627629optimiser = torch .optim .AdamW (
628630 model_1 .parameters (), lr = 5e-3 , betas = (0.9 , 0.999 ), weight_decay = 1.0
@@ -656,10 +658,25 @@ def total_bound(b, i_1, i_2, dic, n=None):
656658optimiser = torch .optim .AdamW (
657659 model_1 .parameters (), lr = 1e-3 , betas = (0.9 , 0.999 ), weight_decay = 1.0
658660)
661+ # %%
662+ a = loss_bound (model_1 , 2 )[2 ]
663+ loss = 1 - a [~ torch .isnan (a )].mean ()
664+ while loss > 0.5 :
665+ print (a [~ torch .isnan (a )].min ())
666+ print (a [~ torch .isnan (a )].mean ())
667+ print (a [~ torch .isnan (a )].max ())
668+ loss .backward ()
669+ optimiser .step ()
670+ optimiser .zero_grad ()
671+ a = loss_bound (model_1 , 2 )[2 ]
672+ loss = 1 - a [~ torch .isnan (a )].mean ()
673+ counter += 1
674+ print (counter )
659675
676+ # %%
660677a = loss_bound (model_1 , 2 )[2 ]
661678loss = 1 - a [~ torch .isnan (a )].min ()
662- while loss > 0.1 :
679+ while loss > 0.5 :
663680 print (a [~ torch .isnan (a )].min ())
664681 print (a [~ torch .isnan (a )].mean ())
665682 print (a [~ torch .isnan (a )].max ())
@@ -670,8 +687,6 @@ def total_bound(b, i_1, i_2, dic, n=None):
670687 loss = 1 - a [~ torch .isnan (a )].min ()
671688 counter += 1
672689 print (counter )
673-
674-
675690# %%
676691valid = (
677692 ein .array (
@@ -681,6 +696,7 @@ def total_bound(b, i_1, i_2, dic, n=None):
681696 .bool ()
682697 .to (device )
683698)
699+ # %%
684700optimiser = torch .optim .AdamW (
685701 model_1 .parameters (), lr = 1 , betas = (0.9 , 0.999 ), weight_decay = 0
686702)
@@ -714,6 +730,9 @@ def total_bound(b, i_1, i_2, dic, n=None):
714730 print (r [valid ].max ())
715731
716732# %%
733+ ModelMatrixLoggingOptions .all (
734+ use_subplots = True , add_mean = {- 1 : None , 0 : "tok_to_pos" , 1 : None }
735+ ).plot_matrices_from_model (model )
717736'''
718737def least_attention_2(a, b, i_1, i_2, j):
719738
0 commit comments