Skip to content

Commit 89277c0

Browse files
authored
Merge pull request #41 from jozic/fix-update
fix connection leaking on update and updateAndReturnGeneratedKey
2 parents a6648f3 + d960818 commit 89277c0

21 files changed

+133
-221
lines changed

build.sbt

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
lazy val _version = "0.6.0"
1+
lazy val _version = "0.6.1-SNAPSHOT"
22
lazy val scalikejdbcVersion = "2.4.2"
33
lazy val mauricioVersion = "0.2.20" // provided
44
lazy val postgresqlVersion = "9.4-1201-jdbc41"

core/src/main/scala/scalikejdbc/async/AsyncConnectionPool.scala

+14-25
Original file line numberDiff line numberDiff line change
@@ -21,23 +21,8 @@ import scalikejdbc._
2121
* Asynchronous Connection Pool
2222
*/
2323
abstract class AsyncConnectionPool(
24-
val url: String,
25-
val user: String,
26-
password: String,
2724
val settings: AsyncConnectionPoolSettings = AsyncConnectionPoolSettings()) {
2825

29-
type MauricioConfiguration = com.github.mauricio.async.db.Configuration
30-
31-
private[this] val jdbcUrl = JDBCUrl(url)
32-
33-
protected val config = new MauricioConfiguration(
34-
username = user,
35-
host = jdbcUrl.host,
36-
port = jdbcUrl.port,
37-
password = Option(password).filterNot(_.trim.isEmpty),
38-
database = Option(jdbcUrl.database).filterNot(_.trim.isEmpty)
39-
)
40-
4126
/**
4227
* Borrows a connection from pool.
4328
* @return connection
@@ -70,24 +55,23 @@ object AsyncConnectionPool extends LogSupport {
7055

7156
private[this] val pools = new ConcurrentMap[Any, AsyncConnectionPool]()
7257

73-
private[this] def ensureInitialized(name: Any): Unit = {
74-
if (!isInitialized(name)) {
58+
def isInitialized(name: Any = DEFAULT_NAME): Boolean = pools.contains(name)
59+
60+
def get(name: Any = DEFAULT_NAME): AsyncConnectionPool = {
61+
pools.getOrElse(name, {
7562
val message = ErrorMessage.CONNECTION_POOL_IS_NOT_YET_INITIALIZED + "(name:" + name + ")"
7663
throw new IllegalStateException(message)
77-
}
64+
})
7865
}
7966

80-
def isInitialized(name: Any = DEFAULT_NAME) = pools.get(name).isDefined
81-
82-
def get(name: Any = DEFAULT_NAME): AsyncConnectionPool = pools.get(name).orNull
83-
8467
def apply(name: Any = DEFAULT_NAME): AsyncConnectionPool = get(name)
8568

8669
def add(name: Any, url: String, user: String, password: String, settings: CPSettings = AsyncConnectionPoolSettings())(
8770
implicit factory: CPFactory = AsyncConnectionPoolFactory): Unit = {
8871
val newPool: AsyncConnectionPool = factory.apply(url, user, password, settings)
8972
log.debug(s"Registered connection pool (url: ${url}, user: ${user}, settings: ${settings}")
90-
pools.put(name, newPool)
73+
val replaced = pools.put(name, newPool)
74+
replaced.foreach(_.close())
9175
}
9276

9377
def singleton(url: String, user: String, password: String, settings: CPSettings = AsyncConnectionPoolSettings())(
@@ -96,12 +80,17 @@ object AsyncConnectionPool extends LogSupport {
9680
}
9781

9882
def borrow(name: Any = DEFAULT_NAME): AsyncConnection = {
99-
ensureInitialized(name)
10083
val pool = get(name)
101-
log.debug("Borrowed a new connection from " + pool.toString())
84+
log.debug(s"Borrowed a new connection from pool $name")
10285
pool.borrow()
10386
}
10487

88+
def giveBack(connection: NonSharedAsyncConnection, name: Any = DEFAULT_NAME): Unit = {
89+
val pool = get(name)
90+
log.debug(s"Gave back previously borrowed connection from pool $name")
91+
pool.giveBack(connection)
92+
}
93+
10594
def close(name: Any = DEFAULT_NAME): Unit = pools.remove(name).foreach(_.close())
10695

10796
def closeAll(): Unit = pools.keys.foreach(name => close(name))

core/src/main/scala/scalikejdbc/async/AsyncDB.scala

+1-41
Original file line numberDiff line numberDiff line change
@@ -15,47 +15,7 @@
1515
*/
1616
package scalikejdbc.async
1717

18-
import scala.concurrent._
19-
import scala.util.{ Failure, Success }
20-
import scalikejdbc.async.ShortenedNames._
21-
import scalikejdbc.async.internal.AsyncConnectionCommonImpl
22-
2318
/**
2419
* Basic Database Accessor
2520
*/
26-
object AsyncDB {
27-
28-
/**
29-
* Provides a code block which have a connection from ConnectionPool and passes it to the operation.
30-
*
31-
* @param op operation
32-
* @tparam A return type
33-
* @return a future value
34-
*/
35-
def withPool[A](op: (SharedAsyncDBSession) => Future[A]): Future[A] = {
36-
op.apply(sharedSession)
37-
}
38-
39-
/**
40-
* Provides a shared session.
41-
*
42-
* @return shared session
43-
*/
44-
def sharedSession: SharedAsyncDBSession = SharedAsyncDBSession(AsyncConnectionPool().borrow())
45-
46-
/**
47-
* Provides a future world within a transaction.
48-
*
49-
* @param op operation
50-
* @param cxt execution context
51-
* @tparam A return type
52-
* @return a future value
53-
*/
54-
def localTx[A](op: (TxAsyncDBSession) => Future[A])(implicit cxt: EC = ECGlobal): Future[A] = {
55-
AsyncConnectionPool().borrow().toNonSharedConnection()
56-
.map { nonSharedConnection => TxAsyncDBSession(nonSharedConnection) }
57-
.flatMap { tx => AsyncTx.inTransaction[A](tx, op) }
58-
}
59-
60-
}
61-
21+
object AsyncDB extends NamedAsyncDB()

core/src/main/scala/scalikejdbc/async/AsyncDBSession.scala

+14-14
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ trait AsyncDBSession extends LogSupport {
3333
withListeners(statement, _parameters) {
3434
queryLogging(statement, _parameters)
3535
connection.sendPreparedStatement(statement, _parameters: _*).map { result =>
36-
result.rowsAffected.map(_ > 0).getOrElse(false)
36+
result.rowsAffected.exists(_ > 0)
3737
}
3838
}
3939
}
@@ -44,8 +44,8 @@ trait AsyncDBSession extends LogSupport {
4444
queryLogging(statement, _parameters)
4545
if (connection.isShared) {
4646
// create local transaction because postgresql-async 0.2.4 seems not to be stable with PostgreSQL without transaction
47-
connection.toNonSharedConnection().map(c => TxAsyncDBSession(c)).flatMap { tx: TxAsyncDBSession =>
48-
tx.update(statement, _parameters: _*)
47+
connection.toNonSharedConnection().flatMap { conn =>
48+
AsyncTx.inTransaction(TxAsyncDBSession(conn), (tx: TxAsyncDBSession) => tx.update(statement, _parameters: _*))
4949
}
5050
} else {
5151
connection.sendPreparedStatement(statement, _parameters: _*).map { result =>
@@ -60,11 +60,12 @@ trait AsyncDBSession extends LogSupport {
6060
withListeners(statement, _parameters) {
6161
queryLogging(statement, _parameters)
6262
connection.toNonSharedConnection().flatMap { conn =>
63-
conn.sendPreparedStatement(statement, _parameters: _*).map { result =>
64-
result.generatedKey.getOrElse {
65-
throw new IllegalArgumentException(ErrorMessage.FAILED_TO_RETRIEVE_GENERATED_KEY + " SQL: '" + statement + "'")
66-
}
67-
}
63+
AsyncTx.inTransaction(TxAsyncDBSession(conn), (tx: TxAsyncDBSession) =>
64+
tx.connection.sendPreparedStatement(statement, _parameters: _*).flatMap { result =>
65+
result.generatedKey.map(_.getOrElse {
66+
throw new IllegalArgumentException(ErrorMessage.FAILED_TO_RETRIEVE_GENERATED_KEY + " SQL: '" + statement + "'")
67+
})
68+
})
6869
}
6970
}
7071
}
@@ -88,7 +89,7 @@ trait AsyncDBSession extends LogSupport {
8889
results match {
8990
case Nil => None
9091
case one :: Nil => Option(one)
91-
case _ => throw new TooManyRowsException(1, results.size)
92+
case _ => throw TooManyRowsException(1, results.size)
9293
}
9394
}
9495
}
@@ -106,10 +107,9 @@ trait AsyncDBSession extends LogSupport {
106107

107108
def processResultSet(oneToOne: (LinkedHashMap[A, Option[B]]), rs: WrappedResultSet): LinkedHashMap[A, Option[B]] = {
108109
val o = extractOne(rs)
109-
oneToOne.keys.find(_ == o).map {
110-
case Some(found) => throw new IllegalRelationshipException(ErrorMessage.INVALID_ONE_TO_ONE_RELATION)
111-
}.getOrElse {
112-
oneToOne += (o -> extractTo(rs))
110+
oneToOne.keys.find(_ == o) match {
111+
case Some(_) => throw IllegalRelationshipException(ErrorMessage.INVALID_ONE_TO_ONE_RELATION)
112+
case _ => oneToOne += (o -> extractTo(rs))
113113
}
114114
}
115115
connection.sendPreparedStatement(statement, _parameters: _*).map { result =>
@@ -338,7 +338,7 @@ trait AsyncDBSession extends LogSupport {
338338
protected def ensureAndNormalizeParameters(parameters: Seq[Any]): Seq[Any] = {
339339
parameters.map {
340340
case withValue: ParameterBinderWithValue[_] => withValue.value
341-
case binder: ParameterBinder => throw new IllegalArgumentException("ParameterBinder is unsupported")
341+
case _: ParameterBinder => throw new IllegalArgumentException("ParameterBinder is unsupported")
342342
case rawValue => rawValue
343343
}
344344
}

core/src/main/scala/scalikejdbc/async/AsyncQueryResult.scala

+3-1
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,8 @@
1515
*/
1616
package scalikejdbc.async
1717

18+
import scala.concurrent.Future
19+
1820
/**
1921
* Query Result
2022
*/
@@ -23,7 +25,7 @@ abstract class AsyncQueryResult(
2325
val statusMessage: Option[String],
2426
val rows: Option[AsyncResultSet]) {
2527

26-
val generatedKey: Option[Long]
28+
val generatedKey: Future[Option[Long]]
2729

2830
}
2931

core/src/main/scala/scalikejdbc/async/AsyncTxQuery.scala

+1-4
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,7 @@
1616
package scalikejdbc.async
1717

1818
import scalikejdbc._
19-
import scalikejdbc.async.internal.AsyncConnectionCommonImpl
2019
import scala.concurrent._
21-
import scala.util._
2220
import scalikejdbc.async.ShortenedNames._
2321

2422
/**
@@ -36,8 +34,7 @@ class AsyncTxQuery(sqls: Seq[SQL[_, _]]) {
3634
}
3735
}
3836
session.connection.toNonSharedConnection
39-
.map(conn => TxAsyncDBSession(conn))
40-
.flatMap { tx => AsyncTx.inTransaction(tx, op) }
37+
.flatMap(conn => AsyncTx.inTransaction(TxAsyncDBSession(conn), op))
4138
}
4239

4340
}

core/src/main/scala/scalikejdbc/async/NamedAsyncDB.scala

+4-9
Original file line numberDiff line numberDiff line change
@@ -15,10 +15,7 @@
1515
*/
1616
package scalikejdbc.async
1717

18-
import scalikejdbc.async.internal.AsyncConnectionCommonImpl
19-
2018
import scala.concurrent._
21-
import scala.util.{ Failure, Success }
2219
import scalikejdbc.async.ShortenedNames._
2320

2421
/**
@@ -29,12 +26,12 @@ case class NamedAsyncDB(name: Any = AsyncConnectionPool.DEFAULT_NAME) {
2926
/**
3027
* Provides a code block which have a connection from ConnectionPool and passes it to the operation.
3128
*
32-
* @param f operation
29+
* @param op operation
3330
* @tparam A return type
3431
* @return a Future value
3532
*/
36-
def withPool[A](f: (SharedAsyncDBSession) => Future[A]): Future[A] = {
37-
f.apply(sharedSession)
33+
def withPool[A](op: (SharedAsyncDBSession) => Future[A]): Future[A] = {
34+
op.apply(sharedSession)
3835
}
3936

4037
/**
@@ -54,9 +51,7 @@ case class NamedAsyncDB(name: Any = AsyncConnectionPool.DEFAULT_NAME) {
5451
*/
5552
def localTx[A](op: (TxAsyncDBSession) => Future[A])(implicit cxt: EC = ECGlobal): Future[A] = {
5653
AsyncConnectionPool(name).borrow().toNonSharedConnection()
57-
.map { txConn => TxAsyncDBSession(txConn) }
58-
.flatMap { tx => AsyncTx.inTransaction[A](tx, op) }
54+
.flatMap(conn => AsyncTx.inTransaction[A](TxAsyncDBSession(conn), op))
5955
}
6056

6157
}
62-

core/src/main/scala/scalikejdbc/async/internal/AsyncConnectionCommonImpl.scala

+7-1
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,12 @@ private[scalikejdbc] trait AsyncConnectionCommonImpl extends AsyncConnection {
6868
* @param cxt execution context
6969
* @return optional generated key
7070
*/
71-
protected def extractGeneratedKey(queryResult: QueryResult)(implicit cxt: EC = ECGlobal): Option[Long]
71+
protected def extractGeneratedKey(queryResult: QueryResult)(implicit cxt: EC = ECGlobal): Future[Option[Long]]
72+
73+
protected def ensureNonShared(): Unit = {
74+
if (!this.isInstanceOf[NonSharedAsyncConnection]) {
75+
throw new IllegalStateException("This asynchronous connection must be a non-shared connection.")
76+
}
77+
}
7278

7379
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
package scalikejdbc.async.internal
2+
3+
import com.github.mauricio.async.db.{Configuration, Connection}
4+
import com.github.mauricio.async.db.pool.{ConnectionPool, ObjectFactory, PoolConfiguration}
5+
import scalikejdbc.LogSupport
6+
import scalikejdbc.async.{AsyncConnectionPool, AsyncConnectionPoolSettings, NonSharedAsyncConnection}
7+
8+
abstract class AsyncConnectionPoolCommonImpl[T <: Connection](
9+
url: String,
10+
user: String,
11+
password: String,
12+
factoryF: Configuration => ObjectFactory[T],
13+
settings: AsyncConnectionPoolSettings = AsyncConnectionPoolSettings()
14+
) extends AsyncConnectionPool(settings) with MauricioConfiguration with LogSupport {
15+
16+
private[this] val factory = factoryF(configuration(url, user, password))
17+
private[internal] val pool = new ConnectionPool[T](
18+
factory = factory,
19+
configuration = PoolConfiguration(
20+
maxObjects = settings.maxPoolSize,
21+
maxIdle = settings.maxIdleMillis,
22+
maxQueueSize = settings.maxQueueSize)
23+
)
24+
25+
override def close(): Unit = pool.disconnect
26+
27+
override def giveBack(conn: NonSharedAsyncConnection): Unit = conn match {
28+
case conn: NonSharedAsyncConnectionImpl => pool.giveBack(conn.underlying.asInstanceOf[T])
29+
case _ => log.debug("You don't need to give back this connection to the pool.")
30+
}
31+
}

core/src/main/scala/scalikejdbc/async/internal/MauricioConfiguration.scala

+7-11
Original file line numberDiff line numberDiff line change
@@ -15,26 +15,22 @@
1515
*/
1616
package scalikejdbc.async.internal
1717

18+
import com.github.mauricio.async.db.Configuration
1819
import scalikejdbc.JDBCUrl
19-
import scalikejdbc.async._
2020

2121
/**
2222
* Configuration attribute
2323
*/
24-
private[scalikejdbc] trait MauricioConfiguration { self: AsyncConnection =>
24+
private[scalikejdbc] trait MauricioConfiguration {
2525

26-
val url: String
27-
val user: String
28-
val password: String
29-
30-
private[scalikejdbc] val configuration = {
26+
private[scalikejdbc] def configuration(url: String, user: String, password: String) = {
3127
val jdbcUrl = JDBCUrl(url)
32-
com.github.mauricio.async.db.Configuration(
28+
Configuration(
29+
username = user,
3330
host = jdbcUrl.host,
3431
port = jdbcUrl.port,
35-
database = Option(jdbcUrl.database),
36-
username = user,
37-
password = Option(password)
32+
password = Option(password).filterNot(_.trim.isEmpty),
33+
database = Option(jdbcUrl.database).filterNot(_.trim.isEmpty)
3834
)
3935
}
4036

core/src/main/scala/scalikejdbc/async/internal/NonSharedAsyncConnectionImpl.scala

+2-1
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,8 @@ abstract class NonSharedAsyncConnectionImpl(
1717
extends AsyncConnectionCommonImpl
1818
with NonSharedAsyncConnection {
1919

20-
override def toNonSharedConnection()(implicit cxt: EC = ECGlobal): Future[NonSharedAsyncConnection] = Future(this)
20+
override def toNonSharedConnection()(implicit cxt: EC = ECGlobal): Future[NonSharedAsyncConnection] =
21+
Future.successful(this)
2122

2223
override def release(): Unit = pool.map(_.giveBack(this.underlying))
2324

core/src/main/scala/scalikejdbc/async/internal/PoolableAsyncConnection.scala

+3-5
Original file line numberDiff line numberDiff line change
@@ -27,13 +27,11 @@ import scalikejdbc.async.ShortenedNames._
2727
* @param pool connection pool
2828
* @tparam T Connection sub type
2929
*/
30-
private[scalikejdbc] abstract class PoolableAsyncConnection[T <: com.github.mauricio.async.db.Connection](val pool: ConnectionPool[T])
31-
extends AsyncConnectionCommonImpl
32-
with AsyncConnection {
30+
private[scalikejdbc] abstract class PoolableAsyncConnection[T <: Connection](val pool: ConnectionPool[T])
31+
extends AsyncConnectionCommonImpl {
3332

34-
override def toNonSharedConnection()(implicit cxt: EC = ECGlobal): Future[NonSharedAsyncConnection] = {
33+
override def toNonSharedConnection()(implicit cxt: EC = ECGlobal): Future[NonSharedAsyncConnection] =
3534
Future.failed(new UnsupportedOperationException)
36-
}
3735

3836
private[scalikejdbc] val underlying: Connection = pool
3937

0 commit comments

Comments
 (0)