Skip to content

Commit dcfda09

Browse files
authored
feat: support nai-diffusion-3 model (#221)
1 parent 4d4e222 commit dcfda09

File tree

3 files changed

+66
-14
lines changed

3 files changed

+66
-14
lines changed

src/config.ts

+33-6
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ type Orient = keyof typeof orientMap
3232

3333
export const models = Object.keys(modelMap) as Model[]
3434
export const orients = Object.keys(orientMap) as Orient[]
35+
export const scheduler = ['native', 'karras', 'exponential', 'polyexponential'] as const
3536

3637
export namespace sampler {
3738
export const nai = {
@@ -42,6 +43,15 @@ export namespace sampler {
4243
'plms': 'PLMS',
4344
}
4445

46+
export const nai3 = {
47+
'k_euler': 'Euler',
48+
'k_euler_a': 'Euler ancestral',
49+
'k_dpmpp_2s_ancestral': 'DPM++ 2S ancestral',
50+
'k_dpmpp_2m': 'DPM++ 2M',
51+
'k_dpmpp_sde': 'DPM++ SDE',
52+
'ddim_v3': 'DDIM V3',
53+
}
54+
4555
// samplers in stable-diffusion-webui
4656
// https://github.com/AUTOMATIC1111/stable-diffusion-webui/blob/master/modules/sd_samplers_compvis.py#L12
4757
// https://github.com/AUTOMATIC1111/stable-diffusion-webui/blob/master/modules/sd_samplers_kdiffusion.py#L12
@@ -103,9 +113,10 @@ export namespace sampler {
103113
})).loose().description('默认的采样器。').default('k_euler_a')
104114
}
105115

