Skip to content

Stable Diffusion C++ Implementation #459

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 1 commit into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion app/build.gradle.kts
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,8 @@ dependencies {
implementation(project(":storage"))
implementation(project(":domain"))
implementation(project(":feature:auth"))
implementation(project(":feature:diffusion"))
implementation(project(":feature:local-diffusion-onnx"))
implementation(project(":feature:local-diffusion-cpp"))
implementation(project(":feature:mediapipe"))
implementation(project(":feature:work"))
implementation(project(":data"))
Expand Down
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
package com.shifthackz.aisdv1.app.di

import com.shifthackz.aisdv1.feature.auth.di.authModule
import com.shifthackz.aisdv1.feature.diffusion.di.diffusionModule
import com.shifthackz.aisdv1.feature.localdiffusion.cpp.di.cppLocalDiffusionModule
import com.shifthackz.aisdv1.feature.localdiffusion.onnx.di.onnxLocalDiffusionModule
import com.shifthackz.aisdv1.feature.mediapipe.di.mediaPipeModule

val featureModule = arrayOf(
authModule,
diffusionModule,
cppLocalDiffusionModule,
onnxLocalDiffusionModule,
mediaPipeModule,
)
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,10 @@ import com.shifthackz.aisdv1.domain.entity.ServerSource
import com.shifthackz.aisdv1.domain.feature.auth.AuthorizationCredentials
import com.shifthackz.aisdv1.domain.feature.auth.AuthorizationStore
import com.shifthackz.aisdv1.domain.preference.PreferenceManager
import com.shifthackz.aisdv1.feature.diffusion.entity.LocalDiffusionFlag
import com.shifthackz.aisdv1.feature.diffusion.environment.DeviceNNAPIFlagProvider
import com.shifthackz.aisdv1.feature.diffusion.environment.LocalModelIdProvider
import com.shifthackz.aisdv1.feature.localdiffusion.cpp.environment.LocalCppModelIdProvider
import com.shifthackz.aisdv1.feature.localdiffusion.onnx.entity.LocalDiffusionFlag
import com.shifthackz.aisdv1.feature.localdiffusion.onnx.environment.DeviceNNAPIFlagProvider
import com.shifthackz.aisdv1.feature.localdiffusion.onnx.environment.LocalOnnxModelIdProvider
import com.shifthackz.aisdv1.network.qualifiers.ApiKeyProvider
import com.shifthackz.aisdv1.network.qualifiers.ApiUrlProvider
import com.shifthackz.aisdv1.network.qualifiers.CredentialsProvider
Expand Down Expand Up @@ -184,7 +185,11 @@ val providersModule = module {
}

single {
LocalModelIdProvider { get<PreferenceManager>().localOnnxModelId }
LocalOnnxModelIdProvider { get<PreferenceManager>().localOnnxModelId }
}

single {
LocalCppModelIdProvider { get<PreferenceManager>().localCppModelId }
}

single {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,8 @@ internal fun Project.configureKotlinAndroid(
commonExtension: CommonExtension<*, *, *, *, *, *>,
) {
commonExtension.apply {
compileSdk = libs.findVersion("compileSdk").get().toString().toInt()
compileSdk = libs.findVersion("compileSdk").get().toString().toInt()
ndkVersion = libs.findVersion("ndk").get().toString()

defaultConfig {
minSdk = libs.findVersion("minSdk").get().toString().toInt()
Expand Down
5 changes: 5 additions & 0 deletions core/localization/src/main/res/values/strings.xml
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,9 @@
<string name="srv_type_horde" translatable="false">Horde AI Cloud</string>
<string name="srv_type_horde_short" translatable="false">Horde</string>
<string name="srv_type_local" translatable="false">Local Diffusion Microsoft ONNX (Beta)</string>
<string name="srv_type_local_cpp" translatable="false">Local Diffusion C++ (Beta)</string>
<string name="srv_type_local_short" translatable="false">ONNX</string>
<string name="srv_type_local_cpp_short" translatable="false">CPP</string>
<string name="srv_type_media_pipe" translatable="false">Local Diffusion Google AI MediaPipe (Beta)</string>
<string name="srv_type_media_pipe_short" translatable="false">MediaPipe</string>
<string name="srv_type_hugging_face" translatable="false">Hugging Face Inference</string>
Expand Down Expand Up @@ -156,6 +158,9 @@
<string name="hint_local_diffusion_sub_title">This configuration uses Microsoft ONNX runtime and allows to run Stable Diffusion AI generations on your phone, with no need to connect to remote server/cloud.</string>
<string name="hint_local_diffusion_warning">Warning! Local Diffusion functionality is in beta-test. Don\'t expect for high quality images using local mode. \n\nThis implementation may not work well on non-powerful phones. Generation performance and speed depends on your phone resources (CPU, RAM) and the size of generated image (the smaller the image size, the faster the generation).</string>

<string name="hint_local_diffusion_cpp_title" translatable="false">Local Diffusion C++</string>
<string name="hint_local_diffusion_cpp_sub_title" translatable="false">Changeme</string>

<string name="hint_mediapipe_title" translatable="false">Local Diffusion Google AI MediaPipe</string>
<string name="hint_mediapipe_sub_title">This configuration uses Google AI MediaPipe and allows to run Stable Diffusion AI generations on your phone, with no need to connect to remote server/cloud.</string>

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,8 @@ import com.shifthackz.aisdv1.data.repository.GenerationResultRepositoryImpl
import com.shifthackz.aisdv1.data.repository.HordeGenerationRepositoryImpl
import com.shifthackz.aisdv1.data.repository.HuggingFaceGenerationRepositoryImpl
import com.shifthackz.aisdv1.data.repository.HuggingFaceModelsRepositoryImpl
import com.shifthackz.aisdv1.data.repository.LocalDiffusionGenerationRepositoryImpl
import com.shifthackz.aisdv1.data.repository.LocalDiffusionCppGenerationRepositoryImpl
import com.shifthackz.aisdv1.data.repository.LocalDiffusionOnnxGenerationRepositoryImpl
import com.shifthackz.aisdv1.data.repository.LorasRepositoryImpl
import com.shifthackz.aisdv1.data.repository.MediaPipeGenerationRepositoryImpl
import com.shifthackz.aisdv1.data.repository.OpenAiGenerationRepositoryImpl
Expand All @@ -33,7 +34,8 @@ import com.shifthackz.aisdv1.domain.repository.GenerationResultRepository
import com.shifthackz.aisdv1.domain.repository.HordeGenerationRepository
import com.shifthackz.aisdv1.domain.repository.HuggingFaceGenerationRepository
import com.shifthackz.aisdv1.domain.repository.HuggingFaceModelsRepository
import com.shifthackz.aisdv1.domain.repository.LocalDiffusionGenerationRepository
import com.shifthackz.aisdv1.domain.repository.LocalDiffusionCppGenerationRepository
import com.shifthackz.aisdv1.domain.repository.LocalDiffusionOnnxGenerationRepository
import com.shifthackz.aisdv1.domain.repository.LorasRepository
import com.shifthackz.aisdv1.domain.repository.MediaPipeGenerationRepository
import com.shifthackz.aisdv1.domain.repository.OpenAiGenerationRepository
Expand Down Expand Up @@ -66,7 +68,8 @@ val repositoryModule = module {
}

singleOf(::TemporaryGenerationResultRepositoryImpl) bind TemporaryGenerationResultRepository::class
factoryOf(::LocalDiffusionGenerationRepositoryImpl) bind LocalDiffusionGenerationRepository::class
factoryOf(::LocalDiffusionOnnxGenerationRepositoryImpl) bind LocalDiffusionOnnxGenerationRepository::class
factoryOf(::LocalDiffusionCppGenerationRepositoryImpl) bind LocalDiffusionCppGenerationRepository::class
factoryOf(::MediaPipeGenerationRepositoryImpl) bind MediaPipeGenerationRepository::class
factoryOf(::HordeGenerationRepositoryImpl) bind HordeGenerationRepository::class
factoryOf(::HuggingFaceGenerationRepositoryImpl) bind HuggingFaceGenerationRepository::class
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,19 @@ internal class DownloadableModelLocalDataSource(
}
.flatMap { models -> models.withLocalData() }

override fun getAllCpp(): Single<List<LocalAiModel>> = dao
.queryByType(LocalAiModel.Type.Cpp.key)
.map(List<LocalModelEntity>::mapEntityToDomain)
.map { models ->
buildList {
addAll(models)
if (buildInfoProvider.type != BuildType.PLAY) {
add(LocalAiModel.CustomCpp)
}
}
}
.flatMap { models -> models.withLocalData() }

override fun getAllMediaPipe(): Single<List<LocalAiModel>> = dao
.queryByType(LocalAiModel.Type.MediaPipe.key)
.map(List<LocalModelEntity>::mapEntityToDomain)
Expand Down Expand Up @@ -89,20 +102,19 @@ internal class DownloadableModelLocalDataSource(
try {
when (model.id) {
LocalAiModel.CustomOnnx.id,
LocalAiModel.CustomCpp.id,
LocalAiModel.CustomMediaPipe.id -> emitter.onSuccess(true)

else -> {

when (model.type) {
LocalAiModel.Type.ONNX -> {
val files = getLocalModelFiles(model.id).filter { it.isDirectory }
emitter.onSuccess(files.size == 4)
}
else -> when (model.type) {
LocalAiModel.Type.ONNX -> {
val files = getLocalModelFiles(model.id).filter { it.isDirectory }
emitter.onSuccess(files.size == 4)
}

LocalAiModel.Type.MediaPipe -> {
val files = getLocalModelFiles(model.id)
emitter.onSuccess(files.isNotEmpty())
}
LocalAiModel.Type.MediaPipe,
LocalAiModel.Type.Cpp -> {
val files = getLocalModelFiles(model.id)
emitter.onSuccess(files.isNotEmpty())
}
}
}
Expand Down Expand Up @@ -133,6 +145,7 @@ internal class DownloadableModelLocalDataSource(
selected = when (this.type) {
LocalAiModel.Type.ONNX -> preferenceManager.localOnnxModelId == id
LocalAiModel.Type.MediaPipe -> preferenceManager.localMediaPipeModelId == id
LocalAiModel.Type.Cpp -> preferenceManager.localCppModelId == id
},
)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,11 @@ class PreferenceManagerImpl(
default = LOCAL_DIFFUSION_CUSTOM_PATH,
)

override var localCppCustomModelPath: String by preferences.delegates.string(
key = KEY_LOCAL_DIFFUSION_CPP_CUSTOM_MODEL_PATH,
default = LOCAL_DIFFUSION_CUSTOM_PATH,
)

override var localOnnxCustomModelPath: String by preferences.delegates.string(
key = KEY_LOCAL_DIFFUSION_CUSTOM_MODEL_PATH,
default = LOCAL_DIFFUSION_CUSTOM_PATH,
Expand Down Expand Up @@ -166,7 +171,12 @@ class PreferenceManagerImpl(
)

override var localOnnxModelId: String by preferences.delegates.string(
key = KEY_LOCAL_MODEL_ID,
key = KEY_LOCAL_ONNX_MODEL_ID,
onChanged = ::onPreferencesChanged,
)

override var localCppModelId: String by preferences.delegates.string(
key = KEY_LOCAL_CPP_MODEL_ID,
onChanged = ::onPreferencesChanged,
)

Expand Down Expand Up @@ -266,6 +276,7 @@ class PreferenceManagerImpl(
const val KEY_SWARM_MODEL = "key_swarm_model"
const val KEY_DEMO_MODE = "key_demo_mode"
const val KEY_DEVELOPER_MODE = "key_developer_mode"
const val KEY_LOCAL_DIFFUSION_CPP_CUSTOM_MODEL_PATH = "key_local_diffusion_cpp_custom_model_path"
const val KEY_LOCAL_DIFFUSION_CUSTOM_MODEL_PATH = "key_local_diffusion_custom_model_path"
const val KEY_MEDIA_PIPE_CUSTOM_MODEL_PATH = "key_mediapipe_custom_model_path"
const val KEY_ALLOW_LOCAL_DIFFUSION_CANCEL = "key_allow_local_diffusion_cancel"
Expand All @@ -286,7 +297,8 @@ class PreferenceManagerImpl(
const val KEY_ON_BOARDING_COMPLETE = "key_on_boarding_complete"
const val KEY_FORCE_SETUP_AFTER_UPDATE = "force_upd_setup_v0.x.x-v0.6.2"
const val KEY_MEDIA_PIPE_MODEL_ID = "key_mediapipe_model_id"
const val KEY_LOCAL_MODEL_ID = "key_local_model_id"
const val KEY_LOCAL_ONNX_MODEL_ID = "key_local_model_id"
const val KEY_LOCAL_CPP_MODEL_ID = "key_local_cpp_model_id"
const val KEY_LOCAL_NN_API = "key_local_nn_api"
const val KEY_DESIGN_DYNAMIC_COLORS = "key_design_dynamic_colors"
const val KEY_DESIGN_SYSTEM_DARK_THEME = "key_design_system_dark_theme"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,28 @@ internal class DownloadableModelRemoteDataSource(
api
.fetchOnnxModels()
.map { it.mapRawToCheckpointDomain(LocalAiModel.Type.ONNX) },
// api
// .fetchCppModels()
// .map { it.mapRawToCheckpointDomain(LocalAiModel.Type.Cpp) },
Single.just(
listOf(
LocalAiModel(
id = "abb51642-58f9-4401-9ca1-c0eac29e5c1d",
type = LocalAiModel.Type.Cpp,
name = "Test",
size = "Unknown",
sources = listOf(
"https://share.moroz.cc/SDAI/cpp/epicphotogasmLCM_ultimatefidelity.zip"
),
)
)
),
api
.fetchMediaPipeModels()
.map { it.mapRawToCheckpointDomain(LocalAiModel.Type.MediaPipe) },
::Pair,
::Triple,
)
.map { (onnx, mediapipe) -> listOf(onnx, mediapipe).flatten() }
.map { (onnx, cpp, mediapipe) -> listOf(onnx, cpp, mediapipe).flatten() }

override fun download(id: String, url: String): Observable<DownloadState> = Completable
.fromAction {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,12 @@ internal class DownloadableModelRepositoryImpl(
.andThen(localDataSource.getAllOnnx())
.onErrorResumeNext { localDataSource.getAllOnnx() }

override fun getAllCpp(): Single<List<LocalAiModel>> = remoteDataSource
.fetch()
.flatMapCompletable(localDataSource::save)
.andThen(localDataSource.getAllCpp())
.onErrorResumeNext { localDataSource.getAllCpp() }

override fun getAllMediaPipe(): Single<List<LocalAiModel>> {
if (buildInfoProvider.type == BuildType.FOSS) {
return Single.just(emptyList())
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
package com.shifthackz.aisdv1.data.repository

import com.shifthackz.aisdv1.core.common.schedulers.SchedulersProvider
import com.shifthackz.aisdv1.core.imageprocessing.Base64ToBitmapConverter
import com.shifthackz.aisdv1.core.imageprocessing.BitmapToBase64Converter
import com.shifthackz.aisdv1.data.core.CoreGenerationRepository
import com.shifthackz.aisdv1.data.mappers.mapLocalDiffusionToAiGenResult
import com.shifthackz.aisdv1.domain.datasource.GenerationResultDataSource
import com.shifthackz.aisdv1.domain.entity.TextToImagePayload
import com.shifthackz.aisdv1.domain.feature.diffusion.LocalDiffusionCpp
import com.shifthackz.aisdv1.domain.feature.work.BackgroundWorkObserver
import com.shifthackz.aisdv1.domain.gateway.MediaStoreGateway
import com.shifthackz.aisdv1.domain.preference.PreferenceManager
import com.shifthackz.aisdv1.domain.repository.LocalDiffusionCppGenerationRepository

internal class LocalDiffusionCppGenerationRepositoryImpl(
mediaStoreGateway: MediaStoreGateway,
base64ToBitmapConverter: Base64ToBitmapConverter,
localDataSource: GenerationResultDataSource.Local,
backgroundWorkObserver: BackgroundWorkObserver,
private val preferenceManager: PreferenceManager,
private val localDiffusion: LocalDiffusionCpp,
private val bitmapToBase64Converter: BitmapToBase64Converter,
private val schedulersProvider: SchedulersProvider,
) : CoreGenerationRepository(
mediaStoreGateway = mediaStoreGateway,
base64ToBitmapConverter = base64ToBitmapConverter,
localDataSource = localDataSource,
preferenceManager = preferenceManager,
backgroundWorkObserver = backgroundWorkObserver,
), LocalDiffusionCppGenerationRepository {

override fun generateFromText(payload: TextToImagePayload) = localDiffusion
.process(payload)
.subscribeOn(schedulersProvider.byToken(preferenceManager.localOnnxSchedulerThread))
.map(BitmapToBase64Converter::Input)
.flatMap(bitmapToBase64Converter::invoke)
.map(BitmapToBase64Converter.Output::base64ImageString)
.map { base64 -> payload to base64 }
.map(Pair<TextToImagePayload, String>::mapLocalDiffusionToAiGenResult)
.flatMap(::insertGenerationResult)
}
Original file line number Diff line number Diff line change
Expand Up @@ -8,20 +8,20 @@ import com.shifthackz.aisdv1.data.mappers.mapLocalDiffusionToAiGenResult
import com.shifthackz.aisdv1.domain.datasource.DownloadableModelDataSource
import com.shifthackz.aisdv1.domain.datasource.GenerationResultDataSource
import com.shifthackz.aisdv1.domain.entity.TextToImagePayload
import com.shifthackz.aisdv1.domain.feature.diffusion.LocalDiffusion
import com.shifthackz.aisdv1.domain.feature.diffusion.LocalDiffusionONNX
import com.shifthackz.aisdv1.domain.feature.work.BackgroundWorkObserver
import com.shifthackz.aisdv1.domain.gateway.MediaStoreGateway
import com.shifthackz.aisdv1.domain.preference.PreferenceManager
import com.shifthackz.aisdv1.domain.repository.LocalDiffusionGenerationRepository
import com.shifthackz.aisdv1.domain.repository.LocalDiffusionOnnxGenerationRepository
import io.reactivex.rxjava3.core.Single

internal class LocalDiffusionGenerationRepositoryImpl(
internal class LocalDiffusionOnnxGenerationRepositoryImpl(
mediaStoreGateway: MediaStoreGateway,
base64ToBitmapConverter: Base64ToBitmapConverter,
localDataSource: GenerationResultDataSource.Local,
backgroundWorkObserver: BackgroundWorkObserver,
private val preferenceManager: PreferenceManager,
private val localDiffusion: LocalDiffusion,
private val localDiffusion: LocalDiffusionONNX,
private val downloadableLocalDataSource: DownloadableModelDataSource.Local,
private val bitmapToBase64Converter: BitmapToBase64Converter,
private val schedulersProvider: SchedulersProvider,
Expand All @@ -31,7 +31,7 @@ internal class LocalDiffusionGenerationRepositoryImpl(
localDataSource = localDataSource,
preferenceManager = preferenceManager,
backgroundWorkObserver = backgroundWorkObserver,
), LocalDiffusionGenerationRepository {
), LocalDiffusionOnnxGenerationRepository {

override fun observeStatus() = localDiffusion.observeStatus()

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ import com.shifthackz.aisdv1.data.preference.PreferenceManagerImpl.Companion.KEY
import com.shifthackz.aisdv1.data.preference.PreferenceManagerImpl.Companion.KEY_HORDE_API_KEY
import com.shifthackz.aisdv1.data.preference.PreferenceManagerImpl.Companion.KEY_HUGGING_FACE_API_KEY
import com.shifthackz.aisdv1.data.preference.PreferenceManagerImpl.Companion.KEY_HUGGING_FACE_MODEL_KEY
import com.shifthackz.aisdv1.data.preference.PreferenceManagerImpl.Companion.KEY_LOCAL_MODEL_ID
import com.shifthackz.aisdv1.data.preference.PreferenceManagerImpl.Companion.KEY_LOCAL_ONNX_MODEL_ID
import com.shifthackz.aisdv1.data.preference.PreferenceManagerImpl.Companion.KEY_LOCAL_NN_API
import com.shifthackz.aisdv1.data.preference.PreferenceManagerImpl.Companion.KEY_MONITOR_CONNECTIVITY
import com.shifthackz.aisdv1.data.preference.PreferenceManagerImpl.Companion.KEY_OPEN_AI_API_KEY
Expand Down Expand Up @@ -372,12 +372,12 @@ class PreferenceManagerImplTest {

@Test
fun `given user reads default localModelId, changes it, expected default value, then changed value`() {
whenever(stubPreference.getString(eq(KEY_LOCAL_MODEL_ID), any()))
whenever(stubPreference.getString(eq(KEY_LOCAL_ONNX_MODEL_ID), any()))
.thenReturn("")

Assert.assertEquals("", preferenceManager.localOnnxModelId)

whenever(stubPreference.getString(eq(KEY_LOCAL_MODEL_ID), any()))
whenever(stubPreference.getString(eq(KEY_LOCAL_ONNX_MODEL_ID), any()))
.thenReturn("key")

preferenceManager.localOnnxModelId = "key"
Expand Down
Loading