Skip to content

various Random improvements #219

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Apr 18, 2025
Merged

various Random improvements #219

merged 5 commits into from
Apr 18, 2025

Conversation

davidkoski
Copy link
Collaborator

  • allow a Key or a RandomState to be used with Random functions
  • users can create RandomState
  • users can do withRandomState to scope the use of a RandomState -- useful with multiple threads

- allow a Key or a RandomState to be used with Random functions
- users can create RandomState
- users can do withRandomState to scope the use of a RandomState -- useful with multiple threads
@davidkoski davidkoski requested review from awni and barronalex April 8, 2025 19:47
@@ -2,21 +2,36 @@

import Foundation

public protocol RandomStateOrKey {
Copy link
Collaborator Author

@davidkoski davidkoski Apr 8, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  • various need documentation

@@ -67,12 +67,13 @@ public enum MLXRandom {
/// let array = MLXRandom.uniform(0.5 ..< 1, [50], key: key)
/// ```
public static func uniform<R: HasDType, T>(
_ range: Range<R>, _ shape: [Int] = [], type: T.Type = Float.self, key: MLXArray? = nil,
_ range: Range<R>, _ shape: [Int] = [], type: T.Type = Float.self,
key: RandomStateOrKey? = nil,
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

allow passing a Key (MLXArray) or RandomState (vendor of keys)

private var state: MLXArray
private let lock = NSLock()

init() {
public init() {
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These inits should be public so that callers can create these as needed.

private var taskLocalRandomState: MLXRandom.RandomState?

public func resolve(key: RandomStateOrKey?) -> MLXArray {
key?.asRandomKey() ?? taskLocalRandomState?.asRandomKey() ?? MLXRandom.globalState.next()
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This lets callers of random functions use:

  • a passed in Key or RandomState
  • a task (roughly thread) local RandomState, see withRandomState in tests
  • the global RandomState

Previously this was either a passed in MLXArray OR the global RandomState, which made it hard to control if the use of MLXRandom was buried deeply in a model or sampling.

for _ in 0 ..< 10 {
group.addTask {
let state = MLXRandom.RandomState(seed: 23)
return withRandomState(state) {
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This will use per-task (thread) RandomState -- this would crash if done before this change.

Comment on lines 35 to 42
/// You can also use a ``RandomState`` to generate different random numbers but the same sequence
/// each time:
///
/// ```swift
/// let state = RandomState(seed: 0)
/// for _ in 0 ..< 3 {
/// print(MLXRandom.uniform(key: state))
/// }
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it would be helpful in the exposition here to add an example inbetween this one and the previous one that explains how to split the array random key. So concretely:

  1. Example with array key -> same numbers
  2. Example with array key + split -> different numbers
  3. Same thing as 2 but nicer syntax with random state

@awni
Copy link
Member

awni commented Apr 9, 2025

@davidkoski right now I see that the random state in Swift is already "thread safe" as you guard the state with a lock. Is that correct?

@awni
Copy link
Member

awni commented Apr 9, 2025

allow a Key or a RandomState to be used with Random functions

This one I'm not sure about. Can you say more about the motivation for adding that?

I kind of feel like for people that want control over the individual calls.. they can use the splittable key. It's slightly more verbose but it's also an edge case.. And then for people that don't need control over individual calls, they can create the state and use it in a local context and/or seed the global state.

@davidkoski
Copy link
Collaborator Author

@davidkoski right now I see that the random state in Swift is already "thread safe" as you guard the state with a lock. Is that correct?

Sort of -- although the container itself is thread safe (so I can mark it Sendable), the MLXArrays that it vends are not thread safe. It is kind of a useless distinction but one I need to be able to send these between threads (the @TaskLocal requires it for instance).

@davidkoski
Copy link
Collaborator Author

davidkoski commented Apr 9, 2025

allow a Key or a RandomState to be used with Random functions

This one I'm not sure about. Can you say more about the motivation for adding that?

I kind of feel like for people that want control over the individual calls.. they can use the splittable key. It's slightly more verbose but it's also an edge case.. And then for people that don't need control over individual calls, they can create the state and use it in a local context and/or seed the global state.

If you wanted to have a private random state for something, e.g. a categorical sampler in a LLM inference pipeline, you would have to use the key split and manage state that way. This is just a convenience if you want to have a random state without the need to manually manage the splitting.

I don't think it is critical for this change -- the @TaskLocal to inject a scoped "global" random state is the more important part, but I do think that using RandomState is more convenient than manually managing the splits (yet equivalent).

I think it is exactly as you said in point 3:

  • Example with array key -> same numbers
  • Example with array key + split -> different numbers
  • Same thing as 2 but nicer syntax with random state

This:

var key = MLXRandom.key(0)
for _ in 0 ..< 3 {
  let (a, b) = MLXRandom.split(key: key)
  // use b to generate a different value each time
  print(MLXRandom.uniform(key: b))
  // new random state is a
  key = a
}

vs

let state = RandomState(seed: 0)
for _ in 0 ..< 3 {
  print(MLXRandom.uniform(key: state))
}

They are exactly the same operations.

Comment on lines +93 to +96
public func resolve(key: RandomStateOrKey?) -> MLXArray {
key?.asRandomKey() ?? MLXRandom.taskLocalRandomState?.asRandomKey()
?? MLXRandom.globalState.next()
}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why did you make that public? Shouldn't it not be an implementation only function?

Copy link
Member

@awni awni left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm ok merging this. I think my hesitation is that the randomKeyOrState in the random API is causes some divergence from the rest of the MLX APIs (Python / C++). Other than that I don't have a problem with it and I can see how it'd be useful.

@davidkoski I leave it up to you to keep or not depending on how useful you think it is.

@davidkoski
Copy link
Collaborator Author

OK, let's go with what we have here then -- I think passing RandomState is a nice alternative to using implicit TaskLocals, for those who desire explicit state instead. We can see if this appeals to anyone and it would be easy enough to add on the python/c++ side if there is uptake.

Given current use of Random I suspect most people will just use TaskLocal. C++ might need that approach, though python probably does not.

@davidkoski davidkoski merged commit df5d5c7 into main Apr 18, 2025
1 check passed
@davidkoski davidkoski deleted the random branch April 18, 2025 16:54
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants