Skip to content

Commit 0bbb393

Browse files
committed
Add context support
This commit contains an API breaking change. LState.NewThread now also returns a child context.Context.
1 parent 33ebc07 commit 0bbb393

9 files changed

+279
-14
lines changed

README.rst

+66
Original file line numberDiff line numberDiff line change
@@ -459,6 +459,72 @@ You can extend GopherLua with new types written in Go.
459459
}
460460
}
461461
462+
+++++++++++++++++++++++++++++++++++++++++
463+
Terminating a running LState
464+
+++++++++++++++++++++++++++++++++++++++++
465+
GopherLua supports the `Go Concurrency Patterns: Context <https://blog.golang.org/context>`_ .
466+
467+
468+
.. code-block:: go
469+
470+
L := lua.NewState()
471+
defer L.Close()
472+
ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second)
473+
defer cancel()
474+
// set the context to our LState
475+
L.SetContext(ctx)
476+
err := L.DoString(`
477+
local clock = os.clock
478+
function sleep(n) -- seconds
479+
local t0 = clock()
480+
while clock() - t0 <= n do end
481+
end
482+
sleep(3)
483+
`)
484+
// err.Error() contains "context deadline exceeded"
485+
486+
With coroutines
487+
488+
.. code-block:: go
489+
490+
L := lua.NewState()
491+
defer L.Close()
492+
ctx, cancel := context.WithCancel(context.Background())
493+
L.SetContext(ctx)
494+
defer cancel()
495+
L.DoString(`
496+
function coro()
497+
local i = 0
498+
while true do
499+
coroutine.yield(i)
500+
i = i+1
501+
end
502+
return i
503+
end
504+
`)
505+
co, cocancel := L.NewThread()
506+
defer cocancel()
507+
fn := L.GetGlobal("coro").(*LFunction)
508+
509+
_, err, values := L.Resume(co, fn) // err is nil
510+
511+
cancel() // cancel the parent context
512+
513+
_, err, values = L.Resume(co, fn) // err is NOT nil : child context was canceled
514+
515+
**Note that using a context causes performance degradation.**
516+
517+
.. code-block::
518+
519+
time ./glua-with-context.exe fib.lua
520+
9227465
521+
0.01s user 0.11s system 1% cpu 7.505 total
522+
523+
time ./glua-without-context.exe fib.lua
524+
9227465
525+
0.01s user 0.01s system 0% cpu 5.306 total
526+
527+
462528
+++++++++++++++++++++++++++++++++++++++++
463529
Goroutines
464530
+++++++++++++++++++++++++++++++++++++++++

_state.go

