@@ -79,16 +79,24 @@ def optimal_transformation(
7979 lambdas = preprocessing .PowerTransformer (
8080 method , standardize = False ).fit (data ).lambdas_ .astype (dtype )
8181
82- logging .info ('Optimal lambda was: %s' , lambdas )
82+ logging .info ('Optimal lambda was: %s, %s ' , lambdas , lambdas . dtype )
8383
8484 if dimension == 1 :
8585 # Make it a scalar, so we don't end up with batch_shape = [1] in the
8686 # bijector.
8787 lambdas = lambdas .item ()
8888 if method == 'yeo-johnson' :
89- warp = tfsb .YeoJohnson (lambdas )
89+ # Cast the default values of `rho` and `shift` to the same dtype as `data`
90+ # to avoid dtype mismatch errors.
91+ warp = tfsb .YeoJohnson (
92+ lambdas , rho = np .asarray (2.0 , dtype = dtype ), shift = np .asarray (1.0 , dtype )
93+ )
9094 elif method == 'box-cox' :
91- warp = tfsb .YeoJohnson (lambdas , shift = .0 )
95+ # Cast the default values of `rho` and `shift` to the same dtype as `data`
96+ # to avoid dtype mismatch errors.
97+ warp = tfsb .YeoJohnson (
98+ lambdas , rho = np .asarray (2.0 , dtype ), shift = np .asarray (0.0 , dtype )
99+ )
92100 else :
93101 raise ValueError (f'Unknown method: { method } ' )
94102
0 commit comments