diff --git a/app/build.gradle.kts b/app/build.gradle.kts index 0a04367b..dc92b956 100755 --- a/app/build.gradle.kts +++ b/app/build.gradle.kts @@ -72,7 +72,8 @@ dependencies { implementation(project(":storage")) implementation(project(":domain")) implementation(project(":feature:auth")) - implementation(project(":feature:diffusion")) + implementation(project(":feature:local-diffusion-onnx")) + implementation(project(":feature:local-diffusion-cpp")) implementation(project(":feature:mediapipe")) implementation(project(":feature:work")) implementation(project(":data")) diff --git a/app/src/main/java/com/shifthackz/aisdv1/app/di/FeatureModule.kt b/app/src/main/java/com/shifthackz/aisdv1/app/di/FeatureModule.kt index a98cf95f..30c407e9 100644 --- a/app/src/main/java/com/shifthackz/aisdv1/app/di/FeatureModule.kt +++ b/app/src/main/java/com/shifthackz/aisdv1/app/di/FeatureModule.kt @@ -1,11 +1,13 @@ package com.shifthackz.aisdv1.app.di import com.shifthackz.aisdv1.feature.auth.di.authModule -import com.shifthackz.aisdv1.feature.diffusion.di.diffusionModule +import com.shifthackz.aisdv1.feature.localdiffusion.cpp.di.cppLocalDiffusionModule +import com.shifthackz.aisdv1.feature.localdiffusion.onnx.di.onnxLocalDiffusionModule import com.shifthackz.aisdv1.feature.mediapipe.di.mediaPipeModule val featureModule = arrayOf( authModule, - diffusionModule, + cppLocalDiffusionModule, + onnxLocalDiffusionModule, mediaPipeModule, ) diff --git a/app/src/main/java/com/shifthackz/aisdv1/app/di/ProvidersModule.kt b/app/src/main/java/com/shifthackz/aisdv1/app/di/ProvidersModule.kt index 5aaa0ea2..ecc8c682 100755 --- a/app/src/main/java/com/shifthackz/aisdv1/app/di/ProvidersModule.kt +++ b/app/src/main/java/com/shifthackz/aisdv1/app/di/ProvidersModule.kt @@ -15,9 +15,10 @@ import com.shifthackz.aisdv1.domain.entity.ServerSource import com.shifthackz.aisdv1.domain.feature.auth.AuthorizationCredentials import com.shifthackz.aisdv1.domain.feature.auth.AuthorizationStore import com.shifthackz.aisdv1.domain.preference.PreferenceManager -import com.shifthackz.aisdv1.feature.diffusion.entity.LocalDiffusionFlag -import com.shifthackz.aisdv1.feature.diffusion.environment.DeviceNNAPIFlagProvider -import com.shifthackz.aisdv1.feature.diffusion.environment.LocalModelIdProvider +import com.shifthackz.aisdv1.feature.localdiffusion.cpp.environment.LocalCppModelIdProvider +import com.shifthackz.aisdv1.feature.localdiffusion.onnx.entity.LocalDiffusionFlag +import com.shifthackz.aisdv1.feature.localdiffusion.onnx.environment.DeviceNNAPIFlagProvider +import com.shifthackz.aisdv1.feature.localdiffusion.onnx.environment.LocalOnnxModelIdProvider import com.shifthackz.aisdv1.network.qualifiers.ApiKeyProvider import com.shifthackz.aisdv1.network.qualifiers.ApiUrlProvider import com.shifthackz.aisdv1.network.qualifiers.CredentialsProvider @@ -184,7 +185,11 @@ val providersModule = module { } single { - LocalModelIdProvider { get().localOnnxModelId } + LocalOnnxModelIdProvider { get().localOnnxModelId } + } + + single { + LocalCppModelIdProvider { get().localCppModelId } } single { diff --git a/build-logic/convention/src/main/kotlin/com/shifthackz/aisdv1/buildlogic/Android.kt b/build-logic/convention/src/main/kotlin/com/shifthackz/aisdv1/buildlogic/Android.kt index 343b5456..44187c7e 100644 --- a/build-logic/convention/src/main/kotlin/com/shifthackz/aisdv1/buildlogic/Android.kt +++ b/build-logic/convention/src/main/kotlin/com/shifthackz/aisdv1/buildlogic/Android.kt @@ -11,7 +11,8 @@ internal fun Project.configureKotlinAndroid( commonExtension: CommonExtension<*, *, *, *, *, *>, ) { commonExtension.apply { - compileSdk = libs.findVersion("compileSdk").get().toString().toInt() + compileSdk = libs.findVersion("compileSdk").get().toString().toInt() + ndkVersion = libs.findVersion("ndk").get().toString() defaultConfig { minSdk = libs.findVersion("minSdk").get().toString().toInt() diff --git a/core/localization/src/main/res/values/strings.xml b/core/localization/src/main/res/values/strings.xml index c767358e..206968c9 100755 --- a/core/localization/src/main/res/values/strings.xml +++ b/core/localization/src/main/res/values/strings.xml @@ -71,7 +71,9 @@ Horde AI Cloud Horde Local Diffusion Microsoft ONNX (Beta) + Local Diffusion C++ (Beta) ONNX + CPP Local Diffusion Google AI MediaPipe (Beta) MediaPipe Hugging Face Inference @@ -156,6 +158,9 @@ This configuration uses Microsoft ONNX runtime and allows to run Stable Diffusion AI generations on your phone, with no need to connect to remote server/cloud. Warning! Local Diffusion functionality is in beta-test. Don\'t expect for high quality images using local mode. \n\nThis implementation may not work well on non-powerful phones. Generation performance and speed depends on your phone resources (CPU, RAM) and the size of generated image (the smaller the image size, the faster the generation). + Local Diffusion C++ + Changeme + Local Diffusion Google AI MediaPipe This configuration uses Google AI MediaPipe and allows to run Stable Diffusion AI generations on your phone, with no need to connect to remote server/cloud. diff --git a/data/src/main/java/com/shifthackz/aisdv1/data/di/RepositoryModule.kt b/data/src/main/java/com/shifthackz/aisdv1/data/di/RepositoryModule.kt index 1358c8a9..4c30a35e 100755 --- a/data/src/main/java/com/shifthackz/aisdv1/data/di/RepositoryModule.kt +++ b/data/src/main/java/com/shifthackz/aisdv1/data/di/RepositoryModule.kt @@ -8,7 +8,8 @@ import com.shifthackz.aisdv1.data.repository.GenerationResultRepositoryImpl import com.shifthackz.aisdv1.data.repository.HordeGenerationRepositoryImpl import com.shifthackz.aisdv1.data.repository.HuggingFaceGenerationRepositoryImpl import com.shifthackz.aisdv1.data.repository.HuggingFaceModelsRepositoryImpl -import com.shifthackz.aisdv1.data.repository.LocalDiffusionGenerationRepositoryImpl +import com.shifthackz.aisdv1.data.repository.LocalDiffusionCppGenerationRepositoryImpl +import com.shifthackz.aisdv1.data.repository.LocalDiffusionOnnxGenerationRepositoryImpl import com.shifthackz.aisdv1.data.repository.LorasRepositoryImpl import com.shifthackz.aisdv1.data.repository.MediaPipeGenerationRepositoryImpl import com.shifthackz.aisdv1.data.repository.OpenAiGenerationRepositoryImpl @@ -33,7 +34,8 @@ import com.shifthackz.aisdv1.domain.repository.GenerationResultRepository import com.shifthackz.aisdv1.domain.repository.HordeGenerationRepository import com.shifthackz.aisdv1.domain.repository.HuggingFaceGenerationRepository import com.shifthackz.aisdv1.domain.repository.HuggingFaceModelsRepository -import com.shifthackz.aisdv1.domain.repository.LocalDiffusionGenerationRepository +import com.shifthackz.aisdv1.domain.repository.LocalDiffusionCppGenerationRepository +import com.shifthackz.aisdv1.domain.repository.LocalDiffusionOnnxGenerationRepository import com.shifthackz.aisdv1.domain.repository.LorasRepository import com.shifthackz.aisdv1.domain.repository.MediaPipeGenerationRepository import com.shifthackz.aisdv1.domain.repository.OpenAiGenerationRepository @@ -66,7 +68,8 @@ val repositoryModule = module { } singleOf(::TemporaryGenerationResultRepositoryImpl) bind TemporaryGenerationResultRepository::class - factoryOf(::LocalDiffusionGenerationRepositoryImpl) bind LocalDiffusionGenerationRepository::class + factoryOf(::LocalDiffusionOnnxGenerationRepositoryImpl) bind LocalDiffusionOnnxGenerationRepository::class + factoryOf(::LocalDiffusionCppGenerationRepositoryImpl) bind LocalDiffusionCppGenerationRepository::class factoryOf(::MediaPipeGenerationRepositoryImpl) bind MediaPipeGenerationRepository::class factoryOf(::HordeGenerationRepositoryImpl) bind HordeGenerationRepository::class factoryOf(::HuggingFaceGenerationRepositoryImpl) bind HuggingFaceGenerationRepository::class diff --git a/data/src/main/java/com/shifthackz/aisdv1/data/local/DownloadableModelLocalDataSource.kt b/data/src/main/java/com/shifthackz/aisdv1/data/local/DownloadableModelLocalDataSource.kt index 4ab031c2..2eadba9b 100644 --- a/data/src/main/java/com/shifthackz/aisdv1/data/local/DownloadableModelLocalDataSource.kt +++ b/data/src/main/java/com/shifthackz/aisdv1/data/local/DownloadableModelLocalDataSource.kt @@ -36,6 +36,19 @@ internal class DownloadableModelLocalDataSource( } .flatMap { models -> models.withLocalData() } + override fun getAllCpp(): Single> = dao + .queryByType(LocalAiModel.Type.Cpp.key) + .map(List::mapEntityToDomain) + .map { models -> + buildList { + addAll(models) + if (buildInfoProvider.type != BuildType.PLAY) { + add(LocalAiModel.CustomCpp) + } + } + } + .flatMap { models -> models.withLocalData() } + override fun getAllMediaPipe(): Single> = dao .queryByType(LocalAiModel.Type.MediaPipe.key) .map(List::mapEntityToDomain) @@ -89,20 +102,19 @@ internal class DownloadableModelLocalDataSource( try { when (model.id) { LocalAiModel.CustomOnnx.id, + LocalAiModel.CustomCpp.id, LocalAiModel.CustomMediaPipe.id -> emitter.onSuccess(true) - else -> { - - when (model.type) { - LocalAiModel.Type.ONNX -> { - val files = getLocalModelFiles(model.id).filter { it.isDirectory } - emitter.onSuccess(files.size == 4) - } + else -> when (model.type) { + LocalAiModel.Type.ONNX -> { + val files = getLocalModelFiles(model.id).filter { it.isDirectory } + emitter.onSuccess(files.size == 4) + } - LocalAiModel.Type.MediaPipe -> { - val files = getLocalModelFiles(model.id) - emitter.onSuccess(files.isNotEmpty()) - } + LocalAiModel.Type.MediaPipe, + LocalAiModel.Type.Cpp -> { + val files = getLocalModelFiles(model.id) + emitter.onSuccess(files.isNotEmpty()) } } } @@ -133,6 +145,7 @@ internal class DownloadableModelLocalDataSource( selected = when (this.type) { LocalAiModel.Type.ONNX -> preferenceManager.localOnnxModelId == id LocalAiModel.Type.MediaPipe -> preferenceManager.localMediaPipeModelId == id + LocalAiModel.Type.Cpp -> preferenceManager.localCppModelId == id }, ) } diff --git a/data/src/main/java/com/shifthackz/aisdv1/data/preference/PreferenceManagerImpl.kt b/data/src/main/java/com/shifthackz/aisdv1/data/preference/PreferenceManagerImpl.kt index 56191be3..69c16b9d 100644 --- a/data/src/main/java/com/shifthackz/aisdv1/data/preference/PreferenceManagerImpl.kt +++ b/data/src/main/java/com/shifthackz/aisdv1/data/preference/PreferenceManagerImpl.kt @@ -62,6 +62,11 @@ class PreferenceManagerImpl( default = LOCAL_DIFFUSION_CUSTOM_PATH, ) + override var localCppCustomModelPath: String by preferences.delegates.string( + key = KEY_LOCAL_DIFFUSION_CPP_CUSTOM_MODEL_PATH, + default = LOCAL_DIFFUSION_CUSTOM_PATH, + ) + override var localOnnxCustomModelPath: String by preferences.delegates.string( key = KEY_LOCAL_DIFFUSION_CUSTOM_MODEL_PATH, default = LOCAL_DIFFUSION_CUSTOM_PATH, @@ -166,7 +171,12 @@ class PreferenceManagerImpl( ) override var localOnnxModelId: String by preferences.delegates.string( - key = KEY_LOCAL_MODEL_ID, + key = KEY_LOCAL_ONNX_MODEL_ID, + onChanged = ::onPreferencesChanged, + ) + + override var localCppModelId: String by preferences.delegates.string( + key = KEY_LOCAL_CPP_MODEL_ID, onChanged = ::onPreferencesChanged, ) @@ -266,6 +276,7 @@ class PreferenceManagerImpl( const val KEY_SWARM_MODEL = "key_swarm_model" const val KEY_DEMO_MODE = "key_demo_mode" const val KEY_DEVELOPER_MODE = "key_developer_mode" + const val KEY_LOCAL_DIFFUSION_CPP_CUSTOM_MODEL_PATH = "key_local_diffusion_cpp_custom_model_path" const val KEY_LOCAL_DIFFUSION_CUSTOM_MODEL_PATH = "key_local_diffusion_custom_model_path" const val KEY_MEDIA_PIPE_CUSTOM_MODEL_PATH = "key_mediapipe_custom_model_path" const val KEY_ALLOW_LOCAL_DIFFUSION_CANCEL = "key_allow_local_diffusion_cancel" @@ -286,7 +297,8 @@ class PreferenceManagerImpl( const val KEY_ON_BOARDING_COMPLETE = "key_on_boarding_complete" const val KEY_FORCE_SETUP_AFTER_UPDATE = "force_upd_setup_v0.x.x-v0.6.2" const val KEY_MEDIA_PIPE_MODEL_ID = "key_mediapipe_model_id" - const val KEY_LOCAL_MODEL_ID = "key_local_model_id" + const val KEY_LOCAL_ONNX_MODEL_ID = "key_local_model_id" + const val KEY_LOCAL_CPP_MODEL_ID = "key_local_cpp_model_id" const val KEY_LOCAL_NN_API = "key_local_nn_api" const val KEY_DESIGN_DYNAMIC_COLORS = "key_design_dynamic_colors" const val KEY_DESIGN_SYSTEM_DARK_THEME = "key_design_system_dark_theme" diff --git a/data/src/main/java/com/shifthackz/aisdv1/data/remote/DownloadableModelRemoteDataSource.kt b/data/src/main/java/com/shifthackz/aisdv1/data/remote/DownloadableModelRemoteDataSource.kt index 1a552534..9b5a5f5c 100644 --- a/data/src/main/java/com/shifthackz/aisdv1/data/remote/DownloadableModelRemoteDataSource.kt +++ b/data/src/main/java/com/shifthackz/aisdv1/data/remote/DownloadableModelRemoteDataSource.kt @@ -22,12 +22,28 @@ internal class DownloadableModelRemoteDataSource( api .fetchOnnxModels() .map { it.mapRawToCheckpointDomain(LocalAiModel.Type.ONNX) }, +// api +// .fetchCppModels() +// .map { it.mapRawToCheckpointDomain(LocalAiModel.Type.Cpp) }, + Single.just( + listOf( + LocalAiModel( + id = "abb51642-58f9-4401-9ca1-c0eac29e5c1d", + type = LocalAiModel.Type.Cpp, + name = "Test", + size = "Unknown", + sources = listOf( + "https://share.moroz.cc/SDAI/cpp/epicphotogasmLCM_ultimatefidelity.zip" + ), + ) + ) + ), api .fetchMediaPipeModels() .map { it.mapRawToCheckpointDomain(LocalAiModel.Type.MediaPipe) }, - ::Pair, + ::Triple, ) - .map { (onnx, mediapipe) -> listOf(onnx, mediapipe).flatten() } + .map { (onnx, cpp, mediapipe) -> listOf(onnx, cpp, mediapipe).flatten() } override fun download(id: String, url: String): Observable = Completable .fromAction { diff --git a/data/src/main/java/com/shifthackz/aisdv1/data/repository/DownloadableModelRepositoryImpl.kt b/data/src/main/java/com/shifthackz/aisdv1/data/repository/DownloadableModelRepositoryImpl.kt index 0e46f816..2d1918c6 100644 --- a/data/src/main/java/com/shifthackz/aisdv1/data/repository/DownloadableModelRepositoryImpl.kt +++ b/data/src/main/java/com/shifthackz/aisdv1/data/repository/DownloadableModelRepositoryImpl.kt @@ -23,6 +23,12 @@ internal class DownloadableModelRepositoryImpl( .andThen(localDataSource.getAllOnnx()) .onErrorResumeNext { localDataSource.getAllOnnx() } + override fun getAllCpp(): Single> = remoteDataSource + .fetch() + .flatMapCompletable(localDataSource::save) + .andThen(localDataSource.getAllCpp()) + .onErrorResumeNext { localDataSource.getAllCpp() } + override fun getAllMediaPipe(): Single> { if (buildInfoProvider.type == BuildType.FOSS) { return Single.just(emptyList()) diff --git a/data/src/main/java/com/shifthackz/aisdv1/data/repository/LocalDiffusionCppGenerationRepositoryImpl.kt b/data/src/main/java/com/shifthackz/aisdv1/data/repository/LocalDiffusionCppGenerationRepositoryImpl.kt new file mode 100644 index 00000000..1143e0d5 --- /dev/null +++ b/data/src/main/java/com/shifthackz/aisdv1/data/repository/LocalDiffusionCppGenerationRepositoryImpl.kt @@ -0,0 +1,42 @@ +package com.shifthackz.aisdv1.data.repository + +import com.shifthackz.aisdv1.core.common.schedulers.SchedulersProvider +import com.shifthackz.aisdv1.core.imageprocessing.Base64ToBitmapConverter +import com.shifthackz.aisdv1.core.imageprocessing.BitmapToBase64Converter +import com.shifthackz.aisdv1.data.core.CoreGenerationRepository +import com.shifthackz.aisdv1.data.mappers.mapLocalDiffusionToAiGenResult +import com.shifthackz.aisdv1.domain.datasource.GenerationResultDataSource +import com.shifthackz.aisdv1.domain.entity.TextToImagePayload +import com.shifthackz.aisdv1.domain.feature.diffusion.LocalDiffusionCpp +import com.shifthackz.aisdv1.domain.feature.work.BackgroundWorkObserver +import com.shifthackz.aisdv1.domain.gateway.MediaStoreGateway +import com.shifthackz.aisdv1.domain.preference.PreferenceManager +import com.shifthackz.aisdv1.domain.repository.LocalDiffusionCppGenerationRepository + +internal class LocalDiffusionCppGenerationRepositoryImpl( + mediaStoreGateway: MediaStoreGateway, + base64ToBitmapConverter: Base64ToBitmapConverter, + localDataSource: GenerationResultDataSource.Local, + backgroundWorkObserver: BackgroundWorkObserver, + private val preferenceManager: PreferenceManager, + private val localDiffusion: LocalDiffusionCpp, + private val bitmapToBase64Converter: BitmapToBase64Converter, + private val schedulersProvider: SchedulersProvider, +) : CoreGenerationRepository( + mediaStoreGateway = mediaStoreGateway, + base64ToBitmapConverter = base64ToBitmapConverter, + localDataSource = localDataSource, + preferenceManager = preferenceManager, + backgroundWorkObserver = backgroundWorkObserver, +), LocalDiffusionCppGenerationRepository { + + override fun generateFromText(payload: TextToImagePayload) = localDiffusion + .process(payload) + .subscribeOn(schedulersProvider.byToken(preferenceManager.localOnnxSchedulerThread)) + .map(BitmapToBase64Converter::Input) + .flatMap(bitmapToBase64Converter::invoke) + .map(BitmapToBase64Converter.Output::base64ImageString) + .map { base64 -> payload to base64 } + .map(Pair::mapLocalDiffusionToAiGenResult) + .flatMap(::insertGenerationResult) +} diff --git a/data/src/main/java/com/shifthackz/aisdv1/data/repository/LocalDiffusionGenerationRepositoryImpl.kt b/data/src/main/java/com/shifthackz/aisdv1/data/repository/LocalDiffusionOnnxGenerationRepositoryImpl.kt similarity index 91% rename from data/src/main/java/com/shifthackz/aisdv1/data/repository/LocalDiffusionGenerationRepositoryImpl.kt rename to data/src/main/java/com/shifthackz/aisdv1/data/repository/LocalDiffusionOnnxGenerationRepositoryImpl.kt index 625a1ca8..9a739f60 100644 --- a/data/src/main/java/com/shifthackz/aisdv1/data/repository/LocalDiffusionGenerationRepositoryImpl.kt +++ b/data/src/main/java/com/shifthackz/aisdv1/data/repository/LocalDiffusionOnnxGenerationRepositoryImpl.kt @@ -8,20 +8,20 @@ import com.shifthackz.aisdv1.data.mappers.mapLocalDiffusionToAiGenResult import com.shifthackz.aisdv1.domain.datasource.DownloadableModelDataSource import com.shifthackz.aisdv1.domain.datasource.GenerationResultDataSource import com.shifthackz.aisdv1.domain.entity.TextToImagePayload -import com.shifthackz.aisdv1.domain.feature.diffusion.LocalDiffusion +import com.shifthackz.aisdv1.domain.feature.diffusion.LocalDiffusionONNX import com.shifthackz.aisdv1.domain.feature.work.BackgroundWorkObserver import com.shifthackz.aisdv1.domain.gateway.MediaStoreGateway import com.shifthackz.aisdv1.domain.preference.PreferenceManager -import com.shifthackz.aisdv1.domain.repository.LocalDiffusionGenerationRepository +import com.shifthackz.aisdv1.domain.repository.LocalDiffusionOnnxGenerationRepository import io.reactivex.rxjava3.core.Single -internal class LocalDiffusionGenerationRepositoryImpl( +internal class LocalDiffusionOnnxGenerationRepositoryImpl( mediaStoreGateway: MediaStoreGateway, base64ToBitmapConverter: Base64ToBitmapConverter, localDataSource: GenerationResultDataSource.Local, backgroundWorkObserver: BackgroundWorkObserver, private val preferenceManager: PreferenceManager, - private val localDiffusion: LocalDiffusion, + private val localDiffusion: LocalDiffusionONNX, private val downloadableLocalDataSource: DownloadableModelDataSource.Local, private val bitmapToBase64Converter: BitmapToBase64Converter, private val schedulersProvider: SchedulersProvider, @@ -31,7 +31,7 @@ internal class LocalDiffusionGenerationRepositoryImpl( localDataSource = localDataSource, preferenceManager = preferenceManager, backgroundWorkObserver = backgroundWorkObserver, -), LocalDiffusionGenerationRepository { +), LocalDiffusionOnnxGenerationRepository { override fun observeStatus() = localDiffusion.observeStatus() diff --git a/data/src/test/java/com/shifthackz/aisdv1/data/preference/PreferenceManagerImplTest.kt b/data/src/test/java/com/shifthackz/aisdv1/data/preference/PreferenceManagerImplTest.kt index 5f68e23b..2a1fe694 100644 --- a/data/src/test/java/com/shifthackz/aisdv1/data/preference/PreferenceManagerImplTest.kt +++ b/data/src/test/java/com/shifthackz/aisdv1/data/preference/PreferenceManagerImplTest.kt @@ -19,7 +19,7 @@ import com.shifthackz.aisdv1.data.preference.PreferenceManagerImpl.Companion.KEY import com.shifthackz.aisdv1.data.preference.PreferenceManagerImpl.Companion.KEY_HORDE_API_KEY import com.shifthackz.aisdv1.data.preference.PreferenceManagerImpl.Companion.KEY_HUGGING_FACE_API_KEY import com.shifthackz.aisdv1.data.preference.PreferenceManagerImpl.Companion.KEY_HUGGING_FACE_MODEL_KEY -import com.shifthackz.aisdv1.data.preference.PreferenceManagerImpl.Companion.KEY_LOCAL_MODEL_ID +import com.shifthackz.aisdv1.data.preference.PreferenceManagerImpl.Companion.KEY_LOCAL_ONNX_MODEL_ID import com.shifthackz.aisdv1.data.preference.PreferenceManagerImpl.Companion.KEY_LOCAL_NN_API import com.shifthackz.aisdv1.data.preference.PreferenceManagerImpl.Companion.KEY_MONITOR_CONNECTIVITY import com.shifthackz.aisdv1.data.preference.PreferenceManagerImpl.Companion.KEY_OPEN_AI_API_KEY @@ -372,12 +372,12 @@ class PreferenceManagerImplTest { @Test fun `given user reads default localModelId, changes it, expected default value, then changed value`() { - whenever(stubPreference.getString(eq(KEY_LOCAL_MODEL_ID), any())) + whenever(stubPreference.getString(eq(KEY_LOCAL_ONNX_MODEL_ID), any())) .thenReturn("") Assert.assertEquals("", preferenceManager.localOnnxModelId) - whenever(stubPreference.getString(eq(KEY_LOCAL_MODEL_ID), any())) + whenever(stubPreference.getString(eq(KEY_LOCAL_ONNX_MODEL_ID), any())) .thenReturn("key") preferenceManager.localOnnxModelId = "key" diff --git a/data/src/test/java/com/shifthackz/aisdv1/data/repository/LocalDiffusionGenerationRepositoryImplTest.kt b/data/src/test/java/com/shifthackz/aisdv1/data/repository/LocalDiffusionONNXGenerationRepositoryImplTest.kt similarity index 97% rename from data/src/test/java/com/shifthackz/aisdv1/data/repository/LocalDiffusionGenerationRepositoryImplTest.kt rename to data/src/test/java/com/shifthackz/aisdv1/data/repository/LocalDiffusionONNXGenerationRepositoryImplTest.kt index 126edac1..03d75502 100644 --- a/data/src/test/java/com/shifthackz/aisdv1/data/repository/LocalDiffusionGenerationRepositoryImplTest.kt +++ b/data/src/test/java/com/shifthackz/aisdv1/data/repository/LocalDiffusionONNXGenerationRepositoryImplTest.kt @@ -11,7 +11,7 @@ import com.shifthackz.aisdv1.domain.datasource.DownloadableModelDataSource import com.shifthackz.aisdv1.domain.datasource.GenerationResultDataSource import com.shifthackz.aisdv1.domain.entity.AiGenerationResult import com.shifthackz.aisdv1.domain.entity.LocalDiffusionStatus -import com.shifthackz.aisdv1.domain.feature.diffusion.LocalDiffusion +import com.shifthackz.aisdv1.domain.feature.diffusion.LocalDiffusionONNX import com.shifthackz.aisdv1.domain.feature.work.BackgroundWorkObserver import com.shifthackz.aisdv1.domain.gateway.MediaStoreGateway import com.shifthackz.aisdv1.domain.preference.PreferenceManager @@ -28,7 +28,7 @@ import org.junit.Test import java.util.concurrent.Executor import java.util.concurrent.Executors -class LocalDiffusionGenerationRepositoryImplTest { +class LocalDiffusionONNXGenerationRepositoryImplTest { private val stubBitmap = mockk() private val stubException = Throwable("Something went wrong.") @@ -38,7 +38,7 @@ class LocalDiffusionGenerationRepositoryImplTest { private val stubBitmapToBase64Converter = mockk() private val stubLocalDataSource = mockk() private val stubPreferenceManager = mockk() - private val stubLocalDiffusion = mockk() + private val stubLocalDiffusion = mockk() private val stubDownloadableLocalDataSource = mockk() private val stubBackgroundWorkObserver = mockk() @@ -49,7 +49,7 @@ class LocalDiffusionGenerationRepositoryImplTest { override val singleThread: Executor = Executors.newSingleThreadExecutor() } - private val repository = LocalDiffusionGenerationRepositoryImpl( + private val repository = LocalDiffusionOnnxGenerationRepositoryImpl( mediaStoreGateway = stubMediaStoreGateway, base64ToBitmapConverter = stubBase64ToBitmapConverter, localDataSource = stubLocalDataSource, diff --git a/domain/src/main/java/com/shifthackz/aisdv1/domain/datasource/DownloadableModelDataSource.kt b/domain/src/main/java/com/shifthackz/aisdv1/domain/datasource/DownloadableModelDataSource.kt index cb56c621..7af1d92e 100644 --- a/domain/src/main/java/com/shifthackz/aisdv1/domain/datasource/DownloadableModelDataSource.kt +++ b/domain/src/main/java/com/shifthackz/aisdv1/domain/datasource/DownloadableModelDataSource.kt @@ -16,6 +16,7 @@ sealed interface DownloadableModelDataSource { interface Local : DownloadableModelDataSource { fun getAllOnnx(): Single> + fun getAllCpp(): Single> fun getAllMediaPipe(): Single> fun getById(id: String): Single fun getSelectedOnnx(): Single diff --git a/domain/src/main/java/com/shifthackz/aisdv1/domain/di/DomainModule.kt b/domain/src/main/java/com/shifthackz/aisdv1/domain/di/DomainModule.kt index 3a5f59eb..7de22954 100755 --- a/domain/src/main/java/com/shifthackz/aisdv1/domain/di/DomainModule.kt +++ b/domain/src/main/java/com/shifthackz/aisdv1/domain/di/DomainModule.kt @@ -36,6 +36,8 @@ import com.shifthackz.aisdv1.domain.usecase.downloadable.DeleteModelUseCase import com.shifthackz.aisdv1.domain.usecase.downloadable.DeleteModelUseCaseImpl import com.shifthackz.aisdv1.domain.usecase.downloadable.DownloadModelUseCase import com.shifthackz.aisdv1.domain.usecase.downloadable.DownloadModelUseCaseImpl +import com.shifthackz.aisdv1.domain.usecase.downloadable.GetLocalCppModelsUseCase +import com.shifthackz.aisdv1.domain.usecase.downloadable.GetLocalCppModelsUseCaseImpl import com.shifthackz.aisdv1.domain.usecase.downloadable.GetLocalMediaPipeModelsUseCase import com.shifthackz.aisdv1.domain.usecase.downloadable.GetLocalMediaPipeModelsUseCaseImpl import com.shifthackz.aisdv1.domain.usecase.downloadable.GetLocalModelUseCase @@ -98,8 +100,10 @@ import com.shifthackz.aisdv1.domain.usecase.settings.ConnectToHordeUseCase import com.shifthackz.aisdv1.domain.usecase.settings.ConnectToHordeUseCaseImpl import com.shifthackz.aisdv1.domain.usecase.settings.ConnectToHuggingFaceUseCase import com.shifthackz.aisdv1.domain.usecase.settings.ConnectToHuggingFaceUseCaseImpl -import com.shifthackz.aisdv1.domain.usecase.settings.ConnectToLocalDiffusionUseCase -import com.shifthackz.aisdv1.domain.usecase.settings.ConnectToLocalDiffusionUseCaseImpl +import com.shifthackz.aisdv1.domain.usecase.settings.ConnectToLocalDiffusionCppUseCase +import com.shifthackz.aisdv1.domain.usecase.settings.ConnectToLocalDiffusionCppUseCaseImpl +import com.shifthackz.aisdv1.domain.usecase.settings.ConnectToLocalDiffusionOnnxUseCase +import com.shifthackz.aisdv1.domain.usecase.settings.ConnectToLocalDiffusionOnnxUseCaseImpl import com.shifthackz.aisdv1.domain.usecase.settings.ConnectToMediaPipeUseCase import com.shifthackz.aisdv1.domain.usecase.settings.ConnectToMediaPipeUseCaseImpl import com.shifthackz.aisdv1.domain.usecase.settings.ConnectToOpenAiUseCase @@ -167,6 +171,7 @@ internal val useCasesModule = module { factoryOf(::GetLastResultFromCacheUseCaseImpl) bind GetLastResultFromCacheUseCase::class factoryOf(::ObserveLocalDiffusionProcessStatusUseCaseImpl) bind ObserveLocalDiffusionProcessStatusUseCase::class factoryOf(::GetLocalOnnxModelsUseCaseImpl) bind GetLocalOnnxModelsUseCase::class + factoryOf(::GetLocalCppModelsUseCaseImpl) bind GetLocalCppModelsUseCase::class factoryOf(::GetLocalMediaPipeModelsUseCaseImpl) bind GetLocalMediaPipeModelsUseCase::class factoryOf(::DownloadModelUseCaseImpl) bind DownloadModelUseCase::class factoryOf(::ObserveLocalOnnxModelsUseCaseImpl) bind ObserveLocalOnnxModelsUseCase::class @@ -175,7 +180,8 @@ internal val useCasesModule = module { factoryOf(::ReleaseWakeLockUseCaseImpl) bind ReleaseWakeLockUseCase::class factoryOf(::InterruptGenerationUseCaseImpl) bind InterruptGenerationUseCase::class factoryOf(::ConnectToHordeUseCaseImpl) bind ConnectToHordeUseCase::class - factoryOf(::ConnectToLocalDiffusionUseCaseImpl) bind ConnectToLocalDiffusionUseCase::class + factoryOf(::ConnectToLocalDiffusionOnnxUseCaseImpl) bind ConnectToLocalDiffusionOnnxUseCase::class + factoryOf(::ConnectToLocalDiffusionCppUseCaseImpl) bind ConnectToLocalDiffusionCppUseCase::class factoryOf(::ConnectToMediaPipeUseCaseImpl) bind ConnectToMediaPipeUseCase::class factoryOf(::ConnectToA1111UseCaseImpl) bind ConnectToA1111UseCase::class factoryOf(::ConnectToSwarmUiUseCaseImpl) bind ConnectToSwarmUiUseCase::class diff --git a/domain/src/main/java/com/shifthackz/aisdv1/domain/entity/Configuration.kt b/domain/src/main/java/com/shifthackz/aisdv1/domain/entity/Configuration.kt index 40c5e10a..c3bda1bc 100644 --- a/domain/src/main/java/com/shifthackz/aisdv1/domain/entity/Configuration.kt +++ b/domain/src/main/java/com/shifthackz/aisdv1/domain/entity/Configuration.kt @@ -15,6 +15,8 @@ data class Configuration( val stabilityAiApiKey: String = "", val stabilityAiEngineId: String = "", val authCredentials: AuthorizationCredentials = AuthorizationCredentials.None, + val localCppModelId: String = "", + val localCppModelPath: String = "", val localOnnxModelId: String = "", val localOnnxModelPath: String = "", val localMediaPipeModelId: String = "", diff --git a/domain/src/main/java/com/shifthackz/aisdv1/domain/entity/LocalAiModel.kt b/domain/src/main/java/com/shifthackz/aisdv1/domain/entity/LocalAiModel.kt index b7958512..469c82ae 100644 --- a/domain/src/main/java/com/shifthackz/aisdv1/domain/entity/LocalAiModel.kt +++ b/domain/src/main/java/com/shifthackz/aisdv1/domain/entity/LocalAiModel.kt @@ -11,6 +11,7 @@ data class LocalAiModel( ) { enum class Type(val key: String) { ONNX("onnx"), + Cpp("cpp"), MediaPipe("mediapipe"); companion object { @@ -27,6 +28,14 @@ data class LocalAiModel( sources = emptyList(), ) + val CustomCpp = LocalAiModel( + id = "CUSTOM_CPP", + type = Type.Cpp, + name = "Custom", + size = "NaN", + sources = emptyList(), + ) + val CustomMediaPipe = LocalAiModel( id = "CUSTOM_MP", type = Type.MediaPipe, diff --git a/domain/src/main/java/com/shifthackz/aisdv1/domain/entity/ServerSource.kt b/domain/src/main/java/com/shifthackz/aisdv1/domain/entity/ServerSource.kt index 15244e1d..c568a51d 100644 --- a/domain/src/main/java/com/shifthackz/aisdv1/domain/entity/ServerSource.kt +++ b/domain/src/main/java/com/shifthackz/aisdv1/domain/entity/ServerSource.kt @@ -73,6 +73,14 @@ enum class ServerSource( FeatureTag.MultipleModels, ), ), + LOCAL_CPP( + key = "local_cpp", + featureTags = setOf( + FeatureTag.Offline, + FeatureTag.Txt2Img, + FeatureTag.MultipleModels, + ), + ), LOCAL_GOOGLE_MEDIA_PIPE( key = "local_google_media_pipe", featureTags = setOf( diff --git a/domain/src/main/java/com/shifthackz/aisdv1/domain/feature/diffusion/LocalDiffusionCpp.kt b/domain/src/main/java/com/shifthackz/aisdv1/domain/feature/diffusion/LocalDiffusionCpp.kt new file mode 100644 index 00000000..19b88e21 --- /dev/null +++ b/domain/src/main/java/com/shifthackz/aisdv1/domain/feature/diffusion/LocalDiffusionCpp.kt @@ -0,0 +1,9 @@ +package com.shifthackz.aisdv1.domain.feature.diffusion + +import android.graphics.Bitmap +import com.shifthackz.aisdv1.domain.entity.TextToImagePayload +import io.reactivex.rxjava3.core.Single + +interface LocalDiffusionCpp { + fun process(payload: TextToImagePayload): Single +} diff --git a/domain/src/main/java/com/shifthackz/aisdv1/domain/feature/diffusion/LocalDiffusion.kt b/domain/src/main/java/com/shifthackz/aisdv1/domain/feature/diffusion/LocalDiffusionONNX.kt similarity index 94% rename from domain/src/main/java/com/shifthackz/aisdv1/domain/feature/diffusion/LocalDiffusion.kt rename to domain/src/main/java/com/shifthackz/aisdv1/domain/feature/diffusion/LocalDiffusionONNX.kt index afcdefd0..90dfe9e9 100644 --- a/domain/src/main/java/com/shifthackz/aisdv1/domain/feature/diffusion/LocalDiffusion.kt +++ b/domain/src/main/java/com/shifthackz/aisdv1/domain/feature/diffusion/LocalDiffusionONNX.kt @@ -7,7 +7,7 @@ import io.reactivex.rxjava3.core.Completable import io.reactivex.rxjava3.core.Observable import io.reactivex.rxjava3.core.Single -interface LocalDiffusion { +interface LocalDiffusionONNX { fun process(payload: TextToImagePayload): Single fun interrupt(): Completable fun observeStatus(): Observable diff --git a/domain/src/main/java/com/shifthackz/aisdv1/domain/interactor/settings/SetupConnectionInterActor.kt b/domain/src/main/java/com/shifthackz/aisdv1/domain/interactor/settings/SetupConnectionInterActor.kt index 5959a035..201cde21 100644 --- a/domain/src/main/java/com/shifthackz/aisdv1/domain/interactor/settings/SetupConnectionInterActor.kt +++ b/domain/src/main/java/com/shifthackz/aisdv1/domain/interactor/settings/SetupConnectionInterActor.kt @@ -3,7 +3,8 @@ package com.shifthackz.aisdv1.domain.interactor.settings import com.shifthackz.aisdv1.domain.usecase.settings.ConnectToA1111UseCase import com.shifthackz.aisdv1.domain.usecase.settings.ConnectToHordeUseCase import com.shifthackz.aisdv1.domain.usecase.settings.ConnectToHuggingFaceUseCase -import com.shifthackz.aisdv1.domain.usecase.settings.ConnectToLocalDiffusionUseCase +import com.shifthackz.aisdv1.domain.usecase.settings.ConnectToLocalDiffusionCppUseCase +import com.shifthackz.aisdv1.domain.usecase.settings.ConnectToLocalDiffusionOnnxUseCase import com.shifthackz.aisdv1.domain.usecase.settings.ConnectToMediaPipeUseCase import com.shifthackz.aisdv1.domain.usecase.settings.ConnectToOpenAiUseCase import com.shifthackz.aisdv1.domain.usecase.settings.ConnectToStabilityAiUseCase @@ -11,7 +12,8 @@ import com.shifthackz.aisdv1.domain.usecase.settings.ConnectToSwarmUiUseCase interface SetupConnectionInterActor { val connectToHorde: ConnectToHordeUseCase - val connectToLocal: ConnectToLocalDiffusionUseCase + val connectToLocalOnnx: ConnectToLocalDiffusionOnnxUseCase + val connectToLocalCpp: ConnectToLocalDiffusionCppUseCase val connectToMediaPipe: ConnectToMediaPipeUseCase val connectToA1111: ConnectToA1111UseCase val connectToHuggingFace: ConnectToHuggingFaceUseCase diff --git a/domain/src/main/java/com/shifthackz/aisdv1/domain/interactor/settings/SetupConnectionInterActorImpl.kt b/domain/src/main/java/com/shifthackz/aisdv1/domain/interactor/settings/SetupConnectionInterActorImpl.kt index 306da094..34d63126 100644 --- a/domain/src/main/java/com/shifthackz/aisdv1/domain/interactor/settings/SetupConnectionInterActorImpl.kt +++ b/domain/src/main/java/com/shifthackz/aisdv1/domain/interactor/settings/SetupConnectionInterActorImpl.kt @@ -3,7 +3,8 @@ package com.shifthackz.aisdv1.domain.interactor.settings import com.shifthackz.aisdv1.domain.usecase.settings.ConnectToA1111UseCase import com.shifthackz.aisdv1.domain.usecase.settings.ConnectToHordeUseCase import com.shifthackz.aisdv1.domain.usecase.settings.ConnectToHuggingFaceUseCase -import com.shifthackz.aisdv1.domain.usecase.settings.ConnectToLocalDiffusionUseCase +import com.shifthackz.aisdv1.domain.usecase.settings.ConnectToLocalDiffusionCppUseCase +import com.shifthackz.aisdv1.domain.usecase.settings.ConnectToLocalDiffusionOnnxUseCase import com.shifthackz.aisdv1.domain.usecase.settings.ConnectToMediaPipeUseCase import com.shifthackz.aisdv1.domain.usecase.settings.ConnectToOpenAiUseCase import com.shifthackz.aisdv1.domain.usecase.settings.ConnectToStabilityAiUseCase @@ -11,7 +12,8 @@ import com.shifthackz.aisdv1.domain.usecase.settings.ConnectToSwarmUiUseCase internal data class SetupConnectionInterActorImpl( override val connectToHorde: ConnectToHordeUseCase, - override val connectToLocal: ConnectToLocalDiffusionUseCase, + override val connectToLocalOnnx: ConnectToLocalDiffusionOnnxUseCase, + override val connectToLocalCpp: ConnectToLocalDiffusionCppUseCase, override val connectToMediaPipe: ConnectToMediaPipeUseCase, override val connectToA1111: ConnectToA1111UseCase, override val connectToHuggingFace: ConnectToHuggingFaceUseCase, diff --git a/domain/src/main/java/com/shifthackz/aisdv1/domain/preference/PreferenceManager.kt b/domain/src/main/java/com/shifthackz/aisdv1/domain/preference/PreferenceManager.kt index b049cf33..e0d2a724 100644 --- a/domain/src/main/java/com/shifthackz/aisdv1/domain/preference/PreferenceManager.kt +++ b/domain/src/main/java/com/shifthackz/aisdv1/domain/preference/PreferenceManager.kt @@ -14,6 +14,7 @@ interface PreferenceManager { var demoMode: Boolean var developerMode: Boolean var localMediaPipeCustomModelPath: String + var localCppCustomModelPath: String var localOnnxCustomModelPath: String var localOnnxAllowCancel: Boolean var localOnnxSchedulerThread: SchedulersToken @@ -33,6 +34,7 @@ interface PreferenceManager { var onBoardingComplete: Boolean var forceSetupAfterUpdate: Boolean var localOnnxModelId: String + var localCppModelId: String var localOnnxUseNNAPI: Boolean var localMediaPipeModelId: String var designUseSystemColorPalette: Boolean diff --git a/domain/src/main/java/com/shifthackz/aisdv1/domain/repository/DownloadableModelRepository.kt b/domain/src/main/java/com/shifthackz/aisdv1/domain/repository/DownloadableModelRepository.kt index 9d124726..4e649c44 100644 --- a/domain/src/main/java/com/shifthackz/aisdv1/domain/repository/DownloadableModelRepository.kt +++ b/domain/src/main/java/com/shifthackz/aisdv1/domain/repository/DownloadableModelRepository.kt @@ -11,6 +11,7 @@ interface DownloadableModelRepository { fun download(id: String, url: String): Observable fun delete(id: String): Completable fun getAllOnnx(): Single> + fun getAllCpp(): Single> fun getAllMediaPipe(): Single> fun observeAllOnnx(): Flowable> } diff --git a/domain/src/main/java/com/shifthackz/aisdv1/domain/repository/LocalDiffusionCppGenerationRepository.kt b/domain/src/main/java/com/shifthackz/aisdv1/domain/repository/LocalDiffusionCppGenerationRepository.kt new file mode 100644 index 00000000..18bef831 --- /dev/null +++ b/domain/src/main/java/com/shifthackz/aisdv1/domain/repository/LocalDiffusionCppGenerationRepository.kt @@ -0,0 +1,9 @@ +package com.shifthackz.aisdv1.domain.repository + +import com.shifthackz.aisdv1.domain.entity.AiGenerationResult +import com.shifthackz.aisdv1.domain.entity.TextToImagePayload +import io.reactivex.rxjava3.core.Single + +interface LocalDiffusionCppGenerationRepository { + fun generateFromText(payload: TextToImagePayload): Single +} diff --git a/domain/src/main/java/com/shifthackz/aisdv1/domain/repository/LocalDiffusionGenerationRepository.kt b/domain/src/main/java/com/shifthackz/aisdv1/domain/repository/LocalDiffusionOnnxGenerationRepository.kt similarity index 91% rename from domain/src/main/java/com/shifthackz/aisdv1/domain/repository/LocalDiffusionGenerationRepository.kt rename to domain/src/main/java/com/shifthackz/aisdv1/domain/repository/LocalDiffusionOnnxGenerationRepository.kt index 3ed5c370..535f34e2 100644 --- a/domain/src/main/java/com/shifthackz/aisdv1/domain/repository/LocalDiffusionGenerationRepository.kt +++ b/domain/src/main/java/com/shifthackz/aisdv1/domain/repository/LocalDiffusionOnnxGenerationRepository.kt @@ -7,7 +7,7 @@ import io.reactivex.rxjava3.core.Completable import io.reactivex.rxjava3.core.Observable import io.reactivex.rxjava3.core.Single -interface LocalDiffusionGenerationRepository { +interface LocalDiffusionOnnxGenerationRepository { fun observeStatus(): Observable fun generateFromText(payload: TextToImagePayload): Single fun interruptGeneration(): Completable diff --git a/domain/src/main/java/com/shifthackz/aisdv1/domain/usecase/downloadable/GetLocalCppModelsUseCase.kt b/domain/src/main/java/com/shifthackz/aisdv1/domain/usecase/downloadable/GetLocalCppModelsUseCase.kt new file mode 100644 index 00000000..0cc0e805 --- /dev/null +++ b/domain/src/main/java/com/shifthackz/aisdv1/domain/usecase/downloadable/GetLocalCppModelsUseCase.kt @@ -0,0 +1,8 @@ +package com.shifthackz.aisdv1.domain.usecase.downloadable + +import com.shifthackz.aisdv1.domain.entity.LocalAiModel +import io.reactivex.rxjava3.core.Single + +interface GetLocalCppModelsUseCase { + operator fun invoke(): Single> +} diff --git a/domain/src/main/java/com/shifthackz/aisdv1/domain/usecase/downloadable/GetLocalCppModelsUseCaseImpl.kt b/domain/src/main/java/com/shifthackz/aisdv1/domain/usecase/downloadable/GetLocalCppModelsUseCaseImpl.kt new file mode 100644 index 00000000..3ca4deda --- /dev/null +++ b/domain/src/main/java/com/shifthackz/aisdv1/domain/usecase/downloadable/GetLocalCppModelsUseCaseImpl.kt @@ -0,0 +1,10 @@ +package com.shifthackz.aisdv1.domain.usecase.downloadable + +import com.shifthackz.aisdv1.domain.repository.DownloadableModelRepository + +internal class GetLocalCppModelsUseCaseImpl( + private val downloadableModelRepository: DownloadableModelRepository, +) : GetLocalCppModelsUseCase { + + override fun invoke() = downloadableModelRepository.getAllCpp() +} diff --git a/domain/src/main/java/com/shifthackz/aisdv1/domain/usecase/downloadable/GetLocalMediaPipeModelsUseCaseImpl.kt b/domain/src/main/java/com/shifthackz/aisdv1/domain/usecase/downloadable/GetLocalMediaPipeModelsUseCaseImpl.kt index 4da8eeba..5ccd114c 100644 --- a/domain/src/main/java/com/shifthackz/aisdv1/domain/usecase/downloadable/GetLocalMediaPipeModelsUseCaseImpl.kt +++ b/domain/src/main/java/com/shifthackz/aisdv1/domain/usecase/downloadable/GetLocalMediaPipeModelsUseCaseImpl.kt @@ -4,7 +4,7 @@ import com.shifthackz.aisdv1.domain.repository.DownloadableModelRepository internal class GetLocalMediaPipeModelsUseCaseImpl( private val downloadableModelRepository: DownloadableModelRepository, - ) : GetLocalMediaPipeModelsUseCase { +) : GetLocalMediaPipeModelsUseCase { override fun invoke() = downloadableModelRepository.getAllMediaPipe() } diff --git a/domain/src/main/java/com/shifthackz/aisdv1/domain/usecase/generation/InterruptGenerationUseCaseImpl.kt b/domain/src/main/java/com/shifthackz/aisdv1/domain/usecase/generation/InterruptGenerationUseCaseImpl.kt index 37f6720c..883b12d7 100644 --- a/domain/src/main/java/com/shifthackz/aisdv1/domain/usecase/generation/InterruptGenerationUseCaseImpl.kt +++ b/domain/src/main/java/com/shifthackz/aisdv1/domain/usecase/generation/InterruptGenerationUseCaseImpl.kt @@ -3,21 +3,21 @@ package com.shifthackz.aisdv1.domain.usecase.generation import com.shifthackz.aisdv1.domain.entity.ServerSource import com.shifthackz.aisdv1.domain.preference.PreferenceManager import com.shifthackz.aisdv1.domain.repository.HordeGenerationRepository -import com.shifthackz.aisdv1.domain.repository.LocalDiffusionGenerationRepository +import com.shifthackz.aisdv1.domain.repository.LocalDiffusionOnnxGenerationRepository import com.shifthackz.aisdv1.domain.repository.StableDiffusionGenerationRepository import io.reactivex.rxjava3.core.Completable internal class InterruptGenerationUseCaseImpl( private val stableDiffusionGenerationRepository: StableDiffusionGenerationRepository, private val hordeGenerationRepository: HordeGenerationRepository, - private val localDiffusionGenerationRepository: LocalDiffusionGenerationRepository, + private val localDiffusionOnnxGenerationRepository: LocalDiffusionOnnxGenerationRepository, private val preferenceManager: PreferenceManager, ) : InterruptGenerationUseCase { override fun invoke() = when (preferenceManager.source) { ServerSource.AUTOMATIC1111 -> stableDiffusionGenerationRepository.interruptGeneration() ServerSource.HORDE -> hordeGenerationRepository.interruptGeneration() - ServerSource.LOCAL_MICROSOFT_ONNX -> localDiffusionGenerationRepository.interruptGeneration() + ServerSource.LOCAL_MICROSOFT_ONNX -> localDiffusionOnnxGenerationRepository.interruptGeneration() else -> Completable.complete() } } diff --git a/domain/src/main/java/com/shifthackz/aisdv1/domain/usecase/generation/ObserveLocalDiffusionProcessStatusUseCaseImpl.kt b/domain/src/main/java/com/shifthackz/aisdv1/domain/usecase/generation/ObserveLocalDiffusionProcessStatusUseCaseImpl.kt index 8e8430dd..0600db45 100644 --- a/domain/src/main/java/com/shifthackz/aisdv1/domain/usecase/generation/ObserveLocalDiffusionProcessStatusUseCaseImpl.kt +++ b/domain/src/main/java/com/shifthackz/aisdv1/domain/usecase/generation/ObserveLocalDiffusionProcessStatusUseCaseImpl.kt @@ -1,12 +1,12 @@ package com.shifthackz.aisdv1.domain.usecase.generation -import com.shifthackz.aisdv1.domain.repository.LocalDiffusionGenerationRepository +import com.shifthackz.aisdv1.domain.repository.LocalDiffusionOnnxGenerationRepository internal class ObserveLocalDiffusionProcessStatusUseCaseImpl( - private val localDiffusionGenerationRepository: LocalDiffusionGenerationRepository, + private val localDiffusionOnnxGenerationRepository: LocalDiffusionOnnxGenerationRepository, ) : ObserveLocalDiffusionProcessStatusUseCase { - override fun invoke() = localDiffusionGenerationRepository + override fun invoke() = localDiffusionOnnxGenerationRepository .observeStatus() .distinctUntilChanged() } diff --git a/domain/src/main/java/com/shifthackz/aisdv1/domain/usecase/generation/TextToImageUseCaseImpl.kt b/domain/src/main/java/com/shifthackz/aisdv1/domain/usecase/generation/TextToImageUseCaseImpl.kt index f823a153..2fb6812f 100755 --- a/domain/src/main/java/com/shifthackz/aisdv1/domain/usecase/generation/TextToImageUseCaseImpl.kt +++ b/domain/src/main/java/com/shifthackz/aisdv1/domain/usecase/generation/TextToImageUseCaseImpl.kt @@ -1,19 +1,18 @@ package com.shifthackz.aisdv1.domain.usecase.generation -import com.shifthackz.aisdv1.domain.entity.AiGenerationResult import com.shifthackz.aisdv1.domain.entity.ServerSource import com.shifthackz.aisdv1.domain.entity.TextToImagePayload import com.shifthackz.aisdv1.domain.preference.PreferenceManager import com.shifthackz.aisdv1.domain.repository.HordeGenerationRepository import com.shifthackz.aisdv1.domain.repository.HuggingFaceGenerationRepository -import com.shifthackz.aisdv1.domain.repository.LocalDiffusionGenerationRepository +import com.shifthackz.aisdv1.domain.repository.LocalDiffusionCppGenerationRepository +import com.shifthackz.aisdv1.domain.repository.LocalDiffusionOnnxGenerationRepository import com.shifthackz.aisdv1.domain.repository.MediaPipeGenerationRepository import com.shifthackz.aisdv1.domain.repository.OpenAiGenerationRepository import com.shifthackz.aisdv1.domain.repository.StabilityAiGenerationRepository import com.shifthackz.aisdv1.domain.repository.StableDiffusionGenerationRepository import com.shifthackz.aisdv1.domain.repository.SwarmUiGenerationRepository import io.reactivex.rxjava3.core.Observable -import io.reactivex.rxjava3.core.Single internal class TextToImageUseCaseImpl( private val stableDiffusionGenerationRepository: StableDiffusionGenerationRepository, @@ -22,26 +21,26 @@ internal class TextToImageUseCaseImpl( private val openAiGenerationRepository: OpenAiGenerationRepository, private val stabilityAiGenerationRepository: StabilityAiGenerationRepository, private val swarmUiGenerationRepository: SwarmUiGenerationRepository, - private val localDiffusionGenerationRepository: LocalDiffusionGenerationRepository, + private val localDiffusionOnnxGenerationRepository: LocalDiffusionOnnxGenerationRepository, + private val localDiffusionCppGenerationRepository: LocalDiffusionCppGenerationRepository, private val mediaPipeGenerationRepository: MediaPipeGenerationRepository, private val preferenceManager: PreferenceManager, ) : TextToImageUseCase { - override operator fun invoke( - payload: TextToImagePayload, - ): Single> = Observable + override operator fun invoke(payload: TextToImagePayload) = Observable .range(1, payload.batchCount) .flatMapSingle { generate(payload) } .toList() private fun generate(payload: TextToImagePayload) = when (preferenceManager.source) { ServerSource.HORDE -> hordeGenerationRepository.generateFromText(payload) - ServerSource.LOCAL_MICROSOFT_ONNX -> localDiffusionGenerationRepository.generateFromText(payload) + ServerSource.LOCAL_MICROSOFT_ONNX -> localDiffusionOnnxGenerationRepository.generateFromText(payload) ServerSource.HUGGING_FACE -> huggingFaceGenerationRepository.generateFromText(payload) ServerSource.AUTOMATIC1111 -> stableDiffusionGenerationRepository.generateFromText(payload) ServerSource.OPEN_AI -> openAiGenerationRepository.generateFromText(payload) ServerSource.STABILITY_AI -> stabilityAiGenerationRepository.generateFromText(payload) ServerSource.SWARM_UI -> swarmUiGenerationRepository.generateFromText(payload) ServerSource.LOCAL_GOOGLE_MEDIA_PIPE -> mediaPipeGenerationRepository.generateFromText(payload) + ServerSource.LOCAL_CPP -> localDiffusionCppGenerationRepository.generateFromText(payload) } } diff --git a/domain/src/main/java/com/shifthackz/aisdv1/domain/usecase/settings/ConnectToLocalDiffusionUseCase.kt b/domain/src/main/java/com/shifthackz/aisdv1/domain/usecase/settings/ConnectToLocalDiffusionCppUseCase.kt similarity index 77% rename from domain/src/main/java/com/shifthackz/aisdv1/domain/usecase/settings/ConnectToLocalDiffusionUseCase.kt rename to domain/src/main/java/com/shifthackz/aisdv1/domain/usecase/settings/ConnectToLocalDiffusionCppUseCase.kt index 31c7d1fa..7aba83b2 100644 --- a/domain/src/main/java/com/shifthackz/aisdv1/domain/usecase/settings/ConnectToLocalDiffusionUseCase.kt +++ b/domain/src/main/java/com/shifthackz/aisdv1/domain/usecase/settings/ConnectToLocalDiffusionCppUseCase.kt @@ -2,6 +2,6 @@ package com.shifthackz.aisdv1.domain.usecase.settings import io.reactivex.rxjava3.core.Single -interface ConnectToLocalDiffusionUseCase { +interface ConnectToLocalDiffusionCppUseCase { operator fun invoke(modelId: String): Single> } diff --git a/domain/src/main/java/com/shifthackz/aisdv1/domain/usecase/settings/ConnectToLocalDiffusionCppUseCaseImpl.kt b/domain/src/main/java/com/shifthackz/aisdv1/domain/usecase/settings/ConnectToLocalDiffusionCppUseCaseImpl.kt new file mode 100644 index 00000000..42a50e39 --- /dev/null +++ b/domain/src/main/java/com/shifthackz/aisdv1/domain/usecase/settings/ConnectToLocalDiffusionCppUseCaseImpl.kt @@ -0,0 +1,21 @@ +package com.shifthackz.aisdv1.domain.usecase.settings + +import com.shifthackz.aisdv1.domain.entity.ServerSource +import io.reactivex.rxjava3.core.Single + +internal class ConnectToLocalDiffusionCppUseCaseImpl( + private val getConfigurationUseCase: GetConfigurationUseCase, + private val setServerConfigurationUseCase: SetServerConfigurationUseCase, +) : ConnectToLocalDiffusionCppUseCase { + + override fun invoke(modelId: String) = getConfigurationUseCase() + .map { originalConfiguration -> + originalConfiguration.copy( + source = ServerSource.LOCAL_CPP, + localCppModelId = modelId, + ) + } + .flatMapCompletable(setServerConfigurationUseCase::invoke) + .andThen(Single.just(Result.success(Unit))) + .onErrorResumeNext { t -> Single.just(Result.failure(t)) } +} diff --git a/domain/src/main/java/com/shifthackz/aisdv1/domain/usecase/settings/ConnectToLocalDiffusionOnnxUseCase.kt b/domain/src/main/java/com/shifthackz/aisdv1/domain/usecase/settings/ConnectToLocalDiffusionOnnxUseCase.kt new file mode 100644 index 00000000..26943421 --- /dev/null +++ b/domain/src/main/java/com/shifthackz/aisdv1/domain/usecase/settings/ConnectToLocalDiffusionOnnxUseCase.kt @@ -0,0 +1,7 @@ +package com.shifthackz.aisdv1.domain.usecase.settings + +import io.reactivex.rxjava3.core.Single + +interface ConnectToLocalDiffusionOnnxUseCase { + operator fun invoke(modelId: String): Single> +} diff --git a/domain/src/main/java/com/shifthackz/aisdv1/domain/usecase/settings/ConnectToLocalDiffusionUseCaseImpl.kt b/domain/src/main/java/com/shifthackz/aisdv1/domain/usecase/settings/ConnectToLocalDiffusionOnnxUseCaseImpl.kt similarity index 88% rename from domain/src/main/java/com/shifthackz/aisdv1/domain/usecase/settings/ConnectToLocalDiffusionUseCaseImpl.kt rename to domain/src/main/java/com/shifthackz/aisdv1/domain/usecase/settings/ConnectToLocalDiffusionOnnxUseCaseImpl.kt index d2519425..fe50db9d 100644 --- a/domain/src/main/java/com/shifthackz/aisdv1/domain/usecase/settings/ConnectToLocalDiffusionUseCaseImpl.kt +++ b/domain/src/main/java/com/shifthackz/aisdv1/domain/usecase/settings/ConnectToLocalDiffusionOnnxUseCaseImpl.kt @@ -3,10 +3,10 @@ package com.shifthackz.aisdv1.domain.usecase.settings import com.shifthackz.aisdv1.domain.entity.ServerSource import io.reactivex.rxjava3.core.Single -internal class ConnectToLocalDiffusionUseCaseImpl( +internal class ConnectToLocalDiffusionOnnxUseCaseImpl( private val getConfigurationUseCase: GetConfigurationUseCase, private val setServerConfigurationUseCase: SetServerConfigurationUseCase, -) : ConnectToLocalDiffusionUseCase { +) : ConnectToLocalDiffusionOnnxUseCase { override fun invoke(modelId: String) = getConfigurationUseCase() .map { originalConfiguration -> diff --git a/domain/src/main/java/com/shifthackz/aisdv1/domain/usecase/settings/GetConfigurationUseCaseImpl.kt b/domain/src/main/java/com/shifthackz/aisdv1/domain/usecase/settings/GetConfigurationUseCaseImpl.kt index e88ebd63..dd5ef054 100644 --- a/domain/src/main/java/com/shifthackz/aisdv1/domain/usecase/settings/GetConfigurationUseCaseImpl.kt +++ b/domain/src/main/java/com/shifthackz/aisdv1/domain/usecase/settings/GetConfigurationUseCaseImpl.kt @@ -24,6 +24,8 @@ internal class GetConfigurationUseCaseImpl( stabilityAiApiKey = preferenceManager.stabilityAiApiKey, stabilityAiEngineId = preferenceManager.stabilityAiEngineId, authCredentials = authorizationStore.getAuthorizationCredentials(), + localCppModelId = preferenceManager.localCppModelId, + localCppModelPath = preferenceManager.localCppCustomModelPath, localOnnxModelId = preferenceManager.localOnnxModelId, localOnnxModelPath = preferenceManager.localOnnxCustomModelPath, localMediaPipeModelId = preferenceManager.localMediaPipeModelId, diff --git a/domain/src/main/java/com/shifthackz/aisdv1/domain/usecase/settings/SetServerConfigurationUseCaseImpl.kt b/domain/src/main/java/com/shifthackz/aisdv1/domain/usecase/settings/SetServerConfigurationUseCaseImpl.kt index aab29abc..2be05e14 100644 --- a/domain/src/main/java/com/shifthackz/aisdv1/domain/usecase/settings/SetServerConfigurationUseCaseImpl.kt +++ b/domain/src/main/java/com/shifthackz/aisdv1/domain/usecase/settings/SetServerConfigurationUseCaseImpl.kt @@ -24,6 +24,8 @@ internal class SetServerConfigurationUseCaseImpl( preferenceManager.huggingFaceModel = configuration.huggingFaceModel preferenceManager.stabilityAiApiKey = configuration.stabilityAiApiKey preferenceManager.stabilityAiEngineId = configuration.stabilityAiEngineId + preferenceManager.localCppModelId = configuration.localCppModelId + preferenceManager.localCppCustomModelPath = configuration.localCppModelPath preferenceManager.localOnnxModelId = configuration.localOnnxModelId preferenceManager.localOnnxCustomModelPath = configuration.localOnnxModelPath preferenceManager.localMediaPipeModelId = configuration.localMediaPipeModelId diff --git a/domain/src/test/java/com/shifthackz/aisdv1/domain/usecase/generation/InterruptGenerationUseCaseImplTest.kt b/domain/src/test/java/com/shifthackz/aisdv1/domain/usecase/generation/InterruptGenerationUseCaseImplTest.kt index 895bb8a8..0f753980 100644 --- a/domain/src/test/java/com/shifthackz/aisdv1/domain/usecase/generation/InterruptGenerationUseCaseImplTest.kt +++ b/domain/src/test/java/com/shifthackz/aisdv1/domain/usecase/generation/InterruptGenerationUseCaseImplTest.kt @@ -5,7 +5,7 @@ import com.nhaarman.mockitokotlin2.whenever import com.shifthackz.aisdv1.domain.entity.ServerSource import com.shifthackz.aisdv1.domain.preference.PreferenceManager import com.shifthackz.aisdv1.domain.repository.HordeGenerationRepository -import com.shifthackz.aisdv1.domain.repository.LocalDiffusionGenerationRepository +import com.shifthackz.aisdv1.domain.repository.LocalDiffusionOnnxGenerationRepository import com.shifthackz.aisdv1.domain.repository.StableDiffusionGenerationRepository import io.reactivex.rxjava3.core.Completable import org.junit.Test @@ -15,13 +15,13 @@ class InterruptGenerationUseCaseImplTest { private val stubException = Throwable("Can not interrupt generation.") private val stubStableDiffusionGenerationRepository = mock() private val stubHordeGenerationRepository = mock() - private val stubLocalDiffusionGenerationRepository = mock() + private val stubLocalDiffusionOnnxGenerationRepository = mock() private val stubPreferenceManager = mock() private val useCase = InterruptGenerationUseCaseImpl( stableDiffusionGenerationRepository = stubStableDiffusionGenerationRepository, hordeGenerationRepository = stubHordeGenerationRepository, - localDiffusionGenerationRepository = stubLocalDiffusionGenerationRepository, + localDiffusionOnnxGenerationRepository = stubLocalDiffusionOnnxGenerationRepository, preferenceManager = stubPreferenceManager, ) @@ -90,7 +90,7 @@ class InterruptGenerationUseCaseImplTest { whenever(stubPreferenceManager.source) .thenReturn(ServerSource.LOCAL_MICROSOFT_ONNX) - whenever(stubLocalDiffusionGenerationRepository.interruptGeneration()) + whenever(stubLocalDiffusionOnnxGenerationRepository.interruptGeneration()) .thenReturn(Completable.complete()) useCase() @@ -105,7 +105,7 @@ class InterruptGenerationUseCaseImplTest { whenever(stubPreferenceManager.source) .thenReturn(ServerSource.LOCAL_MICROSOFT_ONNX) - whenever(stubLocalDiffusionGenerationRepository.interruptGeneration()) + whenever(stubLocalDiffusionOnnxGenerationRepository.interruptGeneration()) .thenReturn(Completable.error(stubException)) useCase() diff --git a/domain/src/test/java/com/shifthackz/aisdv1/domain/usecase/generation/ObserveLocalDiffusionProcessStatusUseCaseImplTest.kt b/domain/src/test/java/com/shifthackz/aisdv1/domain/usecase/generation/ObserveLocalDiffusionONNXProcessStatusUseCaseImplTest.kt similarity index 91% rename from domain/src/test/java/com/shifthackz/aisdv1/domain/usecase/generation/ObserveLocalDiffusionProcessStatusUseCaseImplTest.kt rename to domain/src/test/java/com/shifthackz/aisdv1/domain/usecase/generation/ObserveLocalDiffusionONNXProcessStatusUseCaseImplTest.kt index 47836c4e..c67f3b43 100644 --- a/domain/src/test/java/com/shifthackz/aisdv1/domain/usecase/generation/ObserveLocalDiffusionProcessStatusUseCaseImplTest.kt +++ b/domain/src/test/java/com/shifthackz/aisdv1/domain/usecase/generation/ObserveLocalDiffusionONNXProcessStatusUseCaseImplTest.kt @@ -3,17 +3,17 @@ package com.shifthackz.aisdv1.domain.usecase.generation import com.nhaarman.mockitokotlin2.mock import com.nhaarman.mockitokotlin2.whenever import com.shifthackz.aisdv1.domain.entity.LocalDiffusionStatus -import com.shifthackz.aisdv1.domain.repository.LocalDiffusionGenerationRepository +import com.shifthackz.aisdv1.domain.repository.LocalDiffusionOnnxGenerationRepository import io.reactivex.rxjava3.core.Observable import io.reactivex.rxjava3.subjects.BehaviorSubject import org.junit.Before import org.junit.Test -class ObserveLocalDiffusionProcessStatusUseCaseImplTest { +class ObserveLocalDiffusionONNXProcessStatusUseCaseImplTest { private val stubException = Throwable("Error loading Local Diffusion.") private val stubLocalStatus = BehaviorSubject.create() - private val stubRepository = mock() + private val stubRepository = mock() private val useCase = ObserveLocalDiffusionProcessStatusUseCaseImpl(stubRepository) diff --git a/domain/src/test/java/com/shifthackz/aisdv1/domain/usecase/generation/TextToImageUseCaseImplTest.kt b/domain/src/test/java/com/shifthackz/aisdv1/domain/usecase/generation/TextToImageUseCaseImplTest.kt index 98505eac..69a2bfd6 100644 --- a/domain/src/test/java/com/shifthackz/aisdv1/domain/usecase/generation/TextToImageUseCaseImplTest.kt +++ b/domain/src/test/java/com/shifthackz/aisdv1/domain/usecase/generation/TextToImageUseCaseImplTest.kt @@ -9,7 +9,7 @@ import com.shifthackz.aisdv1.domain.mocks.mockTextToImagePayload import com.shifthackz.aisdv1.domain.preference.PreferenceManager import com.shifthackz.aisdv1.domain.repository.HordeGenerationRepository import com.shifthackz.aisdv1.domain.repository.HuggingFaceGenerationRepository -import com.shifthackz.aisdv1.domain.repository.LocalDiffusionGenerationRepository +import com.shifthackz.aisdv1.domain.repository.LocalDiffusionOnnxGenerationRepository import com.shifthackz.aisdv1.domain.repository.MediaPipeGenerationRepository import com.shifthackz.aisdv1.domain.repository.OpenAiGenerationRepository import com.shifthackz.aisdv1.domain.repository.StabilityAiGenerationRepository @@ -27,7 +27,7 @@ class TextToImageUseCaseImplTest { private val stubOpenAiGenerationRepository = mock() private val stubStabilityAiGenerationRepository = mock() private val stubSwarmUiGenerationRepository = mock() - private val stubLocalDiffusionGenerationRepository = mock() + private val stubLocalDiffusionOnnxGenerationRepository = mock() private val stubMediaPipeGenerationRepository = mock() private val stubPreferenceManager = mock() @@ -37,7 +37,7 @@ class TextToImageUseCaseImplTest { huggingFaceGenerationRepository = stubHuggingFaceGenerationRepository, openAiGenerationRepository = stubOpenAiGenerationRepository, stabilityAiGenerationRepository = stubStabilityAiGenerationRepository, - localDiffusionGenerationRepository = stubLocalDiffusionGenerationRepository, + localDiffusionOnnxGenerationRepository = stubLocalDiffusionOnnxGenerationRepository, swarmUiGenerationRepository = stubSwarmUiGenerationRepository, mediaPipeGenerationRepository = stubMediaPipeGenerationRepository, preferenceManager = stubPreferenceManager, @@ -368,7 +368,7 @@ class TextToImageUseCaseImplTest { whenever(stubPreferenceManager.source) .thenReturn(ServerSource.LOCAL_MICROSOFT_ONNX) - whenever(stubLocalDiffusionGenerationRepository.generateFromText(any())) + whenever(stubLocalDiffusionOnnxGenerationRepository.generateFromText(any())) .thenReturn(Single.just(mockAiGenerationResult)) val stubBatchCount = 1 @@ -391,7 +391,7 @@ class TextToImageUseCaseImplTest { whenever(stubPreferenceManager.source) .thenReturn(ServerSource.LOCAL_MICROSOFT_ONNX) - whenever(stubLocalDiffusionGenerationRepository.generateFromText(any())) + whenever(stubLocalDiffusionOnnxGenerationRepository.generateFromText(any())) .thenReturn(Single.just(mockAiGenerationResult)) val stubBatchCount = 10 @@ -414,7 +414,7 @@ class TextToImageUseCaseImplTest { whenever(stubPreferenceManager.source) .thenReturn(ServerSource.LOCAL_MICROSOFT_ONNX) - whenever(stubLocalDiffusionGenerationRepository.generateFromText(any())) + whenever(stubLocalDiffusionOnnxGenerationRepository.generateFromText(any())) .thenReturn(Single.error(stubException)) val stubBatchCount = 1 diff --git a/domain/src/test/java/com/shifthackz/aisdv1/domain/usecase/settings/ConnectToLocalDiffusionUseCaseImplTest.kt b/domain/src/test/java/com/shifthackz/aisdv1/domain/usecase/settings/ConnectToLocalDiffusionONNXUseCaseImplTest.kt similarity index 93% rename from domain/src/test/java/com/shifthackz/aisdv1/domain/usecase/settings/ConnectToLocalDiffusionUseCaseImplTest.kt rename to domain/src/test/java/com/shifthackz/aisdv1/domain/usecase/settings/ConnectToLocalDiffusionONNXUseCaseImplTest.kt index 33a09c57..0785ba32 100644 --- a/domain/src/test/java/com/shifthackz/aisdv1/domain/usecase/settings/ConnectToLocalDiffusionUseCaseImplTest.kt +++ b/domain/src/test/java/com/shifthackz/aisdv1/domain/usecase/settings/ConnectToLocalDiffusionONNXUseCaseImplTest.kt @@ -7,13 +7,13 @@ import io.reactivex.rxjava3.core.Completable import io.reactivex.rxjava3.core.Single import org.junit.Test -class ConnectToLocalDiffusionUseCaseImplTest { +class ConnectToLocalDiffusionONNXUseCaseImplTest { private val stubThrowable = Throwable("Something went wrong.") private val stubGetConfigurationUseCase = mockk() private val stubSetServerConfigurationUseCase = mockk() - private val useCase = ConnectToLocalDiffusionUseCaseImpl( + private val useCase = ConnectToLocalDiffusionOnnxUseCaseImpl( getConfigurationUseCase = stubGetConfigurationUseCase, setServerConfigurationUseCase = stubSetServerConfigurationUseCase, ) diff --git a/feature/diffusion/src/main/java/com/shifthackz/aisdv1/feature/diffusion/di/DiffusionModule.kt b/feature/diffusion/src/main/java/com/shifthackz/aisdv1/feature/diffusion/di/DiffusionModule.kt deleted file mode 100644 index 4aa7ef14..00000000 --- a/feature/diffusion/src/main/java/com/shifthackz/aisdv1/feature/diffusion/di/DiffusionModule.kt +++ /dev/null @@ -1,19 +0,0 @@ -package com.shifthackz.aisdv1.feature.diffusion.di - -import com.shifthackz.aisdv1.domain.feature.diffusion.LocalDiffusion -import com.shifthackz.aisdv1.feature.diffusion.LocalDiffusionImpl -import com.shifthackz.aisdv1.feature.diffusion.ai.tokenizer.EnglishTextTokenizer -import com.shifthackz.aisdv1.feature.diffusion.ai.tokenizer.LocalDiffusionTextTokenizer -import com.shifthackz.aisdv1.feature.diffusion.ai.unet.UNet -import com.shifthackz.aisdv1.feature.diffusion.environment.OrtEnvironmentProvider -import com.shifthackz.aisdv1.feature.diffusion.environment.OrtEnvironmentProviderImpl -import org.koin.core.module.dsl.singleOf -import org.koin.dsl.bind -import org.koin.dsl.module - -val diffusionModule = module { - singleOf(::UNet) - singleOf(::EnglishTextTokenizer) bind LocalDiffusionTextTokenizer::class - singleOf(::LocalDiffusionImpl) bind LocalDiffusion::class - singleOf(::OrtEnvironmentProviderImpl) bind OrtEnvironmentProvider::class -} diff --git a/feature/diffusion/src/main/java/com/shifthackz/aisdv1/feature/diffusion/entity/LocalDiffusionAlias.kt b/feature/diffusion/src/main/java/com/shifthackz/aisdv1/feature/diffusion/entity/LocalDiffusionAlias.kt deleted file mode 100644 index debe29cb..00000000 --- a/feature/diffusion/src/main/java/com/shifthackz/aisdv1/feature/diffusion/entity/LocalDiffusionAlias.kt +++ /dev/null @@ -1,3 +0,0 @@ -package com.shifthackz.aisdv1.feature.diffusion.entity - -typealias Array3D = Array>> diff --git a/feature/diffusion/src/main/java/com/shifthackz/aisdv1/feature/diffusion/environment/DeviceNNAPIFlagProvider.kt b/feature/diffusion/src/main/java/com/shifthackz/aisdv1/feature/diffusion/environment/DeviceNNAPIFlagProvider.kt deleted file mode 100644 index 692b24b2..00000000 --- a/feature/diffusion/src/main/java/com/shifthackz/aisdv1/feature/diffusion/environment/DeviceNNAPIFlagProvider.kt +++ /dev/null @@ -1,5 +0,0 @@ -package com.shifthackz.aisdv1.feature.diffusion.environment - -fun interface DeviceNNAPIFlagProvider { - fun get(): Int -} diff --git a/feature/diffusion/src/main/java/com/shifthackz/aisdv1/feature/diffusion/environment/LocalModelIdProvider.kt b/feature/diffusion/src/main/java/com/shifthackz/aisdv1/feature/diffusion/environment/LocalModelIdProvider.kt deleted file mode 100644 index 3c24f58b..00000000 --- a/feature/diffusion/src/main/java/com/shifthackz/aisdv1/feature/diffusion/environment/LocalModelIdProvider.kt +++ /dev/null @@ -1,5 +0,0 @@ -package com.shifthackz.aisdv1.feature.diffusion.environment - -fun interface LocalModelIdProvider { - fun get(): String -} diff --git a/feature/diffusion/.gitignore b/feature/local-diffusion-cpp/.gitignore similarity index 100% rename from feature/diffusion/.gitignore rename to feature/local-diffusion-cpp/.gitignore diff --git a/feature/local-diffusion-cpp/build.gradle.kts b/feature/local-diffusion-cpp/build.gradle.kts new file mode 100644 index 00000000..cabef098 --- /dev/null +++ b/feature/local-diffusion-cpp/build.gradle.kts @@ -0,0 +1,20 @@ +plugins { + alias(libs.plugins.generic.library) +} + +android { + namespace = "com.shifthackz.aisdv1.feature.localdiffusion.cpp" + + sourceSets { + getByName("main") { + jniLibs.srcDirs("src/main/jniLibs") + } + } +} + +dependencies { + implementation(project(":core:common")) + implementation(project(":domain")) + implementation(libs.koin.core) + implementation(libs.rx.kotlin) +} diff --git a/feature/diffusion/consumer-rules.pro b/feature/local-diffusion-cpp/consumer-rules.pro similarity index 100% rename from feature/diffusion/consumer-rules.pro rename to feature/local-diffusion-cpp/consumer-rules.pro diff --git a/feature/local-diffusion-cpp/proguard-rules.pro b/feature/local-diffusion-cpp/proguard-rules.pro new file mode 100644 index 00000000..481bb434 --- /dev/null +++ b/feature/local-diffusion-cpp/proguard-rules.pro @@ -0,0 +1,21 @@ +# Add project specific ProGuard rules here. +# You can control the set of applied configuration files using the +# proguardFiles setting in build.gradle. +# +# For more details, see +# http://developer.android.com/guide/developing/tools/proguard.html + +# If your project uses WebView with JS, uncomment the following +# and specify the fully qualified class name to the JavaScript interface +# class: +#-keepclassmembers class fqcn.of.javascript.interface.for.webview { +# public *; +#} + +# Uncomment this to preserve the line number information for +# debugging stack traces. +#-keepattributes SourceFile,LineNumberTable + +# If you keep the line number information, uncomment this to +# hide the original source file name. +#-renamesourcefileattribute SourceFile \ No newline at end of file diff --git a/feature/diffusion/src/main/AndroidManifest.xml b/feature/local-diffusion-cpp/src/main/AndroidManifest.xml similarity index 100% rename from feature/diffusion/src/main/AndroidManifest.xml rename to feature/local-diffusion-cpp/src/main/AndroidManifest.xml diff --git a/feature/local-diffusion-cpp/src/main/java/com/shifthackz/aisdv1/feature/localdiffusion/cpp/LibStableDiffusion.kt b/feature/local-diffusion-cpp/src/main/java/com/shifthackz/aisdv1/feature/localdiffusion/cpp/LibStableDiffusion.kt new file mode 100644 index 00000000..95d84234 --- /dev/null +++ b/feature/local-diffusion-cpp/src/main/java/com/shifthackz/aisdv1/feature/localdiffusion/cpp/LibStableDiffusion.kt @@ -0,0 +1,75 @@ +package com.shifthackz.aisdv1.feature.localdiffusion.cpp + +import com.shifthackz.aisdv1.feature.localdiffusion.cpp.entity.Backend +import com.shifthackz.aisdv1.feature.localdiffusion.cpp.entity.SDImage +import com.shifthackz.aisdv1.feature.localdiffusion.cpp.entity.SDLogCallback +import com.shifthackz.aisdv1.feature.localdiffusion.cpp.entity.SDProgressCallback +import com.shifthackz.aisdv1.feature.localdiffusion.cpp.entity.SampleMethod + +internal class LibStableDiffusion(backend: Backend) { + + init { + System.loadLibrary(backend.libraryName) + } + + external fun getNumPhysicalCores(): Int + + external fun sdSetLogCallback(sdLogCb: SDLogCallback, data: Long) + + external fun sdSetProgressCallback(sdProgressCb: SDProgressCallback, data: Long) + + external fun newSdContext( + modelPath: String, + clipLPath: String, + clipGPath: String, + t5xxlPath: String, + diffusionModelPath: String, + vaePath: String, + taesdPath: String, + controlNetPathCStr: String, + loraModelDir: String, + embedDirCStr: String, + stackedIdEmbedDirCStr: String, + vaeDecodeOnly: Boolean, + vaeTiling: Boolean, + freeParamsImmediately: Boolean, + nThreads: Int, + wtype: Int, + rngType: Int, + scheduleType: Int, + keepClipOnCpu: Boolean, + keepControlNetCpu: Boolean, + keepVaeOnCpu: Boolean, + diffusionFlashAttn: Boolean + ): Long + + external fun txt2img( + sdCtx: Long, + prompt: String, + negativePrompt: String, + clipSkip: Int, + cfgScale: Float, + guidance: Float, + eta: Float, + width: Int, + height: Int, + sampleMethod: SampleMethod, + sampleSteps: Int, + seed: Long, + batchCount: Int, + controlCond: SDImage?, + controlStrength: Float, + styleStrength: Float, + normalizeInput: Boolean, + inputIdImagesPath: String?, + skipLayers: IntArray?, + skipLayersCount: Long, + slgScale: Float, + skipLayerStart: Float, + skipLayerEnd: Float + ): SDImage + + external fun freeSdCtx(ctx: Long) + + external fun freeUpscalerCtx(ctx: Long) +} diff --git a/feature/local-diffusion-cpp/src/main/java/com/shifthackz/aisdv1/feature/localdiffusion/cpp/LocalDiffusionCppImpl.kt b/feature/local-diffusion-cpp/src/main/java/com/shifthackz/aisdv1/feature/localdiffusion/cpp/LocalDiffusionCppImpl.kt new file mode 100644 index 00000000..44ba5b7b --- /dev/null +++ b/feature/local-diffusion-cpp/src/main/java/com/shifthackz/aisdv1/feature/localdiffusion/cpp/LocalDiffusionCppImpl.kt @@ -0,0 +1,104 @@ +package com.shifthackz.aisdv1.feature.localdiffusion.cpp + +import android.graphics.Bitmap +import android.graphics.BitmapFactory +import com.shifthackz.aisdv1.core.common.file.FileProviderDescriptor +import com.shifthackz.aisdv1.core.common.schedulers.SchedulersProvider +import com.shifthackz.aisdv1.domain.entity.TextToImagePayload +import com.shifthackz.aisdv1.domain.feature.diffusion.LocalDiffusionCpp +import com.shifthackz.aisdv1.domain.preference.PreferenceManager +import com.shifthackz.aisdv1.feature.localdiffusion.cpp.entity.Backend +import com.shifthackz.aisdv1.feature.localdiffusion.cpp.entity.SDImage +import com.shifthackz.aisdv1.feature.localdiffusion.cpp.entity.SDType +import com.shifthackz.aisdv1.feature.localdiffusion.cpp.entity.SampleMethod +import com.shifthackz.aisdv1.feature.localdiffusion.cpp.environment.LocalCppModelIdProvider +import com.shifthackz.aisdv1.feature.localdiffusion.cpp.extensions.modelPathPrefix +import io.reactivex.rxjava3.core.Single + +internal class LocalDiffusionCppImpl( + private val preferenceManager: PreferenceManager, + private val fileProviderDescriptor: FileProviderDescriptor, + private val localCppModelIdProvider: LocalCppModelIdProvider, + private val schedulersProvider: SchedulersProvider, +) : LocalDiffusionCpp { + + override fun process(payload: TextToImagePayload): Single = Single.fromCallable { + val libStableDiffusion = LibStableDiffusion(Backend.OpenCL) + + val ctx = libStableDiffusion.newSdContext( + modelPath = modelPathPrefix(preferenceManager, fileProviderDescriptor, localCppModelIdProvider) + "model.safetensors", + clipLPath = "", + clipGPath = "", + t5xxlPath = "", + diffusionModelPath = "", + vaePath = "", + taesdPath = "", + controlNetPathCStr = "", + loraModelDir = "", + embedDirCStr = "", + stackedIdEmbedDirCStr = "", + vaeDecodeOnly = false, + vaeTiling = false, + freeParamsImmediately = false, + nThreads = libStableDiffusion.getNumPhysicalCores() * 2, + wtype = SDType.NONE.ordinal, + rngType = 0, + scheduleType = 0, + keepClipOnCpu = false, + keepControlNetCpu = false, + keepVaeOnCpu = false, + diffusionFlashAttn = false + ) + + val result = libStableDiffusion.txt2img( + sdCtx = ctx, + prompt = payload.prompt, + negativePrompt = payload.negativePrompt, + clipSkip = 2, + cfgScale = payload.cfgScale, + guidance = 0f, + eta = 0f, + width = payload.width, + height = payload.height, + sampleMethod = SampleMethod.EULER, + sampleSteps = payload.samplingSteps, + seed = payload.seed.toLongOrNull() ?: 0L, + batchCount = 1, + controlCond = null, + controlStrength = 1.0f, + styleStrength = 0f, + normalizeInput = false, + inputIdImagesPath = null, + skipLayers = null, + skipLayersCount = 0, + slgScale = 0f, + skipLayerStart = 0f, + skipLayerEnd = 0f, + ) + + val bitmap = convertSDImageToBitmap(result) + + bitmap + } + + private fun convertSDImageToBitmap(sdImage: SDImage): Bitmap { + if (sdImage.channel != 3) { + throw IllegalArgumentException("Unexpected channel count: ${sdImage.channel}") + } + + val width = sdImage.width + val height = sdImage.height + val bytes = sdImage.data + + val rgbaBytes = ByteArray(width * height * 4) + + for (i in 0 until width * height) { + rgbaBytes[i * 4] = bytes[i * 3] + rgbaBytes[i * 4 + 1] = bytes[i * 3 + 1] + rgbaBytes[i * 4 + 2] = bytes[i * 3 + 2] + rgbaBytes[i * 4 + 3] = 255.toByte() + } + + return BitmapFactory.decodeByteArray(rgbaBytes, 0, rgbaBytes.size) + } +} diff --git a/feature/local-diffusion-cpp/src/main/java/com/shifthackz/aisdv1/feature/localdiffusion/cpp/di/CppLocalDiffusionModule.kt b/feature/local-diffusion-cpp/src/main/java/com/shifthackz/aisdv1/feature/localdiffusion/cpp/di/CppLocalDiffusionModule.kt new file mode 100644 index 00000000..ff35b4b1 --- /dev/null +++ b/feature/local-diffusion-cpp/src/main/java/com/shifthackz/aisdv1/feature/localdiffusion/cpp/di/CppLocalDiffusionModule.kt @@ -0,0 +1,11 @@ +package com.shifthackz.aisdv1.feature.localdiffusion.cpp.di + +import com.shifthackz.aisdv1.domain.feature.diffusion.LocalDiffusionCpp +import com.shifthackz.aisdv1.feature.localdiffusion.cpp.LocalDiffusionCppImpl +import org.koin.core.module.dsl.factoryOf +import org.koin.dsl.bind +import org.koin.dsl.module + +val cppLocalDiffusionModule = module { + factoryOf(::LocalDiffusionCppImpl) bind LocalDiffusionCpp::class +} diff --git a/feature/local-diffusion-cpp/src/main/java/com/shifthackz/aisdv1/feature/localdiffusion/cpp/entity/Backend.kt b/feature/local-diffusion-cpp/src/main/java/com/shifthackz/aisdv1/feature/localdiffusion/cpp/entity/Backend.kt new file mode 100644 index 00000000..95573307 --- /dev/null +++ b/feature/local-diffusion-cpp/src/main/java/com/shifthackz/aisdv1/feature/localdiffusion/cpp/entity/Backend.kt @@ -0,0 +1,7 @@ +package com.shifthackz.aisdv1.feature.localdiffusion.cpp.entity + +enum class Backend(val libraryName: String) { + Vulkan("stable-diffusion_vulkan"), + OpenCL("stable-diffusion_opencl"), + Cpu("stable-diffusion"); +} diff --git a/feature/local-diffusion-cpp/src/main/java/com/shifthackz/aisdv1/feature/localdiffusion/cpp/entity/SDImage.kt b/feature/local-diffusion-cpp/src/main/java/com/shifthackz/aisdv1/feature/localdiffusion/cpp/entity/SDImage.kt new file mode 100644 index 00000000..935bfdd4 --- /dev/null +++ b/feature/local-diffusion-cpp/src/main/java/com/shifthackz/aisdv1/feature/localdiffusion/cpp/entity/SDImage.kt @@ -0,0 +1,30 @@ +package com.shifthackz.aisdv1.feature.localdiffusion.cpp.entity + +data class SDImage( + val width: Int, + val height: Int, + val channel: Int, + val data: ByteArray, +) { + override fun equals(other: Any?): Boolean { + if (this === other) return true + if (javaClass != other?.javaClass) return false + + other as SDImage + + if (width != other.width) return false + if (height != other.height) return false + if (channel != other.channel) return false + if (!data.contentEquals(other.data)) return false + + return true + } + + override fun hashCode(): Int { + var result = width + result = 31 * result + height + result = 31 * result + channel + result = 31 * result + data.contentHashCode() + return result + } +} diff --git a/feature/local-diffusion-cpp/src/main/java/com/shifthackz/aisdv1/feature/localdiffusion/cpp/entity/SDLogCallback.kt b/feature/local-diffusion-cpp/src/main/java/com/shifthackz/aisdv1/feature/localdiffusion/cpp/entity/SDLogCallback.kt new file mode 100644 index 00000000..7367cc75 --- /dev/null +++ b/feature/local-diffusion-cpp/src/main/java/com/shifthackz/aisdv1/feature/localdiffusion/cpp/entity/SDLogCallback.kt @@ -0,0 +1,3 @@ +package com.shifthackz.aisdv1.feature.localdiffusion.cpp.entity + +typealias SDLogCallback = (level: Int, text: String, data: Long) -> Unit diff --git a/feature/local-diffusion-cpp/src/main/java/com/shifthackz/aisdv1/feature/localdiffusion/cpp/entity/SDProgressCallback.kt b/feature/local-diffusion-cpp/src/main/java/com/shifthackz/aisdv1/feature/localdiffusion/cpp/entity/SDProgressCallback.kt new file mode 100644 index 00000000..9d7e7312 --- /dev/null +++ b/feature/local-diffusion-cpp/src/main/java/com/shifthackz/aisdv1/feature/localdiffusion/cpp/entity/SDProgressCallback.kt @@ -0,0 +1,3 @@ +package com.shifthackz.aisdv1.feature.localdiffusion.cpp.entity + +typealias SDProgressCallback = (step: Int, steps: Int, time: Float, data: Long) -> Unit diff --git a/feature/local-diffusion-cpp/src/main/java/com/shifthackz/aisdv1/feature/localdiffusion/cpp/entity/SDType.kt b/feature/local-diffusion-cpp/src/main/java/com/shifthackz/aisdv1/feature/localdiffusion/cpp/entity/SDType.kt new file mode 100644 index 00000000..1c033e16 --- /dev/null +++ b/feature/local-diffusion-cpp/src/main/java/com/shifthackz/aisdv1/feature/localdiffusion/cpp/entity/SDType.kt @@ -0,0 +1,17 @@ +package com.shifthackz.aisdv1.feature.localdiffusion.cpp.entity + +internal enum class SDType { + NONE, + SD_TYPE_Q8_0, + SD_TYPE_Q8_1, + SD_TYPE_Q8_K, + SD_TYPE_Q6_K, + SD_TYPE_Q5_0, + SD_TYPE_Q5_1, + SD_TYPE_Q5_K, + SD_TYPE_Q4_0, + SD_TYPE_Q4_1, + SD_TYPE_Q4_K, + SD_TYPE_Q3_K, + SD_TYPE_Q2_K, +} diff --git a/feature/local-diffusion-cpp/src/main/java/com/shifthackz/aisdv1/feature/localdiffusion/cpp/entity/SampleMethod.kt b/feature/local-diffusion-cpp/src/main/java/com/shifthackz/aisdv1/feature/localdiffusion/cpp/entity/SampleMethod.kt new file mode 100644 index 00000000..b626656e --- /dev/null +++ b/feature/local-diffusion-cpp/src/main/java/com/shifthackz/aisdv1/feature/localdiffusion/cpp/entity/SampleMethod.kt @@ -0,0 +1,24 @@ +package com.shifthackz.aisdv1.feature.localdiffusion.cpp.entity + +enum class SampleMethod(val value: Int) { + EULER_A(0), + EULER(1), + HEUN(2), + DPM2(3), + DPMPP2S_A(4), + DPMPP2M(5), + DPMPP2Mv2(6), + IPNDM(7), + IPNDM_V(8), + LCM(9), + DDIM_TRAILING(10), + TCD(11), + N_SAMPLE_METHODS(12); + + companion object { + fun fromInt(value: Int): SampleMethod { + return entries.firstOrNull { it.value == value } + ?: throw IllegalArgumentException("Unknown SampleMethod value: $value") + } + } +} diff --git a/feature/local-diffusion-cpp/src/main/java/com/shifthackz/aisdv1/feature/localdiffusion/cpp/environment/LocalCppModelIdProvider.kt b/feature/local-diffusion-cpp/src/main/java/com/shifthackz/aisdv1/feature/localdiffusion/cpp/environment/LocalCppModelIdProvider.kt new file mode 100644 index 00000000..9d14c52f --- /dev/null +++ b/feature/local-diffusion-cpp/src/main/java/com/shifthackz/aisdv1/feature/localdiffusion/cpp/environment/LocalCppModelIdProvider.kt @@ -0,0 +1,5 @@ +package com.shifthackz.aisdv1.feature.localdiffusion.cpp.environment + +fun interface LocalCppModelIdProvider { + fun get(): String +} diff --git a/feature/local-diffusion-cpp/src/main/java/com/shifthackz/aisdv1/feature/localdiffusion/cpp/extensions/LocalDiffusionCppPaths.kt b/feature/local-diffusion-cpp/src/main/java/com/shifthackz/aisdv1/feature/localdiffusion/cpp/extensions/LocalDiffusionCppPaths.kt new file mode 100644 index 00000000..034e75b7 --- /dev/null +++ b/feature/local-diffusion-cpp/src/main/java/com/shifthackz/aisdv1/feature/localdiffusion/cpp/extensions/LocalDiffusionCppPaths.kt @@ -0,0 +1,19 @@ +package com.shifthackz.aisdv1.feature.localdiffusion.cpp.extensions + +import com.shifthackz.aisdv1.core.common.file.FileProviderDescriptor +import com.shifthackz.aisdv1.domain.entity.LocalAiModel +import com.shifthackz.aisdv1.domain.preference.PreferenceManager +import com.shifthackz.aisdv1.feature.localdiffusion.cpp.environment.LocalCppModelIdProvider + +internal fun modelPathPrefix( + preferenceManager: PreferenceManager, + fileProviderDescriptor: FileProviderDescriptor, + localCppModelIdProvider: LocalCppModelIdProvider, +): String { + val modelId = localCppModelIdProvider.get() + return if (modelId == LocalAiModel.CustomOnnx.id) { + preferenceManager.localOnnxCustomModelPath + } else { + "${fileProviderDescriptor.localModelDirPath}/${modelId}" + } +} diff --git a/feature/local-diffusion-cpp/src/main/jniLibs/arm64-v8a/libOpenCL.so b/feature/local-diffusion-cpp/src/main/jniLibs/arm64-v8a/libOpenCL.so new file mode 100644 index 00000000..7ab23be3 Binary files /dev/null and b/feature/local-diffusion-cpp/src/main/jniLibs/arm64-v8a/libOpenCL.so differ diff --git a/feature/local-diffusion-cpp/src/main/jniLibs/arm64-v8a/libc.so b/feature/local-diffusion-cpp/src/main/jniLibs/arm64-v8a/libc.so new file mode 100644 index 00000000..d7899aa6 Binary files /dev/null and b/feature/local-diffusion-cpp/src/main/jniLibs/arm64-v8a/libc.so differ diff --git a/feature/local-diffusion-cpp/src/main/jniLibs/arm64-v8a/libdl.so b/feature/local-diffusion-cpp/src/main/jniLibs/arm64-v8a/libdl.so new file mode 100644 index 00000000..cfd27309 Binary files /dev/null and b/feature/local-diffusion-cpp/src/main/jniLibs/arm64-v8a/libdl.so differ diff --git a/feature/local-diffusion-cpp/src/main/jniLibs/arm64-v8a/libm.so b/feature/local-diffusion-cpp/src/main/jniLibs/arm64-v8a/libm.so new file mode 100644 index 00000000..920c65f6 Binary files /dev/null and b/feature/local-diffusion-cpp/src/main/jniLibs/arm64-v8a/libm.so differ diff --git a/feature/local-diffusion-cpp/src/main/jniLibs/arm64-v8a/libomp.so b/feature/local-diffusion-cpp/src/main/jniLibs/arm64-v8a/libomp.so new file mode 100644 index 00000000..85dd2801 Binary files /dev/null and b/feature/local-diffusion-cpp/src/main/jniLibs/arm64-v8a/libomp.so differ diff --git a/feature/local-diffusion-cpp/src/main/jniLibs/arm64-v8a/libstable-diffusion.so b/feature/local-diffusion-cpp/src/main/jniLibs/arm64-v8a/libstable-diffusion.so new file mode 100644 index 00000000..6e214e4a Binary files /dev/null and b/feature/local-diffusion-cpp/src/main/jniLibs/arm64-v8a/libstable-diffusion.so differ diff --git a/feature/local-diffusion-cpp/src/main/jniLibs/arm64-v8a/libstable-diffusion_opencl.so b/feature/local-diffusion-cpp/src/main/jniLibs/arm64-v8a/libstable-diffusion_opencl.so new file mode 100644 index 00000000..8a1a7f0a Binary files /dev/null and b/feature/local-diffusion-cpp/src/main/jniLibs/arm64-v8a/libstable-diffusion_opencl.so differ diff --git a/feature/local-diffusion-cpp/src/main/jniLibs/arm64-v8a/libstable-diffusion_vulkan.so b/feature/local-diffusion-cpp/src/main/jniLibs/arm64-v8a/libstable-diffusion_vulkan.so new file mode 100644 index 00000000..80bcfbfb Binary files /dev/null and b/feature/local-diffusion-cpp/src/main/jniLibs/arm64-v8a/libstable-diffusion_vulkan.so differ diff --git a/feature/local-diffusion-onnx/.gitignore b/feature/local-diffusion-onnx/.gitignore new file mode 100644 index 00000000..42afabfd --- /dev/null +++ b/feature/local-diffusion-onnx/.gitignore @@ -0,0 +1 @@ +/build \ No newline at end of file diff --git a/feature/diffusion/build.gradle.kts b/feature/local-diffusion-onnx/build.gradle.kts similarity index 81% rename from feature/diffusion/build.gradle.kts rename to feature/local-diffusion-onnx/build.gradle.kts index 03c679e4..2cf6a5c0 100644 --- a/feature/diffusion/build.gradle.kts +++ b/feature/local-diffusion-onnx/build.gradle.kts @@ -3,7 +3,7 @@ plugins { } android { - namespace = "com.shifthackz.aisdv1.feature.diffusion" + namespace = "com.shifthackz.aisdv1.feature.localdiffusion.onnx" } dependencies { diff --git a/feature/local-diffusion-onnx/consumer-rules.pro b/feature/local-diffusion-onnx/consumer-rules.pro new file mode 100644 index 00000000..e69de29b diff --git a/feature/diffusion/proguard-rules.pro b/feature/local-diffusion-onnx/proguard-rules.pro similarity index 100% rename from feature/diffusion/proguard-rules.pro rename to feature/local-diffusion-onnx/proguard-rules.pro diff --git a/feature/local-diffusion-onnx/src/main/AndroidManifest.xml b/feature/local-diffusion-onnx/src/main/AndroidManifest.xml new file mode 100644 index 00000000..8072ee00 --- /dev/null +++ b/feature/local-diffusion-onnx/src/main/AndroidManifest.xml @@ -0,0 +1,2 @@ + + diff --git a/feature/diffusion/src/main/java/com/shifthackz/aisdv1/feature/diffusion/LocalDiffusionContract.kt b/feature/local-diffusion-onnx/src/main/java/com/shifthackz/aisdv1/feature/localdiffusion/onnx/LocalDiffusionContract.kt similarity index 96% rename from feature/diffusion/src/main/java/com/shifthackz/aisdv1/feature/diffusion/LocalDiffusionContract.kt rename to feature/local-diffusion-onnx/src/main/java/com/shifthackz/aisdv1/feature/localdiffusion/onnx/LocalDiffusionContract.kt index a59a7afc..e912f715 100644 --- a/feature/diffusion/src/main/java/com/shifthackz/aisdv1/feature/diffusion/LocalDiffusionContract.kt +++ b/feature/local-diffusion-onnx/src/main/java/com/shifthackz/aisdv1/feature/localdiffusion/onnx/LocalDiffusionContract.kt @@ -1,6 +1,6 @@ @file:Suppress("SpellCheckingInspection") -package com.shifthackz.aisdv1.feature.diffusion +package com.shifthackz.aisdv1.feature.localdiffusion.onnx internal object LocalDiffusionContract { //region LOGGING diff --git a/feature/diffusion/src/main/java/com/shifthackz/aisdv1/feature/diffusion/LocalDiffusionImpl.kt b/feature/local-diffusion-onnx/src/main/java/com/shifthackz/aisdv1/feature/localdiffusion/onnx/LocalDiffusionONNXImpl.kt similarity index 89% rename from feature/diffusion/src/main/java/com/shifthackz/aisdv1/feature/diffusion/LocalDiffusionImpl.kt rename to feature/local-diffusion-onnx/src/main/java/com/shifthackz/aisdv1/feature/localdiffusion/onnx/LocalDiffusionONNXImpl.kt index 775cad2c..7edebf26 100644 --- a/feature/diffusion/src/main/java/com/shifthackz/aisdv1/feature/diffusion/LocalDiffusionImpl.kt +++ b/feature/local-diffusion-onnx/src/main/java/com/shifthackz/aisdv1/feature/localdiffusion/onnx/LocalDiffusionONNXImpl.kt @@ -1,4 +1,4 @@ -package com.shifthackz.aisdv1.feature.diffusion +package com.shifthackz.aisdv1.feature.localdiffusion.onnx import ai.onnxruntime.OnnxTensor import android.graphics.Bitmap @@ -6,20 +6,20 @@ import com.shifthackz.aisdv1.core.common.log.debugLog import com.shifthackz.aisdv1.core.common.log.errorLog import com.shifthackz.aisdv1.domain.entity.LocalDiffusionStatus import com.shifthackz.aisdv1.domain.entity.TextToImagePayload -import com.shifthackz.aisdv1.domain.feature.diffusion.LocalDiffusion -import com.shifthackz.aisdv1.feature.diffusion.LocalDiffusionContract.TAG -import com.shifthackz.aisdv1.feature.diffusion.ai.tokenizer.LocalDiffusionTextTokenizer -import com.shifthackz.aisdv1.feature.diffusion.ai.unet.UNet -import com.shifthackz.aisdv1.feature.diffusion.environment.OrtEnvironmentProvider +import com.shifthackz.aisdv1.domain.feature.diffusion.LocalDiffusionONNX +import com.shifthackz.aisdv1.feature.localdiffusion.onnx.LocalDiffusionContract.TAG +import com.shifthackz.aisdv1.feature.localdiffusion.onnx.ai.tokenizer.LocalDiffusionTextTokenizer +import com.shifthackz.aisdv1.feature.localdiffusion.onnx.ai.unet.UNet +import com.shifthackz.aisdv1.feature.localdiffusion.onnx.environment.OrtEnvironmentProvider import io.reactivex.rxjava3.core.Completable import io.reactivex.rxjava3.core.Single import io.reactivex.rxjava3.subjects.PublishSubject -internal class LocalDiffusionImpl( +internal class LocalDiffusionONNXImpl( private val uNet: UNet, private val tokenizer: LocalDiffusionTextTokenizer, private val ortEnvironmentProvider: OrtEnvironmentProvider, -) : LocalDiffusion { +) : LocalDiffusionONNX { private val statusSubject: PublishSubject = PublishSubject.create() diff --git a/feature/diffusion/src/main/java/com/shifthackz/aisdv1/feature/diffusion/ai/extensions/ArrayExtensions.kt b/feature/local-diffusion-onnx/src/main/java/com/shifthackz/aisdv1/feature/localdiffusion/onnx/ai/extensions/ArrayExtensions.kt similarity index 92% rename from feature/diffusion/src/main/java/com/shifthackz/aisdv1/feature/diffusion/ai/extensions/ArrayExtensions.kt rename to feature/local-diffusion-onnx/src/main/java/com/shifthackz/aisdv1/feature/localdiffusion/onnx/ai/extensions/ArrayExtensions.kt index 5f844ee7..7051274a 100644 --- a/feature/diffusion/src/main/java/com/shifthackz/aisdv1/feature/diffusion/ai/extensions/ArrayExtensions.kt +++ b/feature/local-diffusion-onnx/src/main/java/com/shifthackz/aisdv1/feature/localdiffusion/onnx/ai/extensions/ArrayExtensions.kt @@ -1,6 +1,6 @@ -package com.shifthackz.aisdv1.feature.diffusion.ai.extensions +package com.shifthackz.aisdv1.feature.localdiffusion.onnx.ai.extensions -import com.shifthackz.aisdv1.feature.diffusion.entity.Array3D +import com.shifthackz.aisdv1.feature.localdiffusion.onnx.entity.Array3D import kotlin.math.ceil import java.util.function.Function diff --git a/feature/diffusion/src/main/java/com/shifthackz/aisdv1/feature/diffusion/ai/extensions/StringExtensions.kt b/feature/local-diffusion-onnx/src/main/java/com/shifthackz/aisdv1/feature/localdiffusion/onnx/ai/extensions/StringExtensions.kt similarity index 91% rename from feature/diffusion/src/main/java/com/shifthackz/aisdv1/feature/diffusion/ai/extensions/StringExtensions.kt rename to feature/local-diffusion-onnx/src/main/java/com/shifthackz/aisdv1/feature/localdiffusion/onnx/ai/extensions/StringExtensions.kt index 89519bc9..a767debd 100644 --- a/feature/diffusion/src/main/java/com/shifthackz/aisdv1/feature/diffusion/ai/extensions/StringExtensions.kt +++ b/feature/local-diffusion-onnx/src/main/java/com/shifthackz/aisdv1/feature/localdiffusion/onnx/ai/extensions/StringExtensions.kt @@ -1,4 +1,4 @@ -package com.shifthackz.aisdv1.feature.diffusion.ai.extensions +package com.shifthackz.aisdv1.feature.localdiffusion.onnx.ai.extensions internal fun String.halfCorner(): String { var output = this diff --git a/feature/diffusion/src/main/java/com/shifthackz/aisdv1/feature/diffusion/ai/extensions/TensorExtensions.kt b/feature/local-diffusion-onnx/src/main/java/com/shifthackz/aisdv1/feature/localdiffusion/onnx/ai/extensions/TensorExtensions.kt similarity index 88% rename from feature/diffusion/src/main/java/com/shifthackz/aisdv1/feature/diffusion/ai/extensions/TensorExtensions.kt rename to feature/local-diffusion-onnx/src/main/java/com/shifthackz/aisdv1/feature/localdiffusion/onnx/ai/extensions/TensorExtensions.kt index 97ba816b..ce013298 100644 --- a/feature/diffusion/src/main/java/com/shifthackz/aisdv1/feature/diffusion/ai/extensions/TensorExtensions.kt +++ b/feature/local-diffusion-onnx/src/main/java/com/shifthackz/aisdv1/feature/localdiffusion/onnx/ai/extensions/TensorExtensions.kt @@ -1,12 +1,12 @@ @file:Suppress("KotlinConstantConditions") -package com.shifthackz.aisdv1.feature.diffusion.ai.extensions +package com.shifthackz.aisdv1.feature.localdiffusion.onnx.ai.extensions import ai.onnxruntime.OnnxTensor import android.util.Pair -import com.shifthackz.aisdv1.feature.diffusion.entity.Array3D -import com.shifthackz.aisdv1.feature.diffusion.entity.LocalDiffusionTensor -import com.shifthackz.aisdv1.feature.diffusion.environment.OrtEnvironmentProvider +import com.shifthackz.aisdv1.feature.localdiffusion.onnx.entity.Array3D +import com.shifthackz.aisdv1.feature.localdiffusion.onnx.entity.LocalDiffusionTensor +import com.shifthackz.aisdv1.feature.localdiffusion.onnx.environment.OrtEnvironmentProvider import org.koin.java.KoinJavaComponent.inject import java.nio.FloatBuffer diff --git a/feature/diffusion/src/main/java/com/shifthackz/aisdv1/feature/diffusion/ai/scheduler/EulerAncestralDiscreteLocalDiffusionScheduler.kt b/feature/local-diffusion-onnx/src/main/java/com/shifthackz/aisdv1/feature/localdiffusion/onnx/ai/scheduler/EulerAncestralDiscreteLocalDiffusionScheduler.kt similarity index 89% rename from feature/diffusion/src/main/java/com/shifthackz/aisdv1/feature/diffusion/ai/scheduler/EulerAncestralDiscreteLocalDiffusionScheduler.kt rename to feature/local-diffusion-onnx/src/main/java/com/shifthackz/aisdv1/feature/localdiffusion/onnx/ai/scheduler/EulerAncestralDiscreteLocalDiffusionScheduler.kt index 68722b00..877ecb58 100644 --- a/feature/diffusion/src/main/java/com/shifthackz/aisdv1/feature/diffusion/ai/scheduler/EulerAncestralDiscreteLocalDiffusionScheduler.kt +++ b/feature/local-diffusion-onnx/src/main/java/com/shifthackz/aisdv1/feature/localdiffusion/onnx/ai/scheduler/EulerAncestralDiscreteLocalDiffusionScheduler.kt @@ -1,20 +1,20 @@ @file:Suppress("UNCHECKED_CAST") -package com.shifthackz.aisdv1.feature.diffusion.ai.scheduler +package com.shifthackz.aisdv1.feature.localdiffusion.onnx.ai.scheduler import ai.onnxruntime.OnnxTensor -import com.shifthackz.aisdv1.feature.diffusion.LocalDiffusionContract.BETA_SCHEDULER_LINEAR -import com.shifthackz.aisdv1.feature.diffusion.LocalDiffusionContract.BETA_SCHEDULER_SCALED_LINEAR -import com.shifthackz.aisdv1.feature.diffusion.LocalDiffusionContract.BETA_SCHEDULER_SQUARED_v2 -import com.shifthackz.aisdv1.feature.diffusion.LocalDiffusionContract.PREDICTION_EPSILON -import com.shifthackz.aisdv1.feature.diffusion.LocalDiffusionContract.PREDICTION_V -import com.shifthackz.aisdv1.feature.diffusion.ai.extensions.arrange -import com.shifthackz.aisdv1.feature.diffusion.ai.extensions.interpolate -import com.shifthackz.aisdv1.feature.diffusion.ai.extensions.lineSpace -import com.shifthackz.aisdv1.feature.diffusion.entity.Array3D -import com.shifthackz.aisdv1.feature.diffusion.entity.LocalDiffusionTensor -import com.shifthackz.aisdv1.feature.diffusion.environment.OrtEnvironmentProvider -import com.shifthackz.aisdv1.feature.diffusion.entity.LocalDiffusionConfig +import com.shifthackz.aisdv1.feature.localdiffusion.onnx.LocalDiffusionContract.BETA_SCHEDULER_LINEAR +import com.shifthackz.aisdv1.feature.localdiffusion.onnx.LocalDiffusionContract.BETA_SCHEDULER_SCALED_LINEAR +import com.shifthackz.aisdv1.feature.localdiffusion.onnx.LocalDiffusionContract.BETA_SCHEDULER_SQUARED_v2 +import com.shifthackz.aisdv1.feature.localdiffusion.onnx.LocalDiffusionContract.PREDICTION_EPSILON +import com.shifthackz.aisdv1.feature.localdiffusion.onnx.LocalDiffusionContract.PREDICTION_V +import com.shifthackz.aisdv1.feature.localdiffusion.onnx.ai.extensions.arrange +import com.shifthackz.aisdv1.feature.localdiffusion.onnx.ai.extensions.interpolate +import com.shifthackz.aisdv1.feature.localdiffusion.onnx.ai.extensions.lineSpace +import com.shifthackz.aisdv1.feature.localdiffusion.onnx.entity.Array3D +import com.shifthackz.aisdv1.feature.localdiffusion.onnx.entity.LocalDiffusionTensor +import com.shifthackz.aisdv1.feature.localdiffusion.onnx.environment.OrtEnvironmentProvider +import com.shifthackz.aisdv1.feature.localdiffusion.onnx.entity.LocalDiffusionConfig import org.koin.core.component.KoinComponent import org.koin.core.component.inject import java.nio.FloatBuffer diff --git a/feature/diffusion/src/main/java/com/shifthackz/aisdv1/feature/diffusion/ai/scheduler/LocalDiffusionScheduler.kt b/feature/local-diffusion-onnx/src/main/java/com/shifthackz/aisdv1/feature/localdiffusion/onnx/ai/scheduler/LocalDiffusionScheduler.kt similarity index 69% rename from feature/diffusion/src/main/java/com/shifthackz/aisdv1/feature/diffusion/ai/scheduler/LocalDiffusionScheduler.kt rename to feature/local-diffusion-onnx/src/main/java/com/shifthackz/aisdv1/feature/localdiffusion/onnx/ai/scheduler/LocalDiffusionScheduler.kt index 3a7f4f3b..a7bf08bc 100644 --- a/feature/diffusion/src/main/java/com/shifthackz/aisdv1/feature/diffusion/ai/scheduler/LocalDiffusionScheduler.kt +++ b/feature/local-diffusion-onnx/src/main/java/com/shifthackz/aisdv1/feature/localdiffusion/onnx/ai/scheduler/LocalDiffusionScheduler.kt @@ -1,6 +1,6 @@ -package com.shifthackz.aisdv1.feature.diffusion.ai.scheduler +package com.shifthackz.aisdv1.feature.localdiffusion.onnx.ai.scheduler -import com.shifthackz.aisdv1.feature.diffusion.entity.LocalDiffusionTensor +import com.shifthackz.aisdv1.feature.localdiffusion.onnx.entity.LocalDiffusionTensor internal interface LocalDiffusionScheduler { val initNoiseSigma: Double diff --git a/feature/diffusion/src/main/java/com/shifthackz/aisdv1/feature/diffusion/ai/tokenizer/EnglishTextTokenizer.kt b/feature/local-diffusion-onnx/src/main/java/com/shifthackz/aisdv1/feature/localdiffusion/onnx/ai/tokenizer/EnglishTextTokenizer.kt similarity index 88% rename from feature/diffusion/src/main/java/com/shifthackz/aisdv1/feature/diffusion/ai/tokenizer/EnglishTextTokenizer.kt rename to feature/local-diffusion-onnx/src/main/java/com/shifthackz/aisdv1/feature/localdiffusion/onnx/ai/tokenizer/EnglishTextTokenizer.kt index 0683aeff..748d209e 100644 --- a/feature/diffusion/src/main/java/com/shifthackz/aisdv1/feature/diffusion/ai/tokenizer/EnglishTextTokenizer.kt +++ b/feature/local-diffusion-onnx/src/main/java/com/shifthackz/aisdv1/feature/localdiffusion/onnx/ai/tokenizer/EnglishTextTokenizer.kt @@ -1,6 +1,6 @@ @file:Suppress("KotlinConstantConditions") -package com.shifthackz.aisdv1.feature.diffusion.ai.tokenizer +package com.shifthackz.aisdv1.feature.localdiffusion.onnx.ai.tokenizer import ai.onnxruntime.OnnxTensor import ai.onnxruntime.OrtSession @@ -11,16 +11,16 @@ import com.shifthackz.aisdv1.core.common.file.FileProviderDescriptor import com.shifthackz.aisdv1.core.common.log.debugLog import com.shifthackz.aisdv1.core.common.log.errorLog import com.shifthackz.aisdv1.domain.preference.PreferenceManager -import com.shifthackz.aisdv1.feature.diffusion.LocalDiffusionContract -import com.shifthackz.aisdv1.feature.diffusion.LocalDiffusionContract.KEY_INPUT_IDS -import com.shifthackz.aisdv1.feature.diffusion.LocalDiffusionContract.ORT -import com.shifthackz.aisdv1.feature.diffusion.LocalDiffusionContract.ORT_KEY_MODEL_FORMAT -import com.shifthackz.aisdv1.feature.diffusion.LocalDiffusionContract.TAG -import com.shifthackz.aisdv1.feature.diffusion.ai.extensions.halfCorner -import com.shifthackz.aisdv1.feature.diffusion.ai.extensions.toArrays -import com.shifthackz.aisdv1.feature.diffusion.environment.LocalModelIdProvider -import com.shifthackz.aisdv1.feature.diffusion.environment.OrtEnvironmentProvider -import com.shifthackz.aisdv1.feature.diffusion.extensions.modelPathPrefix +import com.shifthackz.aisdv1.feature.localdiffusion.onnx.LocalDiffusionContract +import com.shifthackz.aisdv1.feature.localdiffusion.onnx.LocalDiffusionContract.KEY_INPUT_IDS +import com.shifthackz.aisdv1.feature.localdiffusion.onnx.LocalDiffusionContract.ORT +import com.shifthackz.aisdv1.feature.localdiffusion.onnx.LocalDiffusionContract.ORT_KEY_MODEL_FORMAT +import com.shifthackz.aisdv1.feature.localdiffusion.onnx.LocalDiffusionContract.TAG +import com.shifthackz.aisdv1.feature.localdiffusion.onnx.ai.extensions.halfCorner +import com.shifthackz.aisdv1.feature.localdiffusion.onnx.ai.extensions.toArrays +import com.shifthackz.aisdv1.feature.localdiffusion.onnx.environment.LocalOnnxModelIdProvider +import com.shifthackz.aisdv1.feature.localdiffusion.onnx.environment.OrtEnvironmentProvider +import com.shifthackz.aisdv1.feature.localdiffusion.onnx.extensions.modelPathPrefix import java.io.BufferedReader import java.io.FileInputStream import java.io.InputStreamReader @@ -32,7 +32,7 @@ import java.util.regex.Pattern internal class EnglishTextTokenizer( private val ortEnvironmentProvider: OrtEnvironmentProvider, private val fileProviderDescriptor: FileProviderDescriptor, - private val localModelIdProvider: LocalModelIdProvider, + private val localOnnxModelIdProvider: LocalOnnxModelIdProvider, private val preferenceManager: PreferenceManager, ) : LocalDiffusionTextTokenizer { @@ -55,7 +55,7 @@ internal class EnglishTextTokenizer( val options = OrtSession.SessionOptions() options.addConfigEntry(ORT_KEY_MODEL_FORMAT, ORT) session = ortEnvironmentProvider.get().createSession( - "${modelPathPrefix(preferenceManager, fileProviderDescriptor, localModelIdProvider)}/${LocalDiffusionContract.TOKENIZER_MODEL}", + "${modelPathPrefix(preferenceManager, fileProviderDescriptor, localOnnxModelIdProvider)}/${LocalDiffusionContract.TOKENIZER_MODEL}", options ) debugLog("{$TAG} {TOKENIZER} {initialize} Session created successfully!") @@ -231,7 +231,7 @@ internal class EnglishTextTokenizer( private fun loadEncoder(): Map { val map: MutableMap = HashMap() try { - val path = "${modelPathPrefix(preferenceManager, fileProviderDescriptor, localModelIdProvider)}/${LocalDiffusionContract.TOKENIZER_VOCABULARY}" + val path = "${modelPathPrefix(preferenceManager, fileProviderDescriptor, localOnnxModelIdProvider)}/${LocalDiffusionContract.TOKENIZER_VOCABULARY}" val jsonReader = JsonReader(InputStreamReader(FileInputStream(path))) jsonReader.beginObject() while (jsonReader.hasNext()) { @@ -257,7 +257,7 @@ internal class EnglishTextTokenizer( private fun loadBpeRanks(): Map, Int?> { val result: MutableMap, Int?> = HashMap() try { - val path = "${modelPathPrefix(preferenceManager, fileProviderDescriptor, localModelIdProvider)}/${LocalDiffusionContract.TOKENIZER_MERGES}" + val path = "${modelPathPrefix(preferenceManager, fileProviderDescriptor, localOnnxModelIdProvider)}/${LocalDiffusionContract.TOKENIZER_MERGES}" val reader = BufferedReader(InputStreamReader(FileInputStream(path))) var line: String var startLine = 1 diff --git a/feature/diffusion/src/main/java/com/shifthackz/aisdv1/feature/diffusion/ai/tokenizer/LocalDiffusionTextTokenizer.kt b/feature/local-diffusion-onnx/src/main/java/com/shifthackz/aisdv1/feature/localdiffusion/onnx/ai/tokenizer/LocalDiffusionTextTokenizer.kt similarity index 81% rename from feature/diffusion/src/main/java/com/shifthackz/aisdv1/feature/diffusion/ai/tokenizer/LocalDiffusionTextTokenizer.kt rename to feature/local-diffusion-onnx/src/main/java/com/shifthackz/aisdv1/feature/localdiffusion/onnx/ai/tokenizer/LocalDiffusionTextTokenizer.kt index ee66bfdb..44d74cc1 100644 --- a/feature/diffusion/src/main/java/com/shifthackz/aisdv1/feature/diffusion/ai/tokenizer/LocalDiffusionTextTokenizer.kt +++ b/feature/local-diffusion-onnx/src/main/java/com/shifthackz/aisdv1/feature/localdiffusion/onnx/ai/tokenizer/LocalDiffusionTextTokenizer.kt @@ -1,4 +1,4 @@ -package com.shifthackz.aisdv1.feature.diffusion.ai.tokenizer +package com.shifthackz.aisdv1.feature.localdiffusion.onnx.ai.tokenizer import ai.onnxruntime.OnnxTensor diff --git a/feature/diffusion/src/main/java/com/shifthackz/aisdv1/feature/diffusion/ai/tokenizer/TokenizerByteSet.kt b/feature/local-diffusion-onnx/src/main/java/com/shifthackz/aisdv1/feature/localdiffusion/onnx/ai/tokenizer/TokenizerByteSet.kt similarity index 98% rename from feature/diffusion/src/main/java/com/shifthackz/aisdv1/feature/diffusion/ai/tokenizer/TokenizerByteSet.kt rename to feature/local-diffusion-onnx/src/main/java/com/shifthackz/aisdv1/feature/localdiffusion/onnx/ai/tokenizer/TokenizerByteSet.kt index 615b5548..c82985ee 100644 --- a/feature/diffusion/src/main/java/com/shifthackz/aisdv1/feature/diffusion/ai/tokenizer/TokenizerByteSet.kt +++ b/feature/local-diffusion-onnx/src/main/java/com/shifthackz/aisdv1/feature/localdiffusion/onnx/ai/tokenizer/TokenizerByteSet.kt @@ -1,4 +1,4 @@ -package com.shifthackz.aisdv1.feature.diffusion.ai.tokenizer +package com.shifthackz.aisdv1.feature.localdiffusion.onnx.ai.tokenizer object TokenizerByteSet { val byteEncoder: MutableMap = HashMap() diff --git a/feature/diffusion/src/main/java/com/shifthackz/aisdv1/feature/diffusion/ai/unet/UNet.kt b/feature/local-diffusion-onnx/src/main/java/com/shifthackz/aisdv1/feature/localdiffusion/onnx/ai/unet/UNet.kt similarity index 83% rename from feature/diffusion/src/main/java/com/shifthackz/aisdv1/feature/diffusion/ai/unet/UNet.kt rename to feature/local-diffusion-onnx/src/main/java/com/shifthackz/aisdv1/feature/localdiffusion/onnx/ai/unet/UNet.kt index 85cc131f..b9f0c73d 100644 --- a/feature/diffusion/src/main/java/com/shifthackz/aisdv1/feature/diffusion/ai/unet/UNet.kt +++ b/feature/local-diffusion-onnx/src/main/java/com/shifthackz/aisdv1/feature/localdiffusion/onnx/ai/unet/UNet.kt @@ -1,6 +1,6 @@ @file:Suppress("UNCHECKED_CAST", "MemberVisibilityCanBePrivate") -package com.shifthackz.aisdv1.feature.diffusion.ai.unet +package com.shifthackz.aisdv1.feature.localdiffusion.onnx.ai.unet import ai.onnxruntime.OnnxTensor import ai.onnxruntime.OrtSession @@ -11,27 +11,27 @@ import android.util.Pair import com.shifthackz.aisdv1.core.common.file.FileProviderDescriptor import com.shifthackz.aisdv1.core.common.log.debugLog import com.shifthackz.aisdv1.domain.preference.PreferenceManager -import com.shifthackz.aisdv1.feature.diffusion.LocalDiffusionContract -import com.shifthackz.aisdv1.feature.diffusion.LocalDiffusionContract.KEY_ENCODER_HIDDEN_STATES -import com.shifthackz.aisdv1.feature.diffusion.LocalDiffusionContract.KEY_LATENT_SAMPLE -import com.shifthackz.aisdv1.feature.diffusion.LocalDiffusionContract.KEY_SAMPLE -import com.shifthackz.aisdv1.feature.diffusion.LocalDiffusionContract.KEY_TIME_STEP -import com.shifthackz.aisdv1.feature.diffusion.LocalDiffusionContract.ORT -import com.shifthackz.aisdv1.feature.diffusion.LocalDiffusionContract.ORT_KEY_MODEL_FORMAT -import com.shifthackz.aisdv1.feature.diffusion.LocalDiffusionContract.TAG -import com.shifthackz.aisdv1.feature.diffusion.ai.extensions.duplicate -import com.shifthackz.aisdv1.feature.diffusion.ai.extensions.getSizes -import com.shifthackz.aisdv1.feature.diffusion.ai.extensions.multipleTensorsByFloat -import com.shifthackz.aisdv1.feature.diffusion.ai.extensions.splitTensor -import com.shifthackz.aisdv1.feature.diffusion.ai.scheduler.EulerAncestralDiscreteLocalDiffusionScheduler -import com.shifthackz.aisdv1.feature.diffusion.ai.vae.VaeDecoder -import com.shifthackz.aisdv1.feature.diffusion.entity.Array3D -import com.shifthackz.aisdv1.feature.diffusion.entity.LocalDiffusionFlag -import com.shifthackz.aisdv1.feature.diffusion.entity.LocalDiffusionTensor -import com.shifthackz.aisdv1.feature.diffusion.environment.DeviceNNAPIFlagProvider -import com.shifthackz.aisdv1.feature.diffusion.environment.LocalModelIdProvider -import com.shifthackz.aisdv1.feature.diffusion.environment.OrtEnvironmentProvider -import com.shifthackz.aisdv1.feature.diffusion.extensions.modelPathPrefix +import com.shifthackz.aisdv1.feature.localdiffusion.onnx.LocalDiffusionContract +import com.shifthackz.aisdv1.feature.localdiffusion.onnx.LocalDiffusionContract.KEY_ENCODER_HIDDEN_STATES +import com.shifthackz.aisdv1.feature.localdiffusion.onnx.LocalDiffusionContract.KEY_LATENT_SAMPLE +import com.shifthackz.aisdv1.feature.localdiffusion.onnx.LocalDiffusionContract.KEY_SAMPLE +import com.shifthackz.aisdv1.feature.localdiffusion.onnx.LocalDiffusionContract.KEY_TIME_STEP +import com.shifthackz.aisdv1.feature.localdiffusion.onnx.LocalDiffusionContract.ORT +import com.shifthackz.aisdv1.feature.localdiffusion.onnx.LocalDiffusionContract.ORT_KEY_MODEL_FORMAT +import com.shifthackz.aisdv1.feature.localdiffusion.onnx.LocalDiffusionContract.TAG +import com.shifthackz.aisdv1.feature.localdiffusion.onnx.ai.extensions.duplicate +import com.shifthackz.aisdv1.feature.localdiffusion.onnx.ai.extensions.getSizes +import com.shifthackz.aisdv1.feature.localdiffusion.onnx.ai.extensions.multipleTensorsByFloat +import com.shifthackz.aisdv1.feature.localdiffusion.onnx.ai.extensions.splitTensor +import com.shifthackz.aisdv1.feature.localdiffusion.onnx.ai.scheduler.EulerAncestralDiscreteLocalDiffusionScheduler +import com.shifthackz.aisdv1.feature.localdiffusion.onnx.ai.vae.VaeDecoder +import com.shifthackz.aisdv1.feature.localdiffusion.onnx.entity.Array3D +import com.shifthackz.aisdv1.feature.localdiffusion.onnx.entity.LocalDiffusionFlag +import com.shifthackz.aisdv1.feature.localdiffusion.onnx.entity.LocalDiffusionTensor +import com.shifthackz.aisdv1.feature.localdiffusion.onnx.environment.DeviceNNAPIFlagProvider +import com.shifthackz.aisdv1.feature.localdiffusion.onnx.environment.LocalOnnxModelIdProvider +import com.shifthackz.aisdv1.feature.localdiffusion.onnx.environment.OrtEnvironmentProvider +import com.shifthackz.aisdv1.feature.localdiffusion.onnx.extensions.modelPathPrefix import java.nio.IntBuffer import java.util.EnumSet import java.util.Random @@ -43,7 +43,7 @@ internal class UNet( private val deviceNNAPIFlagProvider: DeviceNNAPIFlagProvider, private val ortEnvironmentProvider: OrtEnvironmentProvider, private val fileProviderDescriptor: FileProviderDescriptor, - private val localModelIdProvider: LocalModelIdProvider, + private val localOnnxModelIdProvider: LocalOnnxModelIdProvider, private val preferenceManager: PreferenceManager, ) { @@ -62,7 +62,7 @@ internal class UNet( decoder = VaeDecoder( ortEnvironmentProvider, fileProviderDescriptor, - localModelIdProvider, + localOnnxModelIdProvider, preferenceManager, deviceNNAPIFlagProvider.get(), ) @@ -72,7 +72,7 @@ internal class UNet( options.addNnapi(EnumSet.of(NNAPIFlags.CPU_DISABLED)) } session = ortEnvironmentProvider.get().createSession( - "${modelPathPrefix(preferenceManager, fileProviderDescriptor, localModelIdProvider)}/${LocalDiffusionContract.UNET_MODEL}", + "${modelPathPrefix(preferenceManager, fileProviderDescriptor, localOnnxModelIdProvider)}/${LocalDiffusionContract.UNET_MODEL}", options ) } diff --git a/feature/diffusion/src/main/java/com/shifthackz/aisdv1/feature/diffusion/ai/vae/VaeDecoder.kt b/feature/local-diffusion-onnx/src/main/java/com/shifthackz/aisdv1/feature/localdiffusion/onnx/ai/vae/VaeDecoder.kt similarity index 71% rename from feature/diffusion/src/main/java/com/shifthackz/aisdv1/feature/diffusion/ai/vae/VaeDecoder.kt rename to feature/local-diffusion-onnx/src/main/java/com/shifthackz/aisdv1/feature/localdiffusion/onnx/ai/vae/VaeDecoder.kt index 5dea5b55..2ef93f8e 100644 --- a/feature/diffusion/src/main/java/com/shifthackz/aisdv1/feature/diffusion/ai/vae/VaeDecoder.kt +++ b/feature/local-diffusion-onnx/src/main/java/com/shifthackz/aisdv1/feature/localdiffusion/onnx/ai/vae/VaeDecoder.kt @@ -1,4 +1,4 @@ -package com.shifthackz.aisdv1.feature.diffusion.ai.vae +package com.shifthackz.aisdv1.feature.localdiffusion.onnx.ai.vae import ai.onnxruntime.OnnxTensor import ai.onnxruntime.OrtSession @@ -8,21 +8,21 @@ import android.graphics.Bitmap import android.graphics.Color import com.shifthackz.aisdv1.core.common.file.FileProviderDescriptor import com.shifthackz.aisdv1.domain.preference.PreferenceManager -import com.shifthackz.aisdv1.feature.diffusion.LocalDiffusionContract -import com.shifthackz.aisdv1.feature.diffusion.LocalDiffusionContract.ORT -import com.shifthackz.aisdv1.feature.diffusion.LocalDiffusionContract.ORT_KEY_MODEL_FORMAT -import com.shifthackz.aisdv1.feature.diffusion.entity.Array3D -import com.shifthackz.aisdv1.feature.diffusion.entity.LocalDiffusionFlag -import com.shifthackz.aisdv1.feature.diffusion.environment.LocalModelIdProvider -import com.shifthackz.aisdv1.feature.diffusion.environment.OrtEnvironmentProvider -import com.shifthackz.aisdv1.feature.diffusion.extensions.modelPathPrefix +import com.shifthackz.aisdv1.feature.localdiffusion.onnx.LocalDiffusionContract +import com.shifthackz.aisdv1.feature.localdiffusion.onnx.LocalDiffusionContract.ORT +import com.shifthackz.aisdv1.feature.localdiffusion.onnx.LocalDiffusionContract.ORT_KEY_MODEL_FORMAT +import com.shifthackz.aisdv1.feature.localdiffusion.onnx.entity.Array3D +import com.shifthackz.aisdv1.feature.localdiffusion.onnx.entity.LocalDiffusionFlag +import com.shifthackz.aisdv1.feature.localdiffusion.onnx.environment.LocalOnnxModelIdProvider +import com.shifthackz.aisdv1.feature.localdiffusion.onnx.environment.OrtEnvironmentProvider +import com.shifthackz.aisdv1.feature.localdiffusion.onnx.extensions.modelPathPrefix import java.util.EnumSet import kotlin.math.roundToInt internal class VaeDecoder( private val ortEnvironmentProvider: OrtEnvironmentProvider, private val fileProviderDescriptor: FileProviderDescriptor, - private val localModelIdProvider: LocalModelIdProvider, + private val localOnnxModelIdProvider: LocalOnnxModelIdProvider, private val preferenceManager: PreferenceManager, private val deviceId: Int, ) { @@ -69,7 +69,7 @@ internal class VaeDecoder( options.addNnapi(EnumSet.of(NNAPIFlags.CPU_DISABLED)) } session = ortEnvironmentProvider.get().createSession( - "${modelPathPrefix(preferenceManager, fileProviderDescriptor, localModelIdProvider)}/${LocalDiffusionContract.VAE_MODEL}", + "${modelPathPrefix(preferenceManager, fileProviderDescriptor, localOnnxModelIdProvider)}/${LocalDiffusionContract.VAE_MODEL}", options ) } diff --git a/feature/local-diffusion-onnx/src/main/java/com/shifthackz/aisdv1/feature/localdiffusion/onnx/di/OnnxLocalDiffusionModule.kt b/feature/local-diffusion-onnx/src/main/java/com/shifthackz/aisdv1/feature/localdiffusion/onnx/di/OnnxLocalDiffusionModule.kt new file mode 100644 index 00000000..39c91449 --- /dev/null +++ b/feature/local-diffusion-onnx/src/main/java/com/shifthackz/aisdv1/feature/localdiffusion/onnx/di/OnnxLocalDiffusionModule.kt @@ -0,0 +1,19 @@ +package com.shifthackz.aisdv1.feature.localdiffusion.onnx.di + +import com.shifthackz.aisdv1.domain.feature.diffusion.LocalDiffusionONNX +import com.shifthackz.aisdv1.feature.localdiffusion.onnx.LocalDiffusionONNXImpl +import com.shifthackz.aisdv1.feature.localdiffusion.onnx.ai.tokenizer.EnglishTextTokenizer +import com.shifthackz.aisdv1.feature.localdiffusion.onnx.ai.tokenizer.LocalDiffusionTextTokenizer +import com.shifthackz.aisdv1.feature.localdiffusion.onnx.ai.unet.UNet +import com.shifthackz.aisdv1.feature.localdiffusion.onnx.environment.OrtEnvironmentProvider +import com.shifthackz.aisdv1.feature.localdiffusion.onnx.environment.OrtEnvironmentProviderImpl +import org.koin.core.module.dsl.singleOf +import org.koin.dsl.bind +import org.koin.dsl.module + +val onnxLocalDiffusionModule = module { + singleOf(::UNet) + singleOf(::EnglishTextTokenizer) bind LocalDiffusionTextTokenizer::class + singleOf(::LocalDiffusionONNXImpl) bind LocalDiffusionONNX::class + singleOf(::OrtEnvironmentProviderImpl) bind OrtEnvironmentProvider::class +} diff --git a/feature/local-diffusion-onnx/src/main/java/com/shifthackz/aisdv1/feature/localdiffusion/onnx/entity/LocalDiffusionAlias.kt b/feature/local-diffusion-onnx/src/main/java/com/shifthackz/aisdv1/feature/localdiffusion/onnx/entity/LocalDiffusionAlias.kt new file mode 100644 index 00000000..d4146fb3 --- /dev/null +++ b/feature/local-diffusion-onnx/src/main/java/com/shifthackz/aisdv1/feature/localdiffusion/onnx/entity/LocalDiffusionAlias.kt @@ -0,0 +1,3 @@ +package com.shifthackz.aisdv1.feature.localdiffusion.onnx.entity + +typealias Array3D = Array>> diff --git a/feature/diffusion/src/main/java/com/shifthackz/aisdv1/feature/diffusion/entity/LocalDiffusionConfig.kt b/feature/local-diffusion-onnx/src/main/java/com/shifthackz/aisdv1/feature/localdiffusion/onnx/entity/LocalDiffusionConfig.kt similarity index 58% rename from feature/diffusion/src/main/java/com/shifthackz/aisdv1/feature/diffusion/entity/LocalDiffusionConfig.kt rename to feature/local-diffusion-onnx/src/main/java/com/shifthackz/aisdv1/feature/localdiffusion/onnx/entity/LocalDiffusionConfig.kt index 5d672808..9bbc7a41 100644 --- a/feature/diffusion/src/main/java/com/shifthackz/aisdv1/feature/diffusion/entity/LocalDiffusionConfig.kt +++ b/feature/local-diffusion-onnx/src/main/java/com/shifthackz/aisdv1/feature/localdiffusion/onnx/entity/LocalDiffusionConfig.kt @@ -1,9 +1,9 @@ -package com.shifthackz.aisdv1.feature.diffusion.entity +package com.shifthackz.aisdv1.feature.localdiffusion.onnx.entity -import com.shifthackz.aisdv1.feature.diffusion.LocalDiffusionContract.BETA_SCHEDULER_SCALED_LINEAR -import com.shifthackz.aisdv1.feature.diffusion.LocalDiffusionContract.DPM_SOLVER_PP -import com.shifthackz.aisdv1.feature.diffusion.LocalDiffusionContract.PREDICTION_EPSILON -import com.shifthackz.aisdv1.feature.diffusion.LocalDiffusionContract.SOLVER_MIDPOINT +import com.shifthackz.aisdv1.feature.localdiffusion.onnx.LocalDiffusionContract.BETA_SCHEDULER_SCALED_LINEAR +import com.shifthackz.aisdv1.feature.localdiffusion.onnx.LocalDiffusionContract.DPM_SOLVER_PP +import com.shifthackz.aisdv1.feature.localdiffusion.onnx.LocalDiffusionContract.PREDICTION_EPSILON +import com.shifthackz.aisdv1.feature.localdiffusion.onnx.LocalDiffusionContract.SOLVER_MIDPOINT internal data class LocalDiffusionConfig( val betaStart: Float = 0.00085f, diff --git a/feature/diffusion/src/main/java/com/shifthackz/aisdv1/feature/diffusion/entity/LocalDiffusionFlag.kt b/feature/local-diffusion-onnx/src/main/java/com/shifthackz/aisdv1/feature/localdiffusion/onnx/entity/LocalDiffusionFlag.kt similarity index 61% rename from feature/diffusion/src/main/java/com/shifthackz/aisdv1/feature/diffusion/entity/LocalDiffusionFlag.kt rename to feature/local-diffusion-onnx/src/main/java/com/shifthackz/aisdv1/feature/localdiffusion/onnx/entity/LocalDiffusionFlag.kt index 34479d85..8b0229d4 100644 --- a/feature/diffusion/src/main/java/com/shifthackz/aisdv1/feature/diffusion/entity/LocalDiffusionFlag.kt +++ b/feature/local-diffusion-onnx/src/main/java/com/shifthackz/aisdv1/feature/localdiffusion/onnx/entity/LocalDiffusionFlag.kt @@ -1,6 +1,6 @@ @file:Suppress("unused") -package com.shifthackz.aisdv1.feature.diffusion.entity +package com.shifthackz.aisdv1.feature.localdiffusion.onnx.entity enum class LocalDiffusionFlag(val value: Int) { CPU(0), diff --git a/feature/diffusion/src/main/java/com/shifthackz/aisdv1/feature/diffusion/entity/LocalDiffusionTensor.kt b/feature/local-diffusion-onnx/src/main/java/com/shifthackz/aisdv1/feature/localdiffusion/onnx/entity/LocalDiffusionTensor.kt similarity index 93% rename from feature/diffusion/src/main/java/com/shifthackz/aisdv1/feature/diffusion/entity/LocalDiffusionTensor.kt rename to feature/local-diffusion-onnx/src/main/java/com/shifthackz/aisdv1/feature/localdiffusion/onnx/entity/LocalDiffusionTensor.kt index 3287de5f..231ccd23 100644 --- a/feature/diffusion/src/main/java/com/shifthackz/aisdv1/feature/diffusion/entity/LocalDiffusionTensor.kt +++ b/feature/local-diffusion-onnx/src/main/java/com/shifthackz/aisdv1/feature/localdiffusion/onnx/entity/LocalDiffusionTensor.kt @@ -1,4 +1,4 @@ -package com.shifthackz.aisdv1.feature.diffusion.entity +package com.shifthackz.aisdv1.feature.localdiffusion.onnx.entity import ai.onnxruntime.OnnxTensor diff --git a/feature/local-diffusion-onnx/src/main/java/com/shifthackz/aisdv1/feature/localdiffusion/onnx/environment/DeviceNNAPIFlagProvider.kt b/feature/local-diffusion-onnx/src/main/java/com/shifthackz/aisdv1/feature/localdiffusion/onnx/environment/DeviceNNAPIFlagProvider.kt new file mode 100644 index 00000000..839edba7 --- /dev/null +++ b/feature/local-diffusion-onnx/src/main/java/com/shifthackz/aisdv1/feature/localdiffusion/onnx/environment/DeviceNNAPIFlagProvider.kt @@ -0,0 +1,5 @@ +package com.shifthackz.aisdv1.feature.localdiffusion.onnx.environment + +fun interface DeviceNNAPIFlagProvider { + fun get(): Int +} diff --git a/feature/local-diffusion-onnx/src/main/java/com/shifthackz/aisdv1/feature/localdiffusion/onnx/environment/LocalOnnxModelIdProvider.kt b/feature/local-diffusion-onnx/src/main/java/com/shifthackz/aisdv1/feature/localdiffusion/onnx/environment/LocalOnnxModelIdProvider.kt new file mode 100644 index 00000000..1ef3b24a --- /dev/null +++ b/feature/local-diffusion-onnx/src/main/java/com/shifthackz/aisdv1/feature/localdiffusion/onnx/environment/LocalOnnxModelIdProvider.kt @@ -0,0 +1,5 @@ +package com.shifthackz.aisdv1.feature.localdiffusion.onnx.environment + +fun interface LocalOnnxModelIdProvider { + fun get(): String +} diff --git a/feature/diffusion/src/main/java/com/shifthackz/aisdv1/feature/diffusion/environment/OrtEnvironmentProvider.kt b/feature/local-diffusion-onnx/src/main/java/com/shifthackz/aisdv1/feature/localdiffusion/onnx/environment/OrtEnvironmentProvider.kt similarity index 62% rename from feature/diffusion/src/main/java/com/shifthackz/aisdv1/feature/diffusion/environment/OrtEnvironmentProvider.kt rename to feature/local-diffusion-onnx/src/main/java/com/shifthackz/aisdv1/feature/localdiffusion/onnx/environment/OrtEnvironmentProvider.kt index 2d666660..1809b60f 100644 --- a/feature/diffusion/src/main/java/com/shifthackz/aisdv1/feature/diffusion/environment/OrtEnvironmentProvider.kt +++ b/feature/local-diffusion-onnx/src/main/java/com/shifthackz/aisdv1/feature/localdiffusion/onnx/environment/OrtEnvironmentProvider.kt @@ -1,4 +1,4 @@ -package com.shifthackz.aisdv1.feature.diffusion.environment +package com.shifthackz.aisdv1.feature.localdiffusion.onnx.environment import ai.onnxruntime.OrtEnvironment diff --git a/feature/diffusion/src/main/java/com/shifthackz/aisdv1/feature/diffusion/environment/OrtEnvironmentProviderImpl.kt b/feature/local-diffusion-onnx/src/main/java/com/shifthackz/aisdv1/feature/localdiffusion/onnx/environment/OrtEnvironmentProviderImpl.kt similarity index 79% rename from feature/diffusion/src/main/java/com/shifthackz/aisdv1/feature/diffusion/environment/OrtEnvironmentProviderImpl.kt rename to feature/local-diffusion-onnx/src/main/java/com/shifthackz/aisdv1/feature/localdiffusion/onnx/environment/OrtEnvironmentProviderImpl.kt index cc680035..533ab3cf 100644 --- a/feature/diffusion/src/main/java/com/shifthackz/aisdv1/feature/diffusion/environment/OrtEnvironmentProviderImpl.kt +++ b/feature/local-diffusion-onnx/src/main/java/com/shifthackz/aisdv1/feature/localdiffusion/onnx/environment/OrtEnvironmentProviderImpl.kt @@ -1,4 +1,4 @@ -package com.shifthackz.aisdv1.feature.diffusion.environment +package com.shifthackz.aisdv1.feature.localdiffusion.onnx.environment import ai.onnxruntime.OrtEnvironment diff --git a/feature/diffusion/src/main/java/com/shifthackz/aisdv1/feature/diffusion/extensions/LocalDiffusionPaths.kt b/feature/local-diffusion-onnx/src/main/java/com/shifthackz/aisdv1/feature/localdiffusion/onnx/extensions/LocalDiffusionPaths.kt similarity index 62% rename from feature/diffusion/src/main/java/com/shifthackz/aisdv1/feature/diffusion/extensions/LocalDiffusionPaths.kt rename to feature/local-diffusion-onnx/src/main/java/com/shifthackz/aisdv1/feature/localdiffusion/onnx/extensions/LocalDiffusionPaths.kt index 1b22e6d9..cc149657 100644 --- a/feature/diffusion/src/main/java/com/shifthackz/aisdv1/feature/diffusion/extensions/LocalDiffusionPaths.kt +++ b/feature/local-diffusion-onnx/src/main/java/com/shifthackz/aisdv1/feature/localdiffusion/onnx/extensions/LocalDiffusionPaths.kt @@ -1,16 +1,16 @@ -package com.shifthackz.aisdv1.feature.diffusion.extensions +package com.shifthackz.aisdv1.feature.localdiffusion.onnx.extensions import com.shifthackz.aisdv1.core.common.file.FileProviderDescriptor import com.shifthackz.aisdv1.domain.entity.LocalAiModel import com.shifthackz.aisdv1.domain.preference.PreferenceManager -import com.shifthackz.aisdv1.feature.diffusion.environment.LocalModelIdProvider +import com.shifthackz.aisdv1.feature.localdiffusion.onnx.environment.LocalOnnxModelIdProvider -fun modelPathPrefix( +internal fun modelPathPrefix( preferenceManager: PreferenceManager, fileProviderDescriptor: FileProviderDescriptor, - localModelIdProvider: LocalModelIdProvider, + localOnnxModelIdProvider: LocalOnnxModelIdProvider, ): String { - val modelId = localModelIdProvider.get() + val modelId = localOnnxModelIdProvider.get() return if (modelId == LocalAiModel.CustomOnnx.id) { preferenceManager.localOnnxCustomModelPath } else { diff --git a/gradle/libs.versions.toml b/gradle/libs.versions.toml index 184c5449..6053c24e 100644 --- a/gradle/libs.versions.toml +++ b/gradle/libs.versions.toml @@ -4,6 +4,7 @@ versionCode = "190" targetSdk = "34" compileSdk = "35" minSdk = "24" +ndk = "28.0.13004108" agp = "8.9.1" kotlin = "2.1.20" ksp = "2.1.20-1.0.32" # Should match kotlin version @@ -48,6 +49,8 @@ roboelectric = "4.14.1" testCoroutines = "1.10.1" mediaPipeGenerator = "0.10.21" serialization = "1.8.1" +junitVersion = "1.1.5" +espressoCore = "3.5.1" [libraries] android-tools-build-gradle = { group = "com.android.tools.build", name = "gradle", version.ref = "agp" } @@ -112,6 +115,8 @@ shifthackz-catppuccin-legacy = { group = "com.github.ShiftHackZ.Catppuccin-Andro shifthackz-catppuccin-compose = { group = "com.github.ShiftHackZ.Catppuccin-Android-Library", name = "compose", version.ref = "catppuccin" } shifthackz-catppuccin-splash = { group = "com.github.ShiftHackZ.Catppuccin-Android-Library", name = "splashscreen", version.ref = "catppuccin" } kotlinx-serialization-json = { group = "org.jetbrains.kotlinx", name = "kotlinx-serialization-json", version.ref = "serialization" } +androidx-junit = { group = "androidx.test.ext", name = "junit", version.ref = "junitVersion" } +androidx-espresso-core = { group = "androidx.test.espresso", name = "espresso-core", version.ref = "espressoCore" } [plugins] android-application = { id = "com.android.application", version.ref = "agp" } diff --git a/network/src/main/java/com/shifthackz/aisdv1/network/api/sdai/DownloadableModelsApi.kt b/network/src/main/java/com/shifthackz/aisdv1/network/api/sdai/DownloadableModelsApi.kt index 78a35382..6dd4bda7 100644 --- a/network/src/main/java/com/shifthackz/aisdv1/network/api/sdai/DownloadableModelsApi.kt +++ b/network/src/main/java/com/shifthackz/aisdv1/network/api/sdai/DownloadableModelsApi.kt @@ -13,6 +13,8 @@ interface DownloadableModelsApi { fun fetchOnnxModels(): Single> + fun fetchCppModels(): Single> + fun fetchMediaPipeModels(): Single> fun downloadModel( @@ -27,6 +29,9 @@ interface DownloadableModelsApi { @GET("/models.json") fun fetchOnnxModels(): Single> + @GET("/cpp.json") + fun fetchCppModels(): Single> + @GET("/mediapipe.json") fun fetchMediaPipeModels(): Single> diff --git a/network/src/main/java/com/shifthackz/aisdv1/network/api/sdai/DownloadableModelsApiImpl.kt b/network/src/main/java/com/shifthackz/aisdv1/network/api/sdai/DownloadableModelsApiImpl.kt index 04fd71c1..eefcaa26 100644 --- a/network/src/main/java/com/shifthackz/aisdv1/network/api/sdai/DownloadableModelsApiImpl.kt +++ b/network/src/main/java/com/shifthackz/aisdv1/network/api/sdai/DownloadableModelsApiImpl.kt @@ -12,6 +12,8 @@ internal class DownloadableModelsApiImpl( override fun fetchOnnxModels() = rawApi.fetchOnnxModels() + override fun fetchCppModels() = rawApi.fetchCppModels() + override fun fetchMediaPipeModels() = rawApi.fetchMediaPipeModels() override fun downloadModel( diff --git a/presentation/src/main/java/com/shifthackz/aisdv1/presentation/di/ViewModelModule.kt b/presentation/src/main/java/com/shifthackz/aisdv1/presentation/di/ViewModelModule.kt index 054bdb33..3abf5dd5 100755 --- a/presentation/src/main/java/com/shifthackz/aisdv1/presentation/di/ViewModelModule.kt +++ b/presentation/src/main/java/com/shifthackz/aisdv1/presentation/di/ViewModelModule.kt @@ -75,6 +75,7 @@ val viewModelModule = module { dispatchersProvider = get(), getConfigurationUseCase = get(), getLocalOnnxModelsUseCase = get(), + getLocalCppModelsUseCase = get(), getLocalMediaPipeModelsUseCase = get(), fetchAndGetHuggingFaceModelsUseCase = get(), urlValidator = get(), diff --git a/presentation/src/main/java/com/shifthackz/aisdv1/presentation/model/Modal.kt b/presentation/src/main/java/com/shifthackz/aisdv1/presentation/model/Modal.kt index 98983733..95861abc 100644 --- a/presentation/src/main/java/com/shifthackz/aisdv1/presentation/model/Modal.kt +++ b/presentation/src/main/java/com/shifthackz/aisdv1/presentation/model/Modal.kt @@ -8,7 +8,6 @@ import com.shifthackz.aisdv1.domain.entity.AiGenerationResult import com.shifthackz.aisdv1.domain.entity.Grid import com.shifthackz.aisdv1.domain.entity.HordeProcessStatus import com.shifthackz.aisdv1.domain.entity.LocalDiffusionStatus -import com.shifthackz.aisdv1.domain.feature.diffusion.LocalDiffusion import com.shifthackz.aisdv1.presentation.screen.setup.ServerSetupState sealed interface Modal { diff --git a/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/settings/SettingsScreen.kt b/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/settings/SettingsScreen.kt index 1942f01c..82ceb74e 100644 --- a/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/settings/SettingsScreen.kt +++ b/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/settings/SettingsScreen.kt @@ -240,6 +240,7 @@ private fun ContentSettingsState( ServerSource.LOCAL_MICROSOFT_ONNX -> LocalizationR.string.srv_type_local_short ServerSource.LOCAL_GOOGLE_MEDIA_PIPE -> LocalizationR.string.srv_type_media_pipe_short ServerSource.SWARM_UI -> LocalizationR.string.srv_type_swarm_ui + ServerSource.LOCAL_CPP -> LocalizationR.string.srv_type_local_cpp_short }.asUiText(), onClick = { processIntent(SettingsIntent.NavigateConfiguration) }, ) diff --git a/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/setup/ServerSetupScreen.kt b/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/setup/ServerSetupScreen.kt index 0a77f3ad..ebabd65a 100644 --- a/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/setup/ServerSetupScreen.kt +++ b/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/setup/ServerSetupScreen.kt @@ -159,6 +159,10 @@ fun ServerSetupScreenContent( it.downloaded && it.selected } + ServerSource.LOCAL_CPP -> state.localCppModels.any { + it.downloaded && it.selected + } + ServerSource.LOCAL_GOOGLE_MEDIA_PIPE -> state.localMediaPipeModels.any { it.downloaded && it.selected } diff --git a/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/setup/ServerSetupState.kt b/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/setup/ServerSetupState.kt index ca65ed90..f99c3202 100644 --- a/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/setup/ServerSetupState.kt +++ b/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/setup/ServerSetupState.kt @@ -39,6 +39,9 @@ data class ServerSetupState( val localOnnxModels: List = emptyList(), val localOnnxCustomModel: Boolean = false, val localOnnxCustomModelPath: String = "", + val localCppModels: List = emptyList(), + val localCppCustomModel: Boolean = false, + val localCppCustomModelPath: String = "", val localMediaPipeModels: List = emptyList(), val localMediaPipeCustomModel: Boolean = false, val localMediaPipeCustomModelPath: String = "", @@ -52,35 +55,40 @@ data class ServerSetupState( val openAiApiKeyValidationError: UiText? = null, val stabilityAiApiKeyValidationError: UiText? = null, val localCustomOnnxPathValidationError: UiText? = null, + val localCustomCppPathValidationError: UiText? = null, val localCustomMediaPipePathValidationError: UiText? = null, ) : MviState, KoinComponent { val localCustomModel: Boolean - get() = if (mode == ServerSource.LOCAL_MICROSOFT_ONNX) { - localOnnxCustomModel - } else { - localMediaPipeCustomModel + get() = when (mode) { + ServerSource.LOCAL_MICROSOFT_ONNX -> localOnnxCustomModel + ServerSource.LOCAL_CPP -> localCppCustomModel + ServerSource.LOCAL_GOOGLE_MEDIA_PIPE -> localMediaPipeCustomModel + else -> false } val localCustomModelPath: String - get() = if (mode == ServerSource.LOCAL_MICROSOFT_ONNX) { - localOnnxCustomModelPath - } else { - localMediaPipeCustomModelPath + get() = when (mode) { + ServerSource.LOCAL_MICROSOFT_ONNX -> localOnnxCustomModelPath + ServerSource.LOCAL_CPP -> localCppCustomModelPath + ServerSource.LOCAL_GOOGLE_MEDIA_PIPE -> localMediaPipeCustomModelPath + else -> "" } val localModels: List - get() = if (mode == ServerSource.LOCAL_MICROSOFT_ONNX) { - localOnnxModels - } else { - localMediaPipeModels + get() = when (mode) { + ServerSource.LOCAL_MICROSOFT_ONNX -> localOnnxModels + ServerSource.LOCAL_CPP -> localCppModels + ServerSource.LOCAL_GOOGLE_MEDIA_PIPE -> localMediaPipeModels + else -> emptyList() } val localCustomModelPathValidationError: UiText? - get() = if (mode == ServerSource.LOCAL_MICROSOFT_ONNX) { - localCustomOnnxPathValidationError - } else { - localCustomMediaPipePathValidationError + get() = when (mode) { + ServerSource.LOCAL_MICROSOFT_ONNX -> localCustomOnnxPathValidationError + ServerSource.LOCAL_CPP -> localCustomCppPathValidationError + ServerSource.LOCAL_GOOGLE_MEDIA_PIPE -> localCustomMediaPipePathValidationError + else -> null } val demoModeUrl: String @@ -109,6 +117,11 @@ data class ServerSetupState( localCustomOnnxPathValidationError = null, ) + ServerSource.LOCAL_CPP -> copy( + localCppCustomModelPath = value, + localCustomCppPathValidationError = null, + ) + ServerSource.LOCAL_GOOGLE_MEDIA_PIPE -> copy( localMediaPipeCustomModelPath = value, localCustomMediaPipePathValidationError = null, @@ -121,6 +134,9 @@ data class ServerSetupState( ServerSource.LOCAL_MICROSOFT_ONNX -> copy( localOnnxModels = localOnnxModels.withNewState(value) ) + ServerSource.LOCAL_CPP -> copy( + localCppModels = localCppModels.withNewState(value) + ) ServerSource.LOCAL_GOOGLE_MEDIA_PIPE -> copy( localMediaPipeModels = localMediaPipeModels.withNewState(value) ) @@ -135,7 +151,17 @@ data class ServerSetupState( downloadState = DownloadState.Unknown, downloaded = false, ), - ) + ), + ) + + ServerSource.LOCAL_CPP -> copy( + screenModal = Modal.None, + localCppModels = localCppModels.withNewState( + value.copy( + downloadState = DownloadState.Unknown, + downloaded = false, + ), + ), ) ServerSource.LOCAL_GOOGLE_MEDIA_PIPE -> copy( @@ -145,7 +171,7 @@ data class ServerSetupState( downloadState = DownloadState.Unknown, downloaded = false, ), - ) + ), ) else -> copy(screenModal = Modal.None) @@ -164,6 +190,12 @@ data class ServerSetupState( ), ) + ServerSource.LOCAL_CPP -> copy( + localCppModels = localCppModels.withNewState( + value.copy(selected = true) + ) + ) + else -> this } @@ -179,6 +211,13 @@ data class ServerSetupState( ), ) + ServerSource.LOCAL_CPP -> this.copy( + localCppCustomModel = value, + localCppModels = localCppModels.updateCustomModelSelection( + id = LocalAiModel.CustomCpp.id + ) + ) + ServerSource.LOCAL_GOOGLE_MEDIA_PIPE -> this.copy( localMediaPipeCustomModel = value, localMediaPipeModels = localMediaPipeModels.updateCustomModelSelection( diff --git a/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/setup/ServerSetupViewModel.kt b/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/setup/ServerSetupViewModel.kt index 0e560195..6b65969d 100644 --- a/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/setup/ServerSetupViewModel.kt +++ b/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/setup/ServerSetupViewModel.kt @@ -5,6 +5,7 @@ import com.shifthackz.aisdv1.core.common.appbuild.BuildType import com.shifthackz.aisdv1.core.common.log.errorLog import com.shifthackz.aisdv1.core.common.schedulers.DispatchersProvider import com.shifthackz.aisdv1.core.common.model.Quadruple +import com.shifthackz.aisdv1.core.common.model.Quintuple import com.shifthackz.aisdv1.core.common.schedulers.SchedulersProvider import com.shifthackz.aisdv1.core.common.schedulers.subscribeOnMainThread import com.shifthackz.aisdv1.core.model.asUiText @@ -21,6 +22,7 @@ import com.shifthackz.aisdv1.domain.interactor.wakelock.WakeLockInterActor import com.shifthackz.aisdv1.domain.preference.PreferenceManager import com.shifthackz.aisdv1.domain.usecase.downloadable.DeleteModelUseCase import com.shifthackz.aisdv1.domain.usecase.downloadable.DownloadModelUseCase +import com.shifthackz.aisdv1.domain.usecase.downloadable.GetLocalCppModelsUseCase import com.shifthackz.aisdv1.domain.usecase.downloadable.GetLocalMediaPipeModelsUseCase import com.shifthackz.aisdv1.domain.usecase.downloadable.GetLocalOnnxModelsUseCase import com.shifthackz.aisdv1.domain.usecase.huggingface.FetchAndGetHuggingFaceModelsUseCase @@ -29,6 +31,7 @@ import com.shifthackz.aisdv1.presentation.model.LaunchSource import com.shifthackz.aisdv1.presentation.model.Modal import com.shifthackz.aisdv1.presentation.navigation.router.main.MainRouter import com.shifthackz.aisdv1.presentation.screen.setup.mappers.allowedModes +import com.shifthackz.aisdv1.presentation.screen.setup.mappers.mapLocalCustomCppSwitchState import com.shifthackz.aisdv1.presentation.screen.setup.mappers.mapLocalCustomMediaPipeSwitchState import com.shifthackz.aisdv1.presentation.screen.setup.mappers.mapLocalCustomOnnxSwitchState import com.shifthackz.aisdv1.presentation.screen.setup.mappers.mapToUi @@ -42,6 +45,7 @@ class ServerSetupViewModel( dispatchersProvider: DispatchersProvider, getConfigurationUseCase: GetConfigurationUseCase, getLocalOnnxModelsUseCase: GetLocalOnnxModelsUseCase, + getLocalCppModelsUseCase: GetLocalCppModelsUseCase, getLocalMediaPipeModelsUseCase: GetLocalMediaPipeModelsUseCase, fetchAndGetHuggingFaceModelsUseCase: FetchAndGetHuggingFaceModelsUseCase, private val urlValidator: UrlValidator, @@ -81,12 +85,13 @@ class ServerSetupViewModel( !Single.zip( getConfigurationUseCase(), getLocalOnnxModelsUseCase(), + getLocalCppModelsUseCase(), getLocalMediaPipeModelsUseCase(), fetchAndGetHuggingFaceModelsUseCase(), - ::Quadruple, + ::Quintuple, ) .subscribeOnMainThread(schedulersProvider) - .subscribeBy(::errorLog) { (configuration, onnxModels, mpModels, hfModels) -> + .subscribeBy(::errorLog) { (configuration, onnxModels, cppModels, mpModels, hfModels) -> updateState { state -> state.copy( huggingFaceModels = hfModels.map(HuggingFaceModel::alias), @@ -94,6 +99,8 @@ class ServerSetupViewModel( huggingFaceApiKey = configuration.huggingFaceApiKey, openAiApiKey = configuration.openAiApiKey, stabilityAiApiKey = configuration.stabilityAiApiKey, + localCppModels = cppModels.mapToUi(), + localCppCustomModel = cppModels.mapLocalCustomCppSwitchState(), localOnnxModels = onnxModels.mapToUi(), localOnnxCustomModel = onnxModels.mapLocalCustomOnnxSwitchState(), localOnnxCustomModelPath = configuration.localOnnxModelPath, @@ -241,13 +248,14 @@ class ServerSetupViewModel( emitEffect(ServerSetupEffect.HideKeyboard) !when (currentState.mode) { ServerSource.HORDE -> connectToHorde() - ServerSource.LOCAL_MICROSOFT_ONNX -> connectToLocalDiffusion() + ServerSource.LOCAL_MICROSOFT_ONNX -> connectToLocalDiffusionOnnx() ServerSource.AUTOMATIC1111 -> connectToAutomaticInstance() ServerSource.HUGGING_FACE -> connectToHuggingFace() ServerSource.OPEN_AI -> connectToOpenAi() ServerSource.STABILITY_AI -> connectToStabilityAi() ServerSource.SWARM_UI -> connectToSwarmUi() ServerSource.LOCAL_GOOGLE_MEDIA_PIPE -> connectToMediaPipe() + ServerSource.LOCAL_CPP -> connectToLocalDiffusionCpp() } .doOnSubscribe { setScreenModal(Modal.Communicating(canCancel = false)) } .subscribeOnMainThread(schedulersProvider) @@ -291,6 +299,16 @@ class ServerSetupViewModel( currentState.localOnnxModels.find { it.selected && it.downloaded } != null } + ServerSource.LOCAL_CPP -> if (currentState.localCppCustomModel) { + val validation = filePathValidator(currentState.localCppCustomModelPath) + updateState { + it.copy(localCustomCppPathValidationError = validation.mapToUi()) + } + validation.isValid + } else { + currentState.localCppModels.find { it.selected && it.downloaded } != null + } + ServerSource.LOCAL_GOOGLE_MEDIA_PIPE -> when { buildInfoProvider.type == BuildType.FOSS -> false currentState.localMediaPipeCustomModel -> { @@ -406,10 +424,16 @@ class ServerSetupViewModel( return setupConnectionInterActor.connectToHorde(testApiKey) } - private fun connectToLocalDiffusion(): Single> { + private fun connectToLocalDiffusionOnnx(): Single> { preferenceManager.localOnnxCustomModelPath = currentState.localOnnxCustomModelPath val localModelId = currentState.localOnnxModels.find { it.selected }?.id ?: "" - return setupConnectionInterActor.connectToLocal(localModelId) + return setupConnectionInterActor.connectToLocalOnnx(localModelId) + } + + private fun connectToLocalDiffusionCpp(): Single> { + preferenceManager.localCppModelId = currentState.localCppCustomModelPath + val localModelId = currentState.localCppModels.find { it.selected }?.id ?: "" + return setupConnectionInterActor.connectToLocalCpp(localModelId) } private fun connectToMediaPipe(): Single> { diff --git a/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/setup/components/ConfigurationModeButton.kt b/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/setup/components/ConfigurationModeButton.kt index f8626a6e..c326b71f 100644 --- a/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/setup/components/ConfigurationModeButton.kt +++ b/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/setup/components/ConfigurationModeButton.kt @@ -15,7 +15,6 @@ import androidx.compose.material.icons.Icons import androidx.compose.material.icons.filled.Android import androidx.compose.material.icons.filled.Cloud import androidx.compose.material.icons.filled.Computer -import androidx.compose.material.icons.filled.QuestionMark import androidx.compose.material3.Icon import androidx.compose.material3.MaterialTheme import androidx.compose.material3.Text @@ -55,11 +54,11 @@ fun ConfigurationModeButton( cornerRadius = CornerRadius(16.dp.toPx()), ) if (state.mode != mode) return@drawBehind - drawRoundRect( - color = borderColor, - style = Stroke(2.dp.toPx()), - cornerRadius = CornerRadius(16.dp.toPx()), - ) + drawRoundRect( + color = borderColor, + style = Stroke(2.dp.toPx()), + cornerRadius = CornerRadius(16.dp.toPx()), + ) } .clickable { onClick(mode) } .padding(horizontal = 4.dp) @@ -80,7 +79,8 @@ fun ConfigurationModeButton( ServerSource.HUGGING_FACE -> Icons.Default.Cloud ServerSource.LOCAL_MICROSOFT_ONNX, - ServerSource.LOCAL_GOOGLE_MEDIA_PIPE -> Icons.Default.Android + ServerSource.LOCAL_GOOGLE_MEDIA_PIPE, + ServerSource.LOCAL_CPP -> Icons.Default.Android }, contentDescription = null, ) @@ -102,6 +102,7 @@ fun ConfigurationModeButton( ServerSource.STABILITY_AI -> LocalizationR.string.hint_stability_ai_sub_title ServerSource.SWARM_UI -> LocalizationR.string.hint_swarm_ui_sub_title ServerSource.LOCAL_GOOGLE_MEDIA_PIPE -> LocalizationR.string.hint_mediapipe_sub_title + ServerSource.LOCAL_CPP -> LocalizationR.string.hint_local_diffusion_cpp_title } descriptionId?.let { resId -> Text( diff --git a/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/setup/forms/LocalDiffusionForm.kt b/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/setup/forms/LocalDiffusionForm.kt index dac9efb9..80787b03 100644 --- a/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/setup/forms/LocalDiffusionForm.kt +++ b/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/setup/forms/LocalDiffusionForm.kt @@ -94,6 +94,7 @@ fun LocalDiffusionForm( is DownloadState.Downloading -> Icons.Outlined.FileDownload else -> when { model.id == LocalAiModel.CustomOnnx.id -> Icons.Outlined.Landslide + model.id == LocalAiModel.CustomCpp.id -> Icons.Outlined.Landslide model.id == LocalAiModel.CustomMediaPipe.id -> Icons.Outlined.Landslide model.downloaded -> Icons.Outlined.FileDownloadDone else -> Icons.Outlined.FileDownloadOff @@ -118,6 +119,7 @@ fun LocalDiffusionForm( ) when (model.id) { LocalAiModel.CustomOnnx.id, + LocalAiModel.CustomCpp.id, LocalAiModel.CustomMediaPipe.id -> Unit else -> Text( @@ -272,10 +274,10 @@ fun LocalDiffusionForm( .fillMaxWidth() .padding(top = 32.dp, bottom = 8.dp), text = stringResource( - id = if (state.mode == ServerSource.LOCAL_MICROSOFT_ONNX) { - LocalizationR.string.hint_local_diffusion_title - } else { - LocalizationR.string.hint_mediapipe_title + id = when (state.mode) { + ServerSource.LOCAL_CPP -> LocalizationR.string.hint_local_diffusion_cpp_title + ServerSource.LOCAL_MICROSOFT_ONNX -> LocalizationR.string.hint_local_diffusion_title + else -> LocalizationR.string.hint_mediapipe_title }, ), style = MaterialTheme.typography.bodyLarge, @@ -285,10 +287,10 @@ fun LocalDiffusionForm( Text( modifier = Modifier.padding(top = 16.dp, bottom = 16.dp), text = stringResource( - id = if (state.mode == ServerSource.LOCAL_MICROSOFT_ONNX) { - LocalizationR.string.hint_local_diffusion_sub_title - } else { - LocalizationR.string.hint_mediapipe_sub_title + id = when (state.mode) { + ServerSource.LOCAL_CPP -> LocalizationR.string.hint_local_diffusion_cpp_sub_title + ServerSource.LOCAL_MICROSOFT_ONNX -> LocalizationR.string.hint_local_diffusion_sub_title + else -> LocalizationR.string.hint_mediapipe_sub_title }, ), style = MaterialTheme.typography.bodyMedium, @@ -416,8 +418,10 @@ fun LocalDiffusionForm( } state.localModels .filter { - val customPredicate = - it.id == LocalAiModel.CustomOnnx.id || it.id == LocalAiModel.CustomMediaPipe.id + val customPredicate = it.id == LocalAiModel.CustomOnnx.id + || it.id == LocalAiModel.CustomMediaPipe.id + || it.id == LocalAiModel.CustomCpp.id + if (state.localCustomModel) customPredicate else !customPredicate } .forEach { localModel -> modelItemUi(localModel) } diff --git a/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/setup/mappers/LocalModelMappers.kt b/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/setup/mappers/LocalModelMappers.kt index 3cd1e443..65a2c435 100644 --- a/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/setup/mappers/LocalModelMappers.kt +++ b/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/setup/mappers/LocalModelMappers.kt @@ -8,6 +8,9 @@ fun List.mapToUi(): List = map(LocalA fun List.mapLocalCustomOnnxSwitchState(): Boolean = find { it.selected && it.id == LocalAiModel.CustomOnnx.id } != null +fun List.mapLocalCustomCppSwitchState(): Boolean = + find { it.selected && it.id == LocalAiModel.CustomCpp.id } != null + fun List.mapLocalCustomMediaPipeSwitchState(): Boolean = find { it.selected && it.id == LocalAiModel.CustomMediaPipe.id } != null diff --git a/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/setup/steps/ConfigurationStep.kt b/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/setup/steps/ConfigurationStep.kt index d25e8237..4823848e 100644 --- a/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/setup/steps/ConfigurationStep.kt +++ b/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/setup/steps/ConfigurationStep.kt @@ -65,6 +65,12 @@ fun ConfigurationStep( buildInfoProvider = buildInfoProvider, processIntent = processIntent, ) + + ServerSource.LOCAL_CPP -> LocalDiffusionForm( + state = state, + buildInfoProvider = buildInfoProvider, + processIntent = processIntent, + ) } } } diff --git a/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/txt2img/TextToImageViewModel.kt b/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/txt2img/TextToImageViewModel.kt index 06b3520f..bd1fc6dd 100755 --- a/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/txt2img/TextToImageViewModel.kt +++ b/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/txt2img/TextToImageViewModel.kt @@ -12,7 +12,6 @@ import com.shifthackz.aisdv1.core.validation.dimension.DimensionValidator import com.shifthackz.aisdv1.domain.entity.HordeProcessStatus import com.shifthackz.aisdv1.domain.entity.LocalDiffusionStatus import com.shifthackz.aisdv1.domain.entity.ServerSource -import com.shifthackz.aisdv1.domain.feature.diffusion.LocalDiffusion import com.shifthackz.aisdv1.domain.feature.work.BackgroundTaskManager import com.shifthackz.aisdv1.domain.feature.work.BackgroundWorkObserver import com.shifthackz.aisdv1.domain.interactor.wakelock.WakeLockInterActor diff --git a/presentation/src/main/java/com/shifthackz/aisdv1/presentation/widget/engine/EngineSelectionComponent.kt b/presentation/src/main/java/com/shifthackz/aisdv1/presentation/widget/engine/EngineSelectionComponent.kt index 9799f9c8..da11eb0e 100644 --- a/presentation/src/main/java/com/shifthackz/aisdv1/presentation/widget/engine/EngineSelectionComponent.kt +++ b/presentation/src/main/java/com/shifthackz/aisdv1/presentation/widget/engine/EngineSelectionComponent.kt @@ -66,6 +66,7 @@ fun EngineSelectionComponent( ServerSource.LOCAL_GOOGLE_MEDIA_PIPE -> Unit ServerSource.HORDE -> Unit ServerSource.OPEN_AI -> Unit + ServerSource.LOCAL_CPP -> Unit } } } diff --git a/presentation/src/main/java/com/shifthackz/aisdv1/presentation/widget/source/ServerSourceLabel.kt b/presentation/src/main/java/com/shifthackz/aisdv1/presentation/widget/source/ServerSourceLabel.kt index 4ee15fd4..2854d125 100644 --- a/presentation/src/main/java/com/shifthackz/aisdv1/presentation/widget/source/ServerSourceLabel.kt +++ b/presentation/src/main/java/com/shifthackz/aisdv1/presentation/widget/source/ServerSourceLabel.kt @@ -21,4 +21,5 @@ fun ServerSource.getNameUiText(): UiText = when (this) { ServerSource.OPEN_AI -> LocalizationR.string.srv_type_open_ai ServerSource.STABILITY_AI -> LocalizationR.string.srv_type_stability_ai ServerSource.SWARM_UI -> LocalizationR.string.srv_type_swarm_ui + ServerSource.LOCAL_CPP -> LocalizationR.string.srv_type_local_cpp }.asUiText() diff --git a/settings.gradle.kts b/settings.gradle.kts index fca65518..8990bb94 100755 --- a/settings.gradle.kts +++ b/settings.gradle.kts @@ -34,7 +34,8 @@ val modules = listOf( ":demo", ":domain", ":feature:auth", - ":feature:diffusion", + ":feature:local-diffusion-onnx", + ":feature:local-diffusion-cpp", ":feature:mediapipe", ":feature:work", ":network",