+33-4
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ package lua
33
import (
44
"fmt"
55
"github.com/yuin/gopher-lua/parse"
6+
"golang.org/x/net/context"
67
"io"
78
"math"
89
"os"
@@ -331,6 +332,8 @@ func newLState(options Options) *LState {
331332
wrapped: false,
332333
uvcache: nil,
333334
hasErrorFunc: false,
335+
mainLoop: mainLoop,
336+
ctx: nil,
334337
}
335338
ls.Env = ls.G.Global
336339
return ls
@@ -783,9 +786,9 @@ func (ls *LState) callR(nargs, nret, rbase int) {
783786
if ls.G.MainThread == nil {
784787
ls.G.MainThread = ls
785788
ls.G.CurrentThread = ls
786-
mainLoop(ls, nil)
789+
ls.mainLoop(ls, nil)
787790
} else {
788-
mainLoop(ls, ls.currentFrame)
791+
ls.mainLoop(ls, ls.currentFrame)
789792
}
790793
if nret != MultRet {
791794
ls.reg.SetTop(rbase + nret)
@@ -1115,11 +1118,18 @@ func (ls *LState) CreateTable(acap, hcap int) *LTable {
11151118
return newLTable(acap, hcap)
11161119
}
11171120

1118-
func (ls *LState) NewThread() *LState {
1121+
// NewThread returns a new LState that shares with the original state all global objects.
1122+
// If the original state has context.Context, the new state has a new child context of the original state and this function returns its cancel function.
1123+
func (ls *LState) NewThread() (*LState, context.CancelFunc) {
11191124
thread := newLState(ls.Options)
11201125
thread.G = ls.G
11211126
thread.Env = ls.Env
1122-
return thread
1127+
var f context.CancelFunc = nil
1128+
if ls.ctx != nil {
1129+
thread.mainLoop = mainLoopWithContext
1130+
thread.ctx, f = context.WithCancel(ls.ctx)
1131+
}
1132+
return thread, f
11231133
}
11241134

11251135
func (ls *LState) NewUserData() *LUserData {
@@ -1742,6 +1752,25 @@ func (ls *LState) SetMx(mx int) {
17421752
}()
17431753
}
17441754

1755+
// SetContext set a context ctx to this LState. The provided ctx must be non-nil.
1756+
func (ls *LState) SetContext(ctx context.Context) {
1757+
ls.mainLoop = mainLoopWithContext
1758+
ls.ctx = ctx
1759+
}
1760+
1761+
// Context returns the LState's context. To change the context, use WithContext.
1762+
func (ls *LState) Context() context.Context {
1763+
return ls.ctx
1764+
}
1765+
1766+
// RemoveContext removes the context associated with this LState and returns this context.
1767+
func (ls *LState) RemoveContext() context.Context {
1768+
oldctx := ls.ctx
1769+
ls.mainLoop = mainLoop
1770+
ls.ctx = nil
1771+
return oldctx
1772+
}
1773+
17451774
// Converts the Lua value at the given acceptable index to the chan LValue.
17461775
func (ls *LState) ToChannel(n int) chan LValue {
17471776
if lv, ok := ls.Get(n).(LChannel); ok {

_vm.go

+31-1
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,36 @@ func mainLoop(L *LState, baseframe *callFrame) {
3030
}
3131
}
3232

33+
func mainLoopWithContext(L *LState, baseframe *callFrame) {
34+
var inst uint32
35+
var cf *callFrame
36+
37+
if L.stack.IsEmpty() {
38+
return
39+
}
40+
41+
L.currentFrame = L.stack.Last()
42+
if L.currentFrame.Fn.IsG {
43+
callGFunction(L, false)
44+
return
45+
}
46+
47+
for {
48+
cf = L.currentFrame
49+
inst = cf.Fn.Proto.Code[cf.Pc]
50+
cf.Pc++
51+
select {
52+
case <-L.ctx.Done():
53+
L.RaiseError(L.ctx.Err().Error())
54+
return
55+
default:
56+
if jumpTable[int(inst>>26)](L, inst, baseframe) == 1 {
57+
return
58+
}
59+
}
60+
}
61+
}
62+
3363
func copyReturnValues(L *LState, regv, start, n, b int) { // +inline-start
3464
if b == 1 {
3565
// +inline-call L.reg.FillNil regv n
@@ -118,7 +148,7 @@ func threadRun(L *LState) {
118148
}
119149
}
120150
}()
121-
mainLoop(L, nil)
151+
L.mainLoop(L, nil)
122152
}
123153

124154
type instFunc func(*LState, uint32, *callFrame) int

auxlib_test.go

+1-1
Original file line numberDiff line numberDiff line change
@@ -109,7 +109,7 @@ func TestCheckThread(t *testing.T) {
109109
L := NewState()
110110
defer L.Close()
111111
errorIfGFuncNotFail(t, L, func(L *LState) int {
112-
th := L.NewThread()
112+
th, _ := L.NewThread()
113113
L.Push(th)
114114
errorIfNotEqual(t, th, L.CheckThread(2))
115115
L.Push(LNumber(10))

coroutinelib.go

+1-1
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ var coFuncs = map[string]LGFunction{
1818

1919
func coCreate(L *LState) int {
2020
fn := L.CheckFunction(1)
21-
newthread := L.NewThread()
21+
newthread, _ := L.NewThread()
2222
base := 0
2323
newthread.stack.Push(callFrame{
2424
Fn: fn,

state.go

+33-4
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ package lua
77
import (
88
"fmt"
99
"github.com/yuin/gopher-lua/parse"
10+
"golang.org/x/net/context"
1011
"io"
1112
"math"
1213
"os"
@@ -335,6 +336,8 @@ func newLState(options Options) *LState {
335336
wrapped: false,
336337
uvcache: nil,
337338
hasErrorFunc: false,
339+
mainLoop: mainLoop,
340+
ctx: nil,
338341
}
339342
ls.Env = ls.G.Global
340343
return ls
@@ -868,9 +871,9 @@ func (ls *LState) callR(nargs, nret, rbase int) {
868871
if ls.G.MainThread == nil {
869872
ls.G.MainThread = ls
870873
ls.G.CurrentThread = ls
871-
mainLoop(ls, nil)
874+
ls.mainLoop(ls, nil)
872875
} else {
873-
mainLoop(ls, ls.currentFrame)
876+
ls.mainLoop(ls, ls.currentFrame)
874877
}
875878
if nret != MultRet {
876879
ls.reg.SetTop(rbase + nret)
@@ -1200,11 +1203,18 @@ func (ls *LState) CreateTable(acap, hcap int) *LTable {
12001203
return newLTable(acap, hcap)
12011204
}
12021205

1203-
func (ls *LState) NewThread() *LState {
1206+
// NewThread returns a new LState that shares with the original state all global objects.
1207+
// If the original state has context.Context, the new state has a new child context of the original state and this function returns its cancel function.
1208+
func (ls *LState) NewThread() (*LState, context.CancelFunc) {
12041209
thread := newLState(ls.Options)
12051210
thread.G = ls.G
12061211
thread.Env = ls.Env
1207-
return thread
1212+
var f context.CancelFunc = nil
1213+
if ls.ctx != nil {
1214+
thread.mainLoop = mainLoopWithContext
1215+
thread.ctx, f = context.WithCancel(ls.ctx)
1216+
}
1217+
return thread, f
12081218
}
12091219

12101220
func (ls *LState) NewUserData() *LUserData {
@@ -1827,6 +1837,25 @@ func (ls *LState) SetMx(mx int) {
18271837
}()
18281838
}
18291839

1840+
// SetContext set a context ctx to this LState. The provided ctx must be non-nil.
1841+
func (ls *LState) SetContext(ctx context.Context) {
1842+
ls.mainLoop = mainLoopWithContext
1843+
ls.ctx = ctx
1844+
}
1845+
1846+
// Context returns the LState's context. To change the context, use WithContext.
1847+
func (ls *LState) Context() context.Context {
1848+
return ls.ctx
1849+
}
1850+
1851+
// RemoveContext removes the context associated with this LState and returns this context.
1852+
func (ls *LState) RemoveContext() context.Context {
1853+
oldctx := ls.ctx
1854+
ls.mainLoop = mainLoop
1855+
ls.ctx = nil
1856+
return oldctx
1857+
}
1858+
18301859
// Converts the Lua value at the given acceptable index to the chan LValue.
18311860
func (ls *LState) ToChannel(n int) chan LValue {
18321861
if lv, ok := ls.Get(n).(LChannel); ok {

state_test.go

+80-2
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
11
package lua
22

33
import (
4+
"golang.org/x/net/context"
45
"strings"
56
"testing"
7+
"time"
68
)
79

810
func TestCallStackOverflow(t *testing.T) {
@@ -265,7 +267,7 @@ func TestPCall(t *testing.T) {
265267
func TestCoroutineApi1(t *testing.T) {
266268
L := NewState()
267269
defer L.Close()
268-
co := L.NewThread()
270+
co, _ := L.NewThread()
269271
errorIfScriptFail(t, L, `
270272
function coro(v)
271273
assert(v == 10)
@@ -308,7 +310,7 @@ func TestCoroutineApi1(t *testing.T) {
308310
end
309311
`)
310312
fn = L.GetGlobal("coro_error").(*LFunction)
311-
co = L.NewThread()
313+
co, _ = L.NewThread()
312314
st, err, values = L.Resume(co, fn)
313315
errorIfNotEqual(t, ResumeYield, st)
314316
errorIfNotNil(t, err)
@@ -333,3 +335,79 @@ func TestCoroutineApi1(t *testing.T) {
333335
errorIfFalse(t, strings.Contains(err.Error(), "can not resume a dead thread"), "can not resume a dead thread")
334336

335337
}
338+
339+
func TestContextTimeout(t *testing.T) {
340+
L := NewState()
341+
defer L.Close()
342+
ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second)
343+
defer cancel()
344+
L.SetContext(ctx)
345+
errorIfNotEqual(t, ctx, L.Context())
346+
err := L.DoString(`
347+
local clock = os.clock
348+
function sleep(n) -- seconds
349+
local t0 = clock()
350+
while clock() - t0 <= n do end
351+
end
352+
sleep(3)
353+
`)
354+
errorIfNil(t, err)
355+
errorIfFalse(t, strings.Contains(err.Error(), "context deadline exceeded"), "execution must be canceled")
356+
357+
oldctx := L.RemoveContext()
358+
errorIfNotEqual(t, ctx, oldctx)
359+
errorIfNotNil(t, L.ctx)
360+
}
361+
362+
func TestContextCancel(t *testing.T) {
363+
L := NewState()
364+
defer L.Close()
365+
ctx, cancel := context.WithCancel(context.Background())
366+
errch := make(chan error, 1)
367+
L.SetContext(ctx)
368+
go func() {
369+
errch <- L.DoString(`
370+
local clock = os.clock
371+
function sleep(n) -- seconds
372+
local t0 = clock()
373+
while clock() - t0 <= n do end
374+
end
375+
sleep(3)
376+
`)
377+
}()
378+
time.Sleep(1 * time.Second)
379+
cancel()
380+
err := <-errch
381+
errorIfNil(t, err)
382+
errorIfFalse(t, strings.Contains(err.Error(), "context canceled"), "execution must be canceled")
383+
}
384+
385+
func TestContextWithCroutine(t *testing.T) {
386+
L := NewState()
387+
defer L.Close()
388+
ctx, cancel := context.WithCancel(context.Background())
389+
L.SetContext(ctx)
390+
defer cancel()
391+
L.DoString(`
392+
function coro()
393+
local i = 0
394+
while true do
395+
coroutine.yield(i)
396+
i = i+1
397+
end
398+
return i
399+
end
400+
`)
401+
co, cocancel := L.NewThread()
402+
defer cocancel()
403+
fn := L.GetGlobal("coro").(*LFunction)
404+
_, err, values := L.Resume(co, fn)
405+
errorIfNotNil(t, err)
406+
errorIfNotEqual(t, LNumber(0), values[0])
407+
// cancel the parent context
408+
cancel()
409+
_, err, values = L.Resume(co, fn)
410+
errorIfNil(t, err)
411+
errorIfFalse(t, strings.Contains(err.Error(), "context canceled"), "coroutine execution must be canceled when the parent context is canceled")
412+
413+
}

0 commit comments

Comments
 (0)