106-
export function sd2nai(sampler: string): string {
116+
export function sd2nai(sampler: string, model: string): string {
107117
if (sampler === 'k_euler_a') return 'k_euler_ancestral'
108-
if (sampler in nai) return sampler
118+
if (model === 'nai-v3' && sampler in nai3) return sampler
119+
else if (sampler in nai) return sampler
109120
return 'k_euler_ancestral'
110121
}
111122
}
@@ -193,6 +204,10 @@ const features = Schema.object({
193204
interface ParamConfig {
194205
model?: Model
195206
sampler?: string
207+
smea?: boolean
208+
smeaDyn?: boolean
209+
scheduler?: string
210+
decrisper?: boolean
196211
upscaler?: string
197212
restoreFaces?: boolean
198213
hiresFix?: boolean
@@ -325,10 +340,22 @@ export const Config = Schema.intersect([
325340
type: Schema.const('naifu').required(),
326341
sampler: sampler.createSchema(sampler.nai),
327342
}),
328-
Schema.object({
329-
sampler: sampler.createSchema(sampler.nai),
330-
model: Schema.union(models).loose().description('默认的生成模型。').default('nai'),
331-
}),
343+
Schema.intersect([
344+
Schema.object({
345+
model: Schema.union(models).loose().description('默认的生成模型。').default('nai'),
346+
}),
347+
Schema.union([
348+
Schema.object({
349+
model: Schema.const('nai-v3').required(),
350+
sampler: sampler.createSchema(sampler.nai3),
351+
smea: Schema.boolean().description('默认启用 SMEA。'),
352+
smeaDyn: Schema.boolean().description('默认启用 SMEA 采样器的 DYN 变体。'),
353+
scheduler: Schema.union(scheduler).description('默认的调度器。').default('native'),
354+
}),
355+
Schema.object({ sampler: sampler.createSchema(sampler.nai) }),
356+
]),
357+
Schema.object({ decrisper: Schema.boolean().description('默认启用 decrisper') }),
358+
]),
332359
] as const),
333360

334361
Schema.object({

src/index.ts

+29-8
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
11
import { Computed, Context, Dict, h, Logger, omit, Quester, Session, SessionError, trimSlash } from 'koishi'
2-
import { Config, modelMap, models, orientMap, parseInput, sampler, upscalers } from './config'
2+
import { Config, modelMap, models, orientMap, parseInput, sampler, upscalers, scheduler } from './config'
33
import { ImageData, StableDiffusionWebUI } from './types'
44
import { closestMultiple, download, forceDataPrefix, getImageSize, login, NetworkError, project, resizeInput, Size } from './utils'
5-
import {} from '@koishijs/translator'
6-
import {} from '@koishijs/plugin-help'
5+
import { } from '@koishijs/translator'
6+
import { } from '@koishijs/plugin-help'
77
import AdmZip from 'adm-zip'
88

99
export * from './config'
@@ -107,6 +107,10 @@ export function apply(ctx: Context, config: Config) {
107107
.option('noise', '-n <noise:number>', { hidden: some(restricted, thirdParty) })
108108
.option('strength', '-N <strength:number>', { hidden: restricted })
109109
.option('hiresFix', '-H', { hidden: () => config.type !== 'sd-webui' })
110+
.option('smea', '-S', { hidden: () => config.model !== 'nai-v3' })
111+
.option('smeaDyn', '-d', { hidden: () => config.model !== 'nai-v3' })
112+
.option('scheduler', '-C <scheduler> ', { hidden: () => config.model !== 'nai-v3', type: scheduler })
113+
.option('decrisper', '-D', { hidden: thirdParty })
110114
.option('undesired', '-u <undesired>')
111115
.option('noTranslator', '-T', { hidden: () => !ctx.translator || !config.translator })
112116
.option('iterations', '-i <iterations:posint>', { fallback: 1, hidden: () => config.maxIterations <= 1 })
@@ -306,12 +310,29 @@ export function apply(ctx: Context, config: Config) {
306310
case 'login':
307311
case 'token':
308312
case 'naifu': {
309-
parameters.sampler = sampler.sd2nai(options.sampler)
313+
parameters.sampler = sampler.sd2nai(options.sampler, model)
310314
parameters.image = image?.base64 // NovelAI / NAIFU accepts bare base64 encoded image
311315
if (config.type === 'naifu') return parameters
312316
// The latest interface changes uc to negative_prompt, so that needs to be changed here as well
313-
parameters.negative_prompt = parameters.uc
314-
delete parameters.uc
317+
if (parameters.uc) {
318+
parameters.negative_prompt = parameters.uc
319+
delete parameters.uc
320+
}
321+
parameters.dynamic_thresholding = options.decrisper ?? config.decrisper
322+
if (model === 'nai-diffusion-3') {
323+
parameters.sm_dyn = options.smeaDyn ?? config.smeaDyn
324+
parameters.sm = (options.smea ?? config.smea) || parameters.sm_dyn
325+
parameters.noise_schedule = options.scheduler ?? config.scheduler
326+
if (['k_euler_ancestral', 'k_dpmpp_2s_ancestral'].includes(parameters.sampler)
327+
&& parameters.noise_schedule === 'karras') {
328+
parameters.noise_schedule = 'native'
329+
}
330+
if (parameters.sampler === 'ddim_v3') {
331+
parameters.sm = false
332+
parameters.sm_dyn = false
333+
delete parameters.noise_schedule
334+
}
335+
}
315336
return { model, input: prompt, parameters: omit(parameters, ['prompt']) }
316337
}
317338
case 'sd-webui': {
@@ -436,7 +457,7 @@ export function apply(ctx: Context, config: Config) {
436457
const b64 = Buffer.from(firstImageBuffer).toString('base64')
437458
return forceDataPrefix(b64, 'image/png')
438459
}
439-
460+
440461
return forceDataPrefix(res.data?.slice(27))
441462
}
442463

@@ -522,7 +543,7 @@ export function apply(ctx: Context, config: Config) {
522543
case 'stable-horde':
523544
return sampler.horde
524545
default:
525-
return sampler.nai
546+
return { ...sampler.nai, ...sampler.nai3 }
526547
}
527548
}
528549

src/locales/zh-CN.yml

+4
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,10 @@ commands:
2626
noTranslator: 禁用自动翻译
2727
iterations: 设置绘制次数
2828
batch: 设置绘制批次大小
29+
smea: 启用 SMEA
30+
smeaDyn: 启用 DYN
31+
scheduler: 设置调度器
32+
decrisper: 启用动态阈值
2933

3034
messages:
3135
exceed-max-iteration: 超过最大绘制次数。

0 commit comments

Comments
 (0)