@@ -377,16 +377,19 @@ def compute_state_entropy(
377377 A tensor containing the state entropy for `obs`.
378378 """
379379 assert obs .shape [1 :] == all_obs .shape [1 :]
380+ batch_size = 500
380381 with th .no_grad ():
381382 non_batch_dimensions = tuple (range (2 , len (obs .shape ) + 1 ))
382- distances_tensor = th .linalg .vector_norm (
383- obs [:, None ] - all_obs [None , :],
384- dim = non_batch_dimensions ,
385- ord = 2 ,
386- )
387-
388- # Note that we take the k+1'th value because the closest neighbor to
389- # a point is itself, which we want to skip.
390- assert distances_tensor .shape [- 1 ] > k
391- knn_dists = th .kthvalue (distances_tensor , k = k + 1 , dim = 1 ).values
392- return knn_dists
383+ dists = []
384+ for idx in range (len (all_obs ) // batch_size + 1 ):
385+ start = idx * batch_size
386+ end = (idx + 1 ) * batch_size
387+ distances_tensor = th .linalg .vector_norm (
388+ obs [:, None ] - all_obs [None , start :end ],
389+ dim = non_batch_dimensions ,
390+ ord = 2 ,
391+ )
392+ dists .append (distances_tensor )
393+ dists = th .cat (dists , dim = 1 )
394+ knn_dists = th .kthvalue (dists , k = k + 1 , dim = 1 ).values
395+ return knn_dists
0 commit comments