-
Notifications
You must be signed in to change notification settings - Fork 92
various Random improvements #219
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
davidkoski
commented
Apr 8, 2025
- allow a Key or a RandomState to be used with Random functions
- users can create RandomState
- users can do withRandomState to scope the use of a RandomState -- useful with multiple threads
- allow a Key or a RandomState to be used with Random functions - users can create RandomState - users can do withRandomState to scope the use of a RandomState -- useful with multiple threads
@@ -2,21 +2,36 @@ | |||
|
|||
import Foundation | |||
|
|||
public protocol RandomStateOrKey { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
- various need documentation
@@ -67,12 +67,13 @@ public enum MLXRandom { | |||
/// let array = MLXRandom.uniform(0.5 ..< 1, [50], key: key) | |||
/// ``` | |||
public static func uniform<R: HasDType, T>( | |||
_ range: Range<R>, _ shape: [Int] = [], type: T.Type = Float.self, key: MLXArray? = nil, | |||
_ range: Range<R>, _ shape: [Int] = [], type: T.Type = Float.self, | |||
key: RandomStateOrKey? = nil, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
allow passing a Key (MLXArray) or RandomState (vendor of keys)
private var state: MLXArray | ||
private let lock = NSLock() | ||
|
||
init() { | ||
public init() { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
These inits should be public so that callers can create these as needed.
Source/MLX/State.swift
Outdated
private var taskLocalRandomState: MLXRandom.RandomState? | ||
|
||
public func resolve(key: RandomStateOrKey?) -> MLXArray { | ||
key?.asRandomKey() ?? taskLocalRandomState?.asRandomKey() ?? MLXRandom.globalState.next() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This lets callers of random functions use:
- a passed in Key or RandomState
- a task (roughly thread) local RandomState, see
withRandomState
in tests - the global RandomState
Previously this was either a passed in MLXArray OR the global RandomState, which made it hard to control if the use of MLXRandom was buried deeply in a model or sampling.
for _ in 0 ..< 10 { | ||
group.addTask { | ||
let state = MLXRandom.RandomState(seed: 23) | ||
return withRandomState(state) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This will use per-task (thread) RandomState -- this would crash if done before this change.
Source/MLX/Random.swift
Outdated
/// You can also use a ``RandomState`` to generate different random numbers but the same sequence | ||
/// each time: | ||
/// | ||
/// ```swift | ||
/// let state = RandomState(seed: 0) | ||
/// for _ in 0 ..< 3 { | ||
/// print(MLXRandom.uniform(key: state)) | ||
/// } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think it would be helpful in the exposition here to add an example inbetween this one and the previous one that explains how to split the array random key. So concretely:
- Example with array key -> same numbers
- Example with array key + split -> different numbers
- Same thing as 2 but nicer syntax with random state
@davidkoski right now I see that the random state in Swift is already "thread safe" as you guard the state with a lock. Is that correct? |
This one I'm not sure about. Can you say more about the motivation for adding that? I kind of feel like for people that want control over the individual calls.. they can use the splittable key. It's slightly more verbose but it's also an edge case.. And then for people that don't need control over individual calls, they can create the state and use it in a local context and/or seed the global state. |
Sort of -- although the container itself is thread safe (so I can mark it Sendable), the MLXArrays that it vends are not thread safe. It is kind of a useless distinction but one I need to be able to send these between threads (the |
If you wanted to have a private random state for something, e.g. a categorical sampler in a LLM inference pipeline, you would have to use the key split and manage state that way. This is just a convenience if you want to have a random state without the need to manually manage the splitting. I don't think it is critical for this change -- the I think it is exactly as you said in point 3:
This: var key = MLXRandom.key(0)
for _ in 0 ..< 3 {
let (a, b) = MLXRandom.split(key: key)
// use b to generate a different value each time
print(MLXRandom.uniform(key: b))
// new random state is a
key = a
} vs let state = RandomState(seed: 0)
for _ in 0 ..< 3 {
print(MLXRandom.uniform(key: state))
} They are exactly the same operations. |
public func resolve(key: RandomStateOrKey?) -> MLXArray { | ||
key?.asRandomKey() ?? MLXRandom.taskLocalRandomState?.asRandomKey() | ||
?? MLXRandom.globalState.next() | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why did you make that public? Shouldn't it not be an implementation only function?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm ok merging this. I think my hesitation is that the randomKeyOrState
in the random API is causes some divergence from the rest of the MLX APIs (Python / C++). Other than that I don't have a problem with it and I can see how it'd be useful.
@davidkoski I leave it up to you to keep or not depending on how useful you think it is.
OK, let's go with what we have here then -- I think passing RandomState is a nice alternative to using implicit TaskLocals, for those who desire explicit state instead. We can see if this appeals to anyone and it would be easy enough to add on the python/c++ side if there is uptake. Given current use of Random I suspect most people will just use TaskLocal. C++ might need that approach, though python probably does not. |