6
6
package actions
7
7
8
8
import (
9
- "testing"
10
9
"encoding/json"
11
10
"log"
11
+ "testing"
12
12
13
13
"github.com/aws/aws-sdk-go-v2/aws"
14
14
"github.com/aws/aws-sdk-go-v2/service/bedrockruntime"
15
- "github.com/awsdocs/aws-doc-sdk-examples/gov2/bedrock-runtime/stubs"
16
- "github.com/awsdocs/aws-doc-sdk-examples/gov2/testtools"
15
+ "github.com/awsdocs/aws-doc-sdk-examples/gov2/bedrock-runtime/stubs"
16
+ "github.com/awsdocs/aws-doc-sdk-examples/gov2/testtools"
17
17
)
18
18
19
19
const CLAUDE_MODEL_ID = "anthropic.claude-v2"
@@ -24,7 +24,7 @@ const TITAN_IMAGE_MODEL_ID = "amazon.titan-image-generator-v1"
24
24
const prompt = "A test prompt"
25
25
26
26
func CallInvokeModelActions (sdkConfig aws.Config ) {
27
- defer func () {
27
+ defer func () {
28
28
if r := recover (); r != nil {
29
29
log .Println (r )
30
30
}
@@ -34,110 +34,117 @@ func CallInvokeModelActions(sdkConfig aws.Config) {
34
34
wrapper := InvokeModelWrapper {client }
35
35
36
36
claudeCompletion , err := wrapper .InvokeClaude (prompt )
37
- if err != nil {panic (err )}
37
+ if err != nil {
38
+ panic (err )
39
+ }
38
40
log .Println (claudeCompletion )
39
41
40
42
jurassic2Completion , err := wrapper .InvokeJurassic2 (prompt )
41
- if err != nil {panic (err )}
43
+ if err != nil {
44
+ panic (err )
45
+ }
42
46
log .Println (jurassic2Completion )
43
47
44
48
llama2Completion , err := wrapper .InvokeLlama2 (prompt )
45
- if err != nil {panic (err )}
49
+ if err != nil {
50
+ panic (err )
51
+ }
46
52
log .Println (llama2Completion )
47
53
48
- seed := int64 (0 )
54
+ seed := int64 (0 )
49
55
titanImageCompletion , err := wrapper .InvokeTitanImage (prompt , seed )
50
- if err != nil {panic (err )}
51
- log .Println (titanImageCompletion )
56
+ if err != nil {
57
+ panic (err )
58
+ }
59
+ log .Println (titanImageCompletion )
52
60
53
- log .Printf ("Thanks for watching!" )
61
+ log .Printf ("Thanks for watching!" )
54
62
}
55
63
56
64
func TestInvokeModels (t * testing.T ) {
57
- scenTest := InvokeModelActionsTest {}
58
- testtools .RunScenarioTests (& scenTest , t )
65
+ scenTest := InvokeModelActionsTest {}
66
+ testtools .RunScenarioTests (& scenTest , t )
59
67
}
60
68
61
- type InvokeModelActionsTest struct {}
62
-
69
+ type InvokeModelActionsTest struct {}
63
70
64
71
func (scenTest * InvokeModelActionsTest ) SetupDataAndStubs () []testtools.Stub {
65
- var stubList []testtools.Stub
66
- stubList = append (stubList , stubInvokeModel (CLAUDE_MODEL_ID ))
72
+ var stubList []testtools.Stub
73
+ stubList = append (stubList , stubInvokeModel (CLAUDE_MODEL_ID ))
67
74
stubList = append (stubList , stubInvokeModel (JURASSIC2_MODEL_ID ))
68
75
stubList = append (stubList , stubInvokeModel (LLAMA2_MODEL_ID ))
69
76
stubList = append (stubList , stubInvokeModel (TITAN_IMAGE_MODEL_ID ))
70
- return stubList
77
+ return stubList
71
78
}
72
79
73
80
func (scenTest * InvokeModelActionsTest ) RunSubTest (stubber * testtools.AwsmStubber ) {
74
- CallInvokeModelActions (* stubber .SdkConfig )
81
+ CallInvokeModelActions (* stubber .SdkConfig )
75
82
}
76
83
77
84
func (scenTest * InvokeModelActionsTest ) Cleanup () {}
78
85
79
- func stubInvokeModel (modelId string ) ( testtools.Stub ) {
80
- var request []byte
81
- var response []byte
82
-
83
- switch modelId {
84
- case CLAUDE_MODEL_ID :
85
- request , _ = json .Marshal (ClaudeRequest {
86
- Prompt : "Human: " + prompt + "\n \n Assistant:" ,
87
- MaxTokensToSample : 200 ,
88
- Temperature : 0.5 ,
89
- StopSequences : []string {"\n \n Human:" },
90
- })
91
- response , _ = json .Marshal (ClaudeResponse {
92
- Completion : "A fake response" ,
93
- })
94
-
95
- case JURASSIC2_MODEL_ID :
96
- request , _ = json .Marshal (Jurassic2Request {
97
- Prompt : prompt ,
98
- MaxTokens : 200 ,
99
- Temperature : 0.5 ,
100
- })
101
- response , _ = json .Marshal (Jurassic2Response {
102
- Completions : []Completion {
103
- { Data : Data { Text : "A fake response" , }, },
104
- },
105
- })
106
-
107
- case LLAMA2_MODEL_ID :
108
- request , _ = json .Marshal (Llama2Request {
109
- Prompt : prompt ,
110
- MaxGenLength : 512 ,
111
- Temperature : 0.5 ,
112
- })
113
- response , _ = json .Marshal (Llama2Response {
114
- Generation : "A fake response" ,
115
- })
116
-
117
- case TITAN_IMAGE_MODEL_ID :
118
- request , _ = json .Marshal (TitanImageRequest {
119
- TaskType : "TEXT_IMAGE" ,
120
- TextToImageParams : TextToImageParams {
121
- Text : prompt ,
122
- },
123
- ImageGenerationConfig : ImageGenerationConfig {
124
- NumberOfImages : 1 ,
125
- Quality : "standard" ,
126
- CfgScale : 8.0 ,
127
- Height : 512 ,
128
- Width : 512 ,
129
- Seed : 0 ,
130
- },
131
- })
132
- response , _ = json .Marshal (TitanImageResponse {
133
- Images : []string {"FakeBase64String==" },
134
- })
135
-
136
- default :
137
- return testtools.Stub {}
138
- }
139
-
140
- return stubs .StubInvokeModel (stubs.StubInvokeModelParams {
141
- request , response , modelId , nil ,
142
- })
86
+ func stubInvokeModel (modelId string ) testtools.Stub {
87
+ var request []byte
88
+ var response []byte
89
+
90
+ switch modelId {
91
+ case CLAUDE_MODEL_ID :
92
+ request , _ = json .Marshal (ClaudeRequest {
93
+ Prompt : "Human: " + prompt + "\n \n Assistant:" ,
94
+ MaxTokensToSample : 200 ,
95
+ Temperature : 0.5 ,
96
+ StopSequences : []string {"\n \n Human:" },
97
+ })
98
+ response , _ = json .Marshal (ClaudeResponse {
99
+ Completion : "A fake response" ,
100
+ })
101
+
102
+ case JURASSIC2_MODEL_ID :
103
+ request , _ = json .Marshal (Jurassic2Request {
104
+ Prompt : prompt ,
105
+ MaxTokens : 200 ,
106
+ Temperature : 0.5 ,
107
+ })
108
+ response , _ = json .Marshal (Jurassic2Response {
109
+ Completions : []Completion {
110
+ { Data : Data {Text : "A fake response" } },
111
+ },
112
+ })
113
+
114
+ case LLAMA2_MODEL_ID :
115
+ request , _ = json .Marshal (Llama2Request {
116
+ Prompt : prompt ,
117
+ MaxGenLength : 512 ,
118
+ Temperature : 0.5 ,
119
+ })
120
+ response , _ = json .Marshal (Llama2Response {
121
+ Generation : "A fake response" ,
122
+ })
123
+
124
+ case TITAN_IMAGE_MODEL_ID :
125
+ request , _ = json .Marshal (TitanImageRequest {
126
+ TaskType : "TEXT_IMAGE" ,
127
+ TextToImageParams : TextToImageParams {
128
+ Text : prompt ,
129
+ },
130
+ ImageGenerationConfig : ImageGenerationConfig {
131
+ NumberOfImages : 1 ,
132
+ Quality : "standard" ,
133
+ CfgScale : 8.0 ,
134
+ Height : 512 ,
135
+ Width : 512 ,
136
+ Seed : 0 ,
137
+ },
138
+ })
139
+ response , _ = json .Marshal (TitanImageResponse {
140
+ Images : []string {"FakeBase64String==" },
141
+ })
142
+
143
+ default :
144
+ return testtools.Stub {}
145
+ }
146
+
147
+ return stubs .StubInvokeModel (stubs.StubInvokeModelParams {
148
+ Request : request , Response : response , ModelId : modelId , RaiseErr : nil ,
149
+ })
143
150
}
0 commit comments