Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Send a request to config.json when downloading model from HF repository #31

Merged
merged 4 commits into from
Nov 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
283 changes: 161 additions & 122 deletions android/src/main/java/com/rnexecutorch/Fetcher.kt
Original file line number Diff line number Diff line change
Expand Up @@ -5,104 +5,135 @@ import okhttp3.Call
import okhttp3.Callback
import okhttp3.OkHttpClient
import okhttp3.Request
import okhttp3.RequestBody
import okhttp3.Response
import okio.IOException
import java.io.File
import java.io.FileOutputStream
import java.net.URL

enum class ResourceType {
TOKENIZER,
MODEL
TOKENIZER,
MODEL,
}

class Fetcher {
companion object {
private fun saveResponseToFile(
response: Response,
directory: File,
fileName: String
): File {
val file = File(directory.path, fileName)
file.outputStream().use { outputStream ->
response.body?.byteStream()?.copyTo(outputStream)
}
return file
companion object {
private fun saveResponseToFile(
response: Response,
directory: File,
fileName: String,
): File {
val file = File(directory.path, fileName)
file.outputStream().use { outputStream ->
response.body?.byteStream()?.copyTo(outputStream)
}
return file
}

private fun hasValidExtension(fileName: String, resourceType: ResourceType): Boolean {
return when (resourceType) {
ResourceType.TOKENIZER -> {
fileName.endsWith(".bin")
}

private fun hasValidExtension(fileName: String, resourceType: ResourceType): Boolean {
return when (resourceType) {
ResourceType.TOKENIZER -> {
fileName.endsWith(".bin")
}
ResourceType.MODEL -> {
fileName.endsWith(".pte")
}
}
}

ResourceType.MODEL -> {
fileName.endsWith(".pte")
}
}
private fun extractFileName(url: URL): String {
if (url.path == "/assets/") {
val pathSegments = url.toString().split('/')
return pathSegments[pathSegments.size - 1].split("?")[0]
} else if (url.protocol == "file") {
val localPath = url.toString().split("://")[1]
val file = File(localPath)
if (file.exists()) {
return localPath
}

private fun extractFileName(url: URL): String {
if (url.path == "/assets/") {
val pathSegments = url.toString().split('/')
return pathSegments[pathSegments.size - 1].split("?")[0]
} else if (url.protocol == "file") {
val localPath = url.toString().split("://")[1]
val file = File(localPath)
if (file.exists()) {
return localPath
throw Exception("file_not_found")
} else {
return url.path.substringAfterLast('/')
}
}

private fun fetchModel(
file: File,
validFile: File,
client: OkHttpClient,
url: URL,
onComplete: (String?, Exception?) -> Unit,
listener: ProgressResponseBody.ProgressListener? = null,
) {
val request = Request.Builder().url(url).build()
client.newCall(request).enqueue(object : Callback {
override fun onFailure(call: Call, e: IOException) {
onComplete(null, e)
}

override fun onResponse(call: Call, response: Response) {
if (!response.isSuccessful) {
onComplete(null, Exception("download_error"))
}

response.body?.let { body ->
val progressBody = listener?.let { ProgressResponseBody(body, it) }
val inputStream = progressBody?.source()?.inputStream()
inputStream?.use { input ->
FileOutputStream(file).use { output ->
val buffer = ByteArray(2048)
var bytesRead: Int
while (input.read(buffer).also { bytesRead = it } != -1) {
output.write(buffer, 0, bytesRead)
}
}
}

throw Exception("file_not_found")
if (file.renameTo(validFile)) {
onComplete(validFile.absolutePath, null)
} else {
return url.path.substringAfterLast('/')
onComplete(null, Exception("Failed to move the file to the valid location"))
}
}
}
})
}

private fun fetchModel(file: File, validFile: File, client: OkHttpClient, url: URL, onComplete: (String?, Exception?) -> Unit,
listener: ProgressResponseBody.ProgressListener? = null){
val request = Request.Builder().url(url).build()
client.newCall(request).enqueue(object : Callback {
override fun onFailure(call: Call, e: IOException) {
onComplete(null, e)
}
private fun isUrlPointingToHfRepo(url: URL): Boolean {
val expectedHost = "huggingface.co"
val expectedPathPrefix = "/software-mansion/"
if (url.host != expectedHost) {
return false
}
return url.path.startsWith(expectedPathPrefix)
}

override fun onResponse(call: Call, response: Response) {
if (!response.isSuccessful) {
onComplete(null, Exception("download_error"))
}

response.body?.let { body ->
val progressBody = listener?.let { ProgressResponseBody(body, it) }
val inputStream = progressBody?.source()?.inputStream()
inputStream?.use { input ->
FileOutputStream(file).use { output ->
val buffer = ByteArray(2048)
var bytesRead: Int
while (input.read(buffer).also { bytesRead = it } != -1) {
output.write(buffer, 0, bytesRead)
}
}
}

if (file.renameTo(validFile)) {
onComplete(validFile.absolutePath, null)
} else {
onComplete(null, Exception("Failed to move the file to the valid location"))
}
}
}
})
}
private fun resolveConfigUrlFromModelUrl(modelUrl: URL): URL {
// Create a new URL using the base URL and append the desired path
val baseUrl = modelUrl.protocol + "://" + modelUrl.host + modelUrl.path.substringBefore("resolve/")
return URL(baseUrl + "resolve/main/config.json")
}

fun downloadResource(
context: Context,
client: OkHttpClient,
url: URL,
resourceType: ResourceType,
onComplete: (String?, Exception?) -> Unit,
listener: ProgressResponseBody.ProgressListener? = null
) {
private fun sendRequestToUrl(url: URL, method: String, body: RequestBody?, client: OkHttpClient): Response {
val request = Request.Builder()
.url(url)
.method(method, body)
.build()
val response = client.newCall(request).execute()
return response
}

fun downloadResource(
context: Context,
client: OkHttpClient,
url: URL,
resourceType: ResourceType,
onComplete: (String?, Exception?) -> Unit,
listener: ProgressResponseBody.ProgressListener? = null,
) {
/*
Fetching model and tokenizer file
1. Extract file name from provided URL
Expand All @@ -115,57 +146,65 @@ class Fetcher {
6. If the file does not exist, and is a tokenizer, fetch the file
7. If the file is a model, fetch the file with ProgressResponseBody
*/
val fileName: String

try {
fileName = extractFileName(url)
} catch (e: Exception) {
onComplete(null, e)
return
}

if(fileName.contains("/")){
onComplete(fileName, null)
return
}

if (!hasValidExtension(fileName, resourceType)) {
onComplete(null, Exception("invalid_resource_extension"))
return
}

var tempFile = File(context.filesDir, fileName)
if(tempFile.exists()){
tempFile.delete()
}
val fileName: String

try {
fileName = extractFileName(url)
} catch (e: Exception) {
onComplete(null, e)
return
}

if (fileName.contains("/")) {
onComplete(fileName, null)
return
}

if (!hasValidExtension(fileName, resourceType)) {
onComplete(null, Exception("invalid_resource_extension"))
return
}

var tempFile = File(context.filesDir, fileName)
if (tempFile.exists()) {
tempFile.delete()
}

val modelsDirectory = File(context.filesDir, "models").apply {
if (!exists()) {
mkdirs()
}
}

val modelsDirectory = File(context.filesDir, "models").apply {
if (!exists()) {
mkdirs()
}
}
var validFile = File(modelsDirectory, fileName)
if (validFile.exists()) {
onComplete(validFile.absolutePath, null)
return
}

var validFile = File(modelsDirectory, fileName)
if (validFile.exists()) {
onComplete(validFile.absolutePath, null)
return
}
if (resourceType == ResourceType.TOKENIZER) {
val request = Request.Builder().url(url).build()
val response = client.newCall(request).execute()

if (resourceType == ResourceType.TOKENIZER) {
val request = Request.Builder().url(url).build()
val response = client.newCall(request).execute()
if (!response.isSuccessful) {
onComplete(null, Exception("download_error"))
return
}

if (!response.isSuccessful) {
onComplete(null, Exception("download_error"))
return
}
validFile = saveResponseToFile(response, modelsDirectory, fileName)
onComplete(validFile.absolutePath, null)
return
}

validFile = saveResponseToFile(response, modelsDirectory, fileName)
onComplete(validFile.absolutePath, null)
return
}
// If the url is a Software Mansion HuggingFace repo, we want to send a HEAD
// request to the config.json file, this increments HF download counter
// https://huggingface.co/docs/hub/models-download-stats
if (isUrlPointingToHfRepo(url)) {
val configUrl = resolveConfigUrlFromModelUrl(url)
sendRequestToUrl(configUrl, "HEAD", null, client)
}

fetchModel(tempFile, validFile, client, url, onComplete, listener)
}
fetchModel(tempFile, validFile, client, url, onComplete, listener)
}
}
}
}
23 changes: 23 additions & 0 deletions ios/utils/LargeFileFetcher.mm
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,13 @@ - (void)URLSession:(NSURLSession *)session downloadTask:(NSURLSessionDownloadTas
}
}

- (void)sendHeadRequestToURL:(NSURL *)url {
NSMutableURLRequest *request = [NSMutableURLRequest requestWithURL:url];
[request setHTTPMethod:@"HEAD"];
NSURLSessionDataTask *dataTask = [_session dataTaskWithRequest:request];
[dataTask resume];
}

- (void)startDownloadingFileFromURL:(NSURL *)url {
//Check if file is a valid url, if not check if it's path to local file
if (![Fetcher isValidURL:url]) {
Expand All @@ -77,6 +84,22 @@ - (void)startDownloadingFileFromURL:(NSURL *)url {
[self executeCompletionWithSuccess:filePath];
return;
}

// If the url is a Software Mansion HuggingFace repo, we want to send a HEAD
// request to the config.json file, this increments HF download counter
// https://huggingface.co/docs/hub/models-download-stats
NSString *huggingFaceOrgNSString = @"https://huggingface.co/software-mansion/";
NSString *modelURLNSString = [url absoluteString];

if ([modelURLNSString hasPrefix:huggingFaceOrgNSString]) {
NSRange resolveRange = [modelURLNSString rangeOfString:@"resolve"];
if (resolveRange.location != NSNotFound) {
NSString *configURLNSString = [modelURLNSString substringToIndex:resolveRange.location + resolveRange.length];
configURLNSString = [configURLNSString stringByAppendingString:@"/main/config.json"];
NSURL *configNSURL = [NSURL URLWithString:configURLNSString];
[self sendHeadRequestToURL:configNSURL];
}
}

//Cancel all running background download tasks and start new one
_destination = filePath;
Expand Down
Loading