diff --git a/Source/MLX/Random.swift b/Source/MLX/Random.swift index a42fcfaf..2e7a67e6 100644 --- a/Source/MLX/Random.swift +++ b/Source/MLX/Random.swift @@ -3,6 +3,87 @@ import Cmlx import Foundation +/// Collection of functions related to random number generation. +/// +/// Following [JAX’s PRNG design](https://jax.readthedocs.io/en/latest/jep/263-prng.html) +/// we use a splittable version of Threefry, which is a counter-based PRNG. +/// +/// Random sampling functions in MLX use an implicit global PRNG state by default. +/// However, all functions take an optional key keyword argument for when more fine-grained +/// control or explicit state management is needed. Callers can also arrange for `Task` local +/// random state -- useful in multithreaded situations. +/// +/// For example, you can generate random numbers with: +/// +/// ```swift +/// for _ in 0 ..< 3 { +/// print(MLXRandom.uniform()) +/// } +/// ``` +/// +/// which will print a sequence of unique pseudo random numbers. Alternatively you can explicitly set the key: +/// +/// ```swift +/// let key = MLXRandom.key(0) +/// for _ in 0 ..< 3 { +/// print(MLXRandom.uniform(key: key)) +/// } +/// ``` +/// +/// which will yield the same pseudo random number at each iteration as the key doesn't change. +/// +/// To get a new random number for each call you would ``split(key:stream:)`` the key: +/// +/// ```swift +/// var key = MLXRandom.key(0) +/// for _ in 0 ..< 3 { +/// let (a, b) = MLXRandom.split(key: key) +/// +/// // use b to generate a different value each time +/// print(MLXRandom.uniform(key: b)) +/// +/// // new random state is a +/// key = a +/// } +/// ``` +/// +/// This will generate the same sequence of numbers each time (same starting key) but +/// different values for each call. +/// +/// As a convenience you can use ``RandomState`` to manage the key splitting and: +/// +/// ```swift +/// let state = RandomState(seed: 0) +/// for _ in 0 ..< 3 { +/// print(MLXRandom.uniform(key: state)) +/// } +/// ``` +/// +/// Finally, if you need to control random state in deeply nested calls to `MLXRandom` or you need +/// thread-safe random state for multi-threaded evaluation you can use ``withRandomState(_:body:)-6i2p1``: +/// +/// ```swift +/// await withTaskGroup { group in +/// for i in 0 ..< 10 { +/// group.addTask { +/// let state = MLXRandom.RandomState(seed: UInt64(i)) +/// return withRandomState(state) { +/// var t: Float = 0.0 +/// for _ in 0 ..< 100 { +/// t += uniform(0 ..< 1, [10, 10]).sum().item(Float.self) +/// } +/// return t +/// } +/// } +/// } +/// +/// for await v in group { +/// ... +/// } +/// } +/// ``` +/// +/// Each task will have separate ``RandomState`` that will be used implicitly (if no other key is passed in). public enum MLXRandom { /// Seed the global PRNG. @@ -67,12 +148,13 @@ public enum MLXRandom { /// let array = MLXRandom.uniform(0.5 ..< 1, [50], key: key) /// ``` public static func uniform( - _ range: Range, _ shape: [Int] = [], type: T.Type = Float.self, key: MLXArray? = nil, + _ range: Range, _ shape: [Int] = [], type: T.Type = Float.self, + key: RandomStateOrKey? = nil, stream: StreamOrDevice = .default ) -> MLXArray where T: HasDType, T: BinaryFloatingPoint { let lb = MLXArray(range.lowerBound) let ub = MLXArray(range.upperBound) - let key = key ?? globalState.next() + let key = resolve(key: key) var result = mlx_array_new() mlx_random_uniform( @@ -92,11 +174,11 @@ public enum MLXRandom { /// ``` public static func uniform( _ range: Range = 0 ..< 1, _ shape: [Int] = [], type: T.Type = Float.self, - key: MLXArray? = nil, stream: StreamOrDevice = .default + key: RandomStateOrKey? = nil, stream: StreamOrDevice = .default ) -> MLXArray where T: HasDType, T: BinaryFloatingPoint { let lb = MLXArray(range.lowerBound) let ub = MLXArray(range.upperBound) - let key = key ?? globalState.next() + let key = resolve(key: key) var result = mlx_array_new() mlx_random_uniform( @@ -122,11 +204,11 @@ public enum MLXRandom { /// ``` public static func uniform( low: ScalarOrArray, high: ScalarOrArray, _ shape: [Int]? = nil, type: T.Type = Float.self, - key: MLXArray? = nil, stream: StreamOrDevice = .default + key: RandomStateOrKey? = nil, stream: StreamOrDevice = .default ) -> MLXArray where T: HasDType, T: BinaryFloatingPoint { let (low, high) = toArrays(low, high) let shape = shape ?? low.shape - let key = key ?? globalState.next() + let key = resolve(key: key) var result = mlx_array_new() mlx_random_uniform( @@ -152,11 +234,11 @@ public enum MLXRandom { /// ``` public static func uniform( low: ScalarOrArray, high: ScalarOrArray, _ shape: [Int]? = nil, dtype: DType = .float32, - key: MLXArray? = nil, stream: StreamOrDevice = .default + key: RandomStateOrKey? = nil, stream: StreamOrDevice = .default ) -> MLXArray { let (low, high) = toArrays(low, high) let shape = shape ?? low.shape - let key = key ?? globalState.next() + let key = resolve(key: key) var result = mlx_array_new() mlx_random_uniform( @@ -190,10 +272,10 @@ public enum MLXRandom { /// - key: PRNG key public static func normal( _ shape: [Int] = [], type: T.Type = Float.self, loc: Float = 0, scale: Float = 1, - key: MLXArray? = nil, + key: RandomStateOrKey? = nil, stream: StreamOrDevice = .default ) -> MLXArray where T: HasDType, T: BinaryFloatingPoint { - let key = key ?? globalState.next() + let key = resolve(key: key) var result = mlx_array_new() mlx_random_normal( @@ -225,10 +307,10 @@ public enum MLXRandom { /// - key: PRNG key public static func normal( _ shape: [Int] = [], dtype: DType = .float32, loc: Float = 0, scale: Float = 1, - key: MLXArray? = nil, + key: RandomStateOrKey? = nil, stream: StreamOrDevice = .default ) -> MLXArray { - let key = key ?? globalState.next() + let key = resolve(key: key) var result = mlx_array_new() mlx_random_normal( @@ -255,9 +337,9 @@ public enum MLXRandom { /// - key: PRNG key public static func multivariateNormal( mean: MLXArray, covariance: MLXArray, shape: [Int] = [], dtype: DType = .float32, - key: MLXArray? = nil, stream: StreamOrDevice = .default + key: RandomStateOrKey? = nil, stream: StreamOrDevice = .default ) -> MLXArray { - let key = key ?? globalState.next() + let key = resolve(key: key) var result = mlx_array_new() mlx_random_multivariate_normal( @@ -283,12 +365,12 @@ public enum MLXRandom { /// let array = MLXRandom.randInt(Int32(0) ..< 100, [50], key: key) /// ``` public static func randInt( - _ range: Range, _ shape: [Int] = [], key: MLXArray? = nil, + _ range: Range, _ shape: [Int] = [], key: RandomStateOrKey? = nil, stream: StreamOrDevice = .default ) -> MLXArray where T: HasDType, T: BinaryInteger { let lb = MLXArray(range.lowerBound) let ub = MLXArray(range.upperBound) - let key = key ?? globalState.next() + let key = resolve(key: key) var result = mlx_array_new() mlx_random_randint( @@ -312,12 +394,13 @@ public enum MLXRandom { /// let array = MLXRandom.randInt(low: [0, 10], high: [10, 100], key: key) /// ``` public static func randInt( - low: ScalarOrArray, high: ScalarOrArray, _ shape: [Int]? = nil, key: MLXArray? = nil, + low: ScalarOrArray, high: ScalarOrArray, _ shape: [Int]? = nil, + key: RandomStateOrKey? = nil, stream: StreamOrDevice = .default ) -> MLXArray { let (low, high) = toArrays(low, high) let shape = shape ?? low.shape - let key = key ?? globalState.next() + let key = resolve(key: key) var result = mlx_array_new() mlx_random_randint( @@ -343,11 +426,11 @@ public enum MLXRandom { /// ``` public static func randInt( low: ScalarOrArray, high: ScalarOrArray, _ shape: [Int]? = nil, type: T.Type, - key: MLXArray? = nil, stream: StreamOrDevice = .default + key: RandomStateOrKey? = nil, stream: StreamOrDevice = .default ) -> MLXArray where T: HasDType, T: BinaryInteger { let (low, high) = toArrays(low, high) let shape = shape ?? low.shape - let key = key ?? globalState.next() + let key = resolve(key: key) var result = mlx_array_new() mlx_random_randint( @@ -372,12 +455,12 @@ public enum MLXRandom { /// let array = MLXRandom.bernoulli([50, 2], key: key) /// ``` public static func bernoulli( - _ shape: [Int] = [], key: MLXArray? = nil, stream: StreamOrDevice = .default + _ shape: [Int] = [], key: RandomStateOrKey? = nil, stream: StreamOrDevice = .default ) -> MLXArray { let p = MLXArray(0.5) - let key = key ?? globalState.next() + let key = resolve(key: key) var result = mlx_array_new() mlx_random_bernoulli(&result, p.ctx, shape.asInt32, shape.count, key.ctx, stream.ctx) @@ -403,12 +486,12 @@ public enum MLXRandom { /// let array = MLXRandom.bernoulli(MLXArray(convert: [0.1, 0.5, 0.8]), key: key) /// ``` public static func bernoulli( - _ p: ScalarOrArray, _ shape: [Int]? = nil, key: MLXArray? = nil, + _ p: ScalarOrArray, _ shape: [Int]? = nil, key: RandomStateOrKey? = nil, stream: StreamOrDevice = .default ) -> MLXArray { let p = p.asMLXArray(dtype: .float32) let shape = shape ?? p.shape - let key = key ?? globalState.next() + let key = resolve(key: key) var result = mlx_array_new() mlx_random_bernoulli(&result, p.ctx, shape.asInt32, shape.count, key.ctx, stream.ctx) @@ -434,12 +517,13 @@ public enum MLXRandom { /// ### See also /// - [JAX Documentation](https://jax.readthedocs.io/en/latest/_modules/jax/_src/random.html#truncated_normal) public static func truncatedNormal( - _ range: Range, _ shape: [Int] = [], type: T.Type = Float.self, key: MLXArray? = nil, + _ range: Range, _ shape: [Int] = [], type: T.Type = Float.self, + key: RandomStateOrKey? = nil, stream: StreamOrDevice = .default ) -> MLXArray where T: HasDType, T: BinaryFloatingPoint { let lb = MLXArray(range.lowerBound) let ub = MLXArray(range.upperBound) - let key = key ?? globalState.next() + let key = resolve(key: key) var result = mlx_array_new() mlx_random_truncated_normal( @@ -458,12 +542,13 @@ public enum MLXRandom { /// let array = MLXRandom.truncatedNormal(0.5 ..< 1, [50], key: key) /// ``` public static func truncatedNormal( - _ range: Range, _ shape: [Int] = [], type: T.Type = Float.self, key: MLXArray? = nil, + _ range: Range, _ shape: [Int] = [], type: T.Type = Float.self, + key: RandomStateOrKey? = nil, stream: StreamOrDevice = .default ) -> MLXArray where T: HasDType, T: BinaryFloatingPoint { let lb = MLXArray(range.lowerBound) let ub = MLXArray(range.upperBound) - let key = key ?? globalState.next() + let key = resolve(key: key) var result = mlx_array_new() mlx_random_truncated_normal( @@ -488,11 +573,11 @@ public enum MLXRandom { /// ``` public static func truncatedNormal( low: ScalarOrArray, high: ScalarOrArray, _ shape: [Int]? = nil, type: T.Type = Float.self, - key: MLXArray? = nil, stream: StreamOrDevice = .default + key: RandomStateOrKey? = nil, stream: StreamOrDevice = .default ) -> MLXArray where T: HasDType, T: BinaryFloatingPoint { let (low, high) = toArrays(low, high) let shape = shape ?? low.shape - let key = key ?? globalState.next() + let key = resolve(key: key) var result = mlx_array_new() mlx_random_truncated_normal( @@ -517,11 +602,11 @@ public enum MLXRandom { /// ``` public static func truncatedNormal( low: ScalarOrArray, high: ScalarOrArray, _ shape: [Int]? = nil, dtype: DType = .float32, - key: MLXArray? = nil, stream: StreamOrDevice = .default + key: RandomStateOrKey? = nil, stream: StreamOrDevice = .default ) -> MLXArray { let (low, high) = toArrays(low, high) let shape = shape ?? low.shape - let key = key ?? globalState.next() + let key = resolve(key: key) var result = mlx_array_new() mlx_random_truncated_normal( @@ -547,10 +632,10 @@ public enum MLXRandom { /// let array = MLXRandom.gumbel([10, 5], key: key) /// ``` public static func gumbel( - _ shape: [Int] = [], type: T.Type = Float.self, key: MLXArray? = nil, + _ shape: [Int] = [], type: T.Type = Float.self, key: RandomStateOrKey? = nil, stream: StreamOrDevice = .default ) -> MLXArray where T: HasDType, T: BinaryFloatingPoint { - let key = key ?? globalState.next() + let key = resolve(key: key) var result = mlx_array_new() mlx_random_gumbel( @@ -574,10 +659,10 @@ public enum MLXRandom { /// let array = MLXRandom.gumbel([10, 5], key: key) /// ``` public static func gumbel( - _ shape: [Int] = [], dtype: DType = .float32, key: MLXArray? = nil, + _ shape: [Int] = [], dtype: DType = .float32, key: RandomStateOrKey? = nil, stream: StreamOrDevice = .default ) -> MLXArray { - let key = key ?? globalState.next() + let key = resolve(key: key) var result = mlx_array_new() mlx_random_gumbel(&result, shape.asInt32, shape.count, dtype.cmlxDtype, key.ctx, stream.ctx) @@ -604,10 +689,10 @@ public enum MLXRandom { /// - Parameters: /// - logits: The *unnormalized* categorical distribution(s). public static func categorical( - _ logits: MLXArray, axis: Int = -1, shape: [Int]? = nil, key: MLXArray? = nil, + _ logits: MLXArray, axis: Int = -1, shape: [Int]? = nil, key: RandomStateOrKey? = nil, stream: StreamOrDevice = .default ) -> MLXArray { - let key = key ?? globalState.next() + let key = resolve(key: key) if let shape { var result = mlx_array_new() @@ -640,10 +725,10 @@ public enum MLXRandom { /// - Parameters: /// - logits: The *unnormalized* categorical distribution(s). public static func categorical( - _ logits: MLXArray, axis: Int = -1, count: Int, key: MLXArray? = nil, + _ logits: MLXArray, axis: Int = -1, count: Int, key: RandomStateOrKey? = nil, stream: StreamOrDevice = .default ) -> MLXArray { - let key = key ?? globalState.next() + let key = resolve(key: key) var result = mlx_array_new() mlx_random_categorical_num_samples( @@ -661,9 +746,9 @@ public enum MLXRandom { /// - scale: scale "b" of the distribution public static func laplace( _ shape: [Int] = [], dtype: DType = .float32, loc: Float = 0, scale: Float = 1, - key: MLXArray? = nil, stream: StreamOrDevice = .default + key: RandomStateOrKey? = nil, stream: StreamOrDevice = .default ) -> MLXArray { - let key = key ?? globalState.next() + let key = resolve(key: key) var result = mlx_array_new() mlx_random_laplace( @@ -724,7 +809,7 @@ public func split(key: MLXArray, stream: StreamOrDevice = .default) -> (MLXArray /// let array = MLXRandom.uniform(0.5 ..< 1, [50], key: key) /// ``` public func uniform( - _ range: Range, _ shape: [Int] = [], type: T.Type = Float.self, key: MLXArray? = nil, + _ range: Range, _ shape: [Int] = [], type: T.Type = Float.self, key: RandomStateOrKey? = nil, stream: StreamOrDevice = .default ) -> MLXArray where T: HasDType, T: BinaryFloatingPoint { return MLXRandom.uniform(range, shape, type: type, key: key, stream: stream) @@ -740,7 +825,7 @@ public func uniform( /// ``` public func uniform( _ range: Range = 0 ..< 1, _ shape: [Int] = [], type: T.Type = Float.self, - key: MLXArray? = nil, stream: StreamOrDevice = .default + key: RandomStateOrKey? = nil, stream: StreamOrDevice = .default ) -> MLXArray where T: HasDType, T: BinaryFloatingPoint { return MLXRandom.uniform(range, shape, type: type, key: key, stream: stream) } @@ -761,7 +846,7 @@ public func uniform( /// ``` public func uniform( low: ScalarOrArray, high: ScalarOrArray, _ shape: [Int]? = nil, type: T.Type = Float.self, - key: MLXArray? = nil, stream: StreamOrDevice = .default + key: RandomStateOrKey? = nil, stream: StreamOrDevice = .default ) -> MLXArray where T: HasDType, T: BinaryFloatingPoint { return MLXRandom.uniform(low: low, high: high, shape, type: type, key: key, stream: stream) } @@ -782,7 +867,7 @@ public func uniform( /// ``` public func uniform( low: ScalarOrArray, high: ScalarOrArray, _ shape: [Int]? = nil, dtype: DType = .float32, - key: MLXArray? = nil, stream: StreamOrDevice = .default + key: RandomStateOrKey? = nil, stream: StreamOrDevice = .default ) -> MLXArray { return MLXRandom.uniform(low: low, high: high, shape, dtype: dtype, key: key, stream: stream) } @@ -810,7 +895,7 @@ public func uniform( /// - key: PRNG key public func normal( _ shape: [Int] = [], type: T.Type = Float.self, loc: Float = 0, scale: Float = 1, - key: MLXArray? = nil, + key: RandomStateOrKey? = nil, stream: StreamOrDevice = .default ) -> MLXArray where T: HasDType, T: BinaryFloatingPoint { return MLXRandom.normal(shape, type: type, loc: loc, scale: scale, key: key, stream: stream) @@ -839,7 +924,7 @@ public func normal( /// - key: PRNG key public func normal( _ shape: [Int] = [], dtype: DType = .float32, loc: Float = 0, scale: Float = 1, - key: MLXArray? = nil, + key: RandomStateOrKey? = nil, stream: StreamOrDevice = .default ) -> MLXArray { return MLXRandom.normal(shape, dtype: dtype, loc: loc, scale: scale, key: key, stream: stream) @@ -863,7 +948,7 @@ public func normal( /// - key: PRNG key public func multivariateNormal( mean: MLXArray, covariance: MLXArray, shape: [Int] = [], dtype: DType = .float32, - key: MLXArray? = nil, stream: StreamOrDevice = .default + key: RandomStateOrKey? = nil, stream: StreamOrDevice = .default ) -> MLXArray { return MLXRandom.multivariateNormal( mean: mean, covariance: covariance, shape: shape, dtype: dtype, key: key, stream: stream) @@ -885,7 +970,8 @@ public func multivariateNormal( /// let array = MLXRandom.randInt(Int32(0) ..< 100, [50], key: key) /// ``` public func randInt( - _ range: Range, _ shape: [Int] = [], key: MLXArray? = nil, stream: StreamOrDevice = .default + _ range: Range, _ shape: [Int] = [], key: RandomStateOrKey? = nil, + stream: StreamOrDevice = .default ) -> MLXArray where T: HasDType, T: BinaryInteger { return MLXRandom.randInt(range, shape, key: key, stream: stream) } @@ -904,7 +990,7 @@ public func randInt( /// let array = MLXRandom.randInt(low: [0, 10], high: [10, 100], key: key) /// ``` public func randInt( - low: ScalarOrArray, high: ScalarOrArray, _ shape: [Int]? = nil, key: MLXArray? = nil, + low: ScalarOrArray, high: ScalarOrArray, _ shape: [Int]? = nil, key: RandomStateOrKey? = nil, stream: StreamOrDevice = .default ) -> MLXArray { return MLXRandom.randInt(low: low, high: high, shape, key: key, stream: stream) @@ -926,7 +1012,7 @@ public func randInt( /// ``` public func randInt( low: ScalarOrArray, high: ScalarOrArray, _ shape: [Int]? = nil, type: T.Type, - key: MLXArray? = nil, stream: StreamOrDevice = .default + key: RandomStateOrKey? = nil, stream: StreamOrDevice = .default ) -> MLXArray where T: HasDType, T: BinaryInteger { return MLXRandom.randInt(low: low, high: high, shape, type: type, key: key, stream: stream) } @@ -945,7 +1031,9 @@ public func randInt( /// // generate an array of shape [50, 2] of random Bool /// let array = MLXRandom.bernoulli([50, 2], key: key) /// ``` -public func bernoulli(_ shape: [Int] = [], key: MLXArray? = nil, stream: StreamOrDevice = .default) +public func bernoulli( + _ shape: [Int] = [], key: RandomStateOrKey? = nil, stream: StreamOrDevice = .default +) -> MLXArray { return MLXRandom.bernoulli(shape, key: key, stream: stream) @@ -970,7 +1058,7 @@ public func bernoulli(_ shape: [Int] = [], key: MLXArray? = nil, stream: StreamO /// let array = MLXRandom.bernoulli(MLXArray(convert: [0.1, 0.5, 0.8]), key: key) /// ``` public func bernoulli( - _ p: ScalarOrArray, _ shape: [Int]? = nil, key: MLXArray? = nil, + _ p: ScalarOrArray, _ shape: [Int]? = nil, key: RandomStateOrKey? = nil, stream: StreamOrDevice = .default ) -> MLXArray { return MLXRandom.bernoulli(p, shape, key: key, stream: stream) @@ -995,7 +1083,7 @@ public func bernoulli( /// ### See also /// - [JAX Documentation](https://jax.readthedocs.io/en/latest/_modules/jax/_src/random.html#truncated_normal) public func truncatedNormal( - _ range: Range, _ shape: [Int] = [], type: T.Type = Float.self, key: MLXArray? = nil, + _ range: Range, _ shape: [Int] = [], type: T.Type = Float.self, key: RandomStateOrKey? = nil, stream: StreamOrDevice = .default ) -> MLXArray where T: HasDType, T: BinaryFloatingPoint { return MLXRandom.truncatedNormal(range, shape, type: type, key: key, stream: stream) @@ -1010,7 +1098,8 @@ public func truncatedNormal( /// let array = MLXRandom.truncatedNormal(0.5 ..< 1, [50], key: key) /// ``` public func truncatedNormal( - _ range: Range, _ shape: [Int] = [], type: T.Type = Float.self, key: MLXArray? = nil, + _ range: Range, _ shape: [Int] = [], type: T.Type = Float.self, + key: RandomStateOrKey? = nil, stream: StreamOrDevice = .default ) -> MLXArray where T: HasDType, T: BinaryFloatingPoint { return MLXRandom.truncatedNormal(range, shape, type: type, key: key, stream: stream) @@ -1031,7 +1120,7 @@ public func truncatedNormal( /// ``` public func truncatedNormal( low: ScalarOrArray, high: ScalarOrArray, _ shape: [Int]? = nil, type: T.Type = Float.self, - key: MLXArray? = nil, stream: StreamOrDevice = .default + key: RandomStateOrKey? = nil, stream: StreamOrDevice = .default ) -> MLXArray where T: HasDType, T: BinaryFloatingPoint { return MLXRandom.truncatedNormal( low: low, high: high, shape, type: type, key: key, stream: stream) @@ -1052,7 +1141,7 @@ public func truncatedNormal( /// ``` public func truncatedNormal( low: ScalarOrArray, high: ScalarOrArray, _ shape: [Int]? = nil, dtype: DType = .float32, - key: MLXArray? = nil, stream: StreamOrDevice = .default + key: RandomStateOrKey? = nil, stream: StreamOrDevice = .default ) -> MLXArray { return MLXRandom.truncatedNormal( low: low, high: high, shape, dtype: dtype, key: key, stream: stream) @@ -1073,7 +1162,7 @@ public func truncatedNormal( /// let array = MLXRandom.gumbel([10, 5], key: key) /// ``` public func gumbel( - _ shape: [Int] = [], type: T.Type = Float.self, key: MLXArray? = nil, + _ shape: [Int] = [], type: T.Type = Float.self, key: RandomStateOrKey? = nil, stream: StreamOrDevice = .default ) -> MLXArray where T: HasDType, T: BinaryFloatingPoint { return MLXRandom.gumbel(shape, type: type, key: key, stream: stream) @@ -1094,7 +1183,7 @@ public func gumbel( /// let array = MLXRandom.gumbel([10, 5], key: key) /// ``` public func gumbel( - _ shape: [Int] = [], dtype: DType = .float32, key: MLXArray? = nil, + _ shape: [Int] = [], dtype: DType = .float32, key: RandomStateOrKey? = nil, stream: StreamOrDevice = .default ) -> MLXArray { return MLXRandom.gumbel(shape, dtype: dtype, key: key, stream: stream) @@ -1119,7 +1208,7 @@ public func gumbel( /// - Parameters: /// - logits: The *unnormalized* categorical distribution(s). public func categorical( - _ logits: MLXArray, axis: Int = -1, shape: [Int]? = nil, key: MLXArray? = nil, + _ logits: MLXArray, axis: Int = -1, shape: [Int]? = nil, key: RandomStateOrKey? = nil, stream: StreamOrDevice = .default ) -> MLXArray { return MLXRandom.categorical(logits, axis: axis, shape: shape, key: key, stream: stream) @@ -1142,7 +1231,7 @@ public func categorical( /// - Parameters: /// - logits: The *unnormalized* categorical distribution(s). public func categorical( - _ logits: MLXArray, axis: Int = -1, count: Int, key: MLXArray? = nil, + _ logits: MLXArray, axis: Int = -1, count: Int, key: RandomStateOrKey? = nil, stream: StreamOrDevice = .default ) -> MLXArray { return MLXRandom.categorical(logits, axis: axis, count: count, key: key, stream: stream) @@ -1157,7 +1246,7 @@ public func categorical( /// - scale: scale "b" of the distribution public func laplace( _ shape: [Int] = [], dtype: DType = .float32, loc: Float = 0, scale: Float = 1, - key: MLXArray? = nil, stream: StreamOrDevice = .default + key: RandomStateOrKey? = nil, stream: StreamOrDevice = .default ) -> MLXArray { return MLXRandom.laplace(shape, dtype: dtype, loc: loc, scale: scale, key: key, stream: stream) } diff --git a/Source/MLX/State.swift b/Source/MLX/State.swift index 460e4f56..63e59aae 100644 --- a/Source/MLX/State.swift +++ b/Source/MLX/State.swift @@ -2,27 +2,52 @@ import Foundation +/// Protocol for types that can be used as a provider of random keys, e.g. for ``MLXRandom``. +public protocol RandomStateOrKey { + func asRandomKey() -> MLXArray +} + +extension MLXArray: RandomStateOrKey { + public func asRandomKey() -> MLXArray { + self + } +} + extension MLXRandom { - /// Global random state. + + /// Random state factory. + /// + /// /// /// Note: although this type is thread-safe, the MLXArrays that it returns are not -- do not /// evaluate these values or expressions that depend on them across multiple threads /// simultaneously. - public class RandomState: Updatable, Evaluatable, @unchecked (Sendable) { + /// + /// ### See Also + /// - ``globalState`` + /// - ``withRandomState(_:body:)-6i2p1`` + public class RandomState: RandomStateOrKey, Updatable, Evaluatable, @unchecked (Sendable) { private var state: MLXArray private let lock = NSLock() - init() { + /// Initialize the RandomState with a seed based on the current time. + public init() { let now = mach_approximate_time() state = MLXRandom.key(now) } + /// Initialize the RandomState with the given seed value. + public init(seed: UInt64) { + state = MLXRandom.key(seed) + } + public func innerState() -> [MLXArray] { lock.withLock { [state] } } + /// Split the current state and return a new Key. public func next() -> MLXArray { lock.withLock { let (a, b) = MLXRandom.split(key: state) @@ -31,11 +56,16 @@ extension MLXRandom { } } + /// Reset the random state. public func seed(_ seed: UInt64) { lock.withLock { state = MLXRandom.key(seed) } } + + public func asRandomKey() -> MLXArray { + next() + } } /// Global random state. @@ -46,4 +76,33 @@ extension MLXRandom { /// - ``seed(_:)`` public static let globalState = RandomState() + /// See ``withRandomState(_:body:)`` and ``resolve(key:)`` + @TaskLocal + static fileprivate var taskLocalRandomState: MLXRandom.RandomState? + } // MLXRandom + +/// Resolve the given key to a concrete MLXArray (random key). +/// +/// This will use the following values (in order until one is found) to resolve the +/// random key: +/// +/// - the passed key, either an ``MLXArray`` or ``MLXRandom/RandomState`` +/// - the task-local ``MLXRandom/RandomState``, see ``withRandomState(_:body:)-6i2p1`` +/// - the global RandomState, ``MLXRandom/globalState`` +public func resolve(key: RandomStateOrKey?) -> MLXArray { + key?.asRandomKey() ?? MLXRandom.taskLocalRandomState?.asRandomKey() + ?? MLXRandom.globalState.next() +} + +/// Use the given ``MLXRandom/RandomState`` scoped to the current task and body. +public func withRandomState(_ state: MLXRandom.RandomState, body: () throws -> R) rethrows -> R { + try MLXRandom.$taskLocalRandomState.withValue(state, operation: body) +} + +/// Use the given ``MLXRandom/RandomState`` scoped to the current task and body. +public func withRandomState(_ state: MLXRandom.RandomState, body: () async throws -> R) + async rethrows -> R +{ + try await MLXRandom.$taskLocalRandomState.withValue(state, operation: body) +} diff --git a/Tests/MLXTests/MLXRandomTests.swift b/Tests/MLXTests/MLXRandomTests.swift index 8613c360..07a3125f 100644 --- a/Tests/MLXTests/MLXRandomTests.swift +++ b/Tests/MLXTests/MLXRandomTests.swift @@ -159,4 +159,106 @@ class MLXRandomTests: XCTestCase { assertEqual(result, expected) } + func testRandomStateOrKeySame() { + // these should all produce the same value since they + // all resolve to the same key + + let key = MLXRandom.key(0) + let (_, k1) = split(key: key) + + let state = MLXRandom.RandomState(seed: 0) + MLXRandom.seed(0) + + // global state + let v0 = uniform(0 ..< 1, [5]) + + // explicit key + let v1 = uniform(0 ..< 1, [5], key: k1) + + // local RandomState + let v2 = uniform(0 ..< 1, [5], key: state) + + assertEqual(v0, v1) + assertEqual(v1, v2) + } + + func testRandomStateOrKeyDifferent() { + // these should all produce different values as they + // use different keys -- note this is otherwise identical + // to testRandomStateOrKeySame + + let key = MLXRandom.key(7) + let (_, k1) = split(key: key) + + let state = MLXRandom.RandomState(seed: 11) + MLXRandom.seed(31) + + // global state + let v0 = uniform(0 ..< 1, [5]) + + // explicit key + let v1 = uniform(0 ..< 1, [5], key: k1) + + // local RandomState + let v2 = uniform(0 ..< 1, [5], key: state) + + assertNotEqual(v0, v1) + assertNotEqual(v1, v2) + assertNotEqual(v0, v2) + } + + func testRandomThreadsSame() async { + // several threads using task local random state with a constant + // seed will produce the same value + await withTaskGroup(of: Float.self) { group in + for _ in 0 ..< 10 { + group.addTask { + let state = MLXRandom.RandomState(seed: 23) + return withRandomState(state) { + var t: Float = 0.0 + for _ in 0 ..< 100 { + t += uniform(0 ..< 1, [10, 10]).sum().item(Float.self) + } + return t + } + } + } + + var result = [Float]() + for await v in group { + result.append(v) + } + + let unique = Set(result) + XCTAssertEqual(unique.count, 1, "Different values: \(result)") + } + } + + func testRandomThreadsDifferent() async { + // several threads using task local random state with different + // seeds will produce different values + await withTaskGroup(of: Float.self) { group in + for i in 0 ..< 10 { + group.addTask { + let state = MLXRandom.RandomState(seed: UInt64(i)) + return withRandomState(state) { + var t: Float = 0.0 + for _ in 0 ..< 100 { + t += uniform(0 ..< 1, [10, 10]).sum().item(Float.self) + } + return t + } + } + } + + var result = [Float]() + for await v in group { + result.append(v) + } + + let unique = Set(result) + XCTAssertEqual(unique.count, 10, "Same values: \(result)") + } + } + } diff --git a/Tests/MLXTests/Utils.swift b/Tests/MLXTests/Utils.swift index 73e58e23..520306bd 100644 --- a/Tests/MLXTests/Utils.swift +++ b/Tests/MLXTests/Utils.swift @@ -24,6 +24,16 @@ func assertEqual( } } +func assertNotEqual( + _ array1: MLXArray, _ array2: MLXArray, rtol: Double = 1e-5, atol: Double = 1e-8, + file: StaticString = #filePath, line: UInt = #line +) { + XCTAssertEqual(array1.shape, array2.shape, "shapes differ: \(array1.shape) != \(array2.shape)") + XCTAssertFalse( + array1.allClose(array2, rtol: rtol, atol: atol).item(Bool.self), + "contents same:\n\(array1)\n\(array2)") +} + func setDefaultDevice() { MLX.Device.setDefault(device: .gpu) }