Correctly cancel child jobs in suspending variant of binding
Closes #46
This commit is contained in:
parent
a5c153477b
commit
f2bd9aaa11
@ -5,27 +5,31 @@ import com.github.michaelbull.result.Ok
|
||||
import com.github.michaelbull.result.Result
|
||||
import kotlinx.coroutines.CancellationException
|
||||
import kotlinx.coroutines.CoroutineScope
|
||||
import kotlinx.coroutines.Job
|
||||
import kotlinx.coroutines.cancel
|
||||
import kotlinx.coroutines.coroutineScope
|
||||
import kotlinx.coroutines.sync.Mutex
|
||||
import kotlinx.coroutines.sync.withLock
|
||||
import kotlin.contracts.InvocationKind
|
||||
import kotlin.contracts.contract
|
||||
import kotlin.coroutines.CoroutineContext
|
||||
|
||||
/**
|
||||
* Suspending variant of [binding][com.github.michaelbull.result.binding].
|
||||
* Wraps the suspendable block in a new coroutine scope.
|
||||
* This scope is cancelled once a failing bind is encountered, allowing deferred child jobs to be eagerly cancelled.
|
||||
* The suspendable [block] runs in a new [CoroutineScope], inheriting the parent [CoroutineContext].
|
||||
* This new scope is [cancelled][CoroutineScope.cancel] once a failing bind is encountered, eagerly cancelling all
|
||||
* child [jobs][Job].
|
||||
*/
|
||||
public suspend inline fun <V, E> binding(crossinline block: suspend SuspendableResultBinding<E>.() -> V): Result<V, E> {
|
||||
contract {
|
||||
callsInPlace(block, InvocationKind.EXACTLY_ONCE)
|
||||
}
|
||||
val receiver = SuspendableResultBindingImpl<E>()
|
||||
|
||||
lateinit var receiver: SuspendableResultBindingImpl<E>
|
||||
|
||||
return try {
|
||||
coroutineScope {
|
||||
receiver.coroutineScope = this@coroutineScope
|
||||
receiver = SuspendableResultBindingImpl(this.coroutineContext)
|
||||
with(receiver) { Ok(block()) }
|
||||
}
|
||||
} catch (ex: BindCancellationException) {
|
||||
@ -35,27 +39,27 @@ public suspend inline fun <V, E> binding(crossinline block: suspend SuspendableR
|
||||
|
||||
internal object BindCancellationException : CancellationException(null)
|
||||
|
||||
public interface SuspendableResultBinding<E> {
|
||||
public interface SuspendableResultBinding<E> : CoroutineScope {
|
||||
public suspend fun <V> Result<V, E>.bind(): V
|
||||
}
|
||||
|
||||
@PublishedApi
|
||||
internal class SuspendableResultBindingImpl<E> : SuspendableResultBinding<E> {
|
||||
internal class SuspendableResultBindingImpl<E>(
|
||||
override val coroutineContext: CoroutineContext
|
||||
) : SuspendableResultBinding<E> {
|
||||
|
||||
private val mutex = Mutex()
|
||||
lateinit var internalError: Err<E>
|
||||
var coroutineScope: CoroutineScope? = null
|
||||
|
||||
override suspend fun <V> Result<V, E>.bind(): V {
|
||||
return when (this) {
|
||||
is Ok -> value
|
||||
is Err -> {
|
||||
mutex.withLock {
|
||||
is Err -> mutex.withLock {
|
||||
if (::internalError.isInitialized.not()) {
|
||||
internalError = this
|
||||
this@SuspendableResultBindingImpl.cancel(BindCancellationException)
|
||||
}
|
||||
}
|
||||
coroutineScope?.cancel(BindCancellationException)
|
||||
|
||||
throw BindCancellationException
|
||||
}
|
||||
}
|
||||
|
@ -3,9 +3,14 @@ package com.github.michaelbull.result.coroutines.binding
|
||||
import com.github.michaelbull.result.Err
|
||||
import com.github.michaelbull.result.Ok
|
||||
import com.github.michaelbull.result.Result
|
||||
import kotlinx.coroutines.CoroutineName
|
||||
import kotlinx.coroutines.asCoroutineDispatcher
|
||||
import kotlinx.coroutines.async
|
||||
import kotlinx.coroutines.delay
|
||||
import kotlinx.coroutines.launch
|
||||
import kotlinx.coroutines.runBlocking
|
||||
import java.util.concurrent.Executors
|
||||
import kotlin.coroutines.CoroutineContext
|
||||
import kotlin.test.Test
|
||||
import kotlin.test.assertEquals
|
||||
import kotlin.test.assertFalse
|
||||
@ -53,12 +58,12 @@ class AsyncSuspendableBindingTest {
|
||||
}
|
||||
|
||||
suspend fun provideY(): Result<Int, BindingError.BindingErrorA> {
|
||||
delay(2)
|
||||
delay(3)
|
||||
return Err(BindingError.BindingErrorA)
|
||||
}
|
||||
|
||||
suspend fun provideZ(): Result<Int, BindingError.BindingErrorB> {
|
||||
delay(1)
|
||||
delay(2)
|
||||
return Err(BindingError.BindingErrorB)
|
||||
}
|
||||
|
||||
@ -82,41 +87,82 @@ class AsyncSuspendableBindingTest {
|
||||
fun returnsStateChangedForOnlyTheFirstAsyncBindFailWhenEagerlyCancellingBinding() {
|
||||
var xStateChange = false
|
||||
var yStateChange = false
|
||||
var zStateChange = false
|
||||
suspend fun provideX(): Result<Int, BindingError> {
|
||||
delay(20)
|
||||
xStateChange = true
|
||||
return Ok(1)
|
||||
}
|
||||
|
||||
suspend fun provideY(): Result<Int, BindingError.BindingErrorA> {
|
||||
delay(10)
|
||||
yStateChange = true
|
||||
suspend fun provideX(): Result<Int, BindingError> {
|
||||
delay(2)
|
||||
xStateChange = true
|
||||
return Err(BindingError.BindingErrorA)
|
||||
}
|
||||
|
||||
suspend fun provideZ(): Result<Int, BindingError.BindingErrorB> {
|
||||
delay(1)
|
||||
zStateChange = true
|
||||
suspend fun provideY(): Result<Int, BindingError.BindingErrorB> {
|
||||
// as this test uses a new thread for each coroutine, we want to set this delay to a high enough number that
|
||||
// there isn't any chance of a jvm run actually completing this suspending function in this thread first
|
||||
// otherwise the assertions might fail.
|
||||
delay(100)
|
||||
yStateChange = true
|
||||
return Err(BindingError.BindingErrorB)
|
||||
}
|
||||
|
||||
runBlocking {
|
||||
val result = binding<Int, BindingError> {
|
||||
val x = async { provideX().bind() }
|
||||
val y = async { provideY().bind() }
|
||||
val z = async { provideZ().bind() }
|
||||
x.await() + y.await() + z.await()
|
||||
val x = async(newThread("ThreadA")) { provideX().bind() }
|
||||
val y = async(newThread("ThreadB")) { provideY().bind() }
|
||||
x.await() + y.await()
|
||||
}
|
||||
|
||||
assertTrue(result is Err)
|
||||
assertEquals(
|
||||
expected = BindingError.BindingErrorB,
|
||||
expected = BindingError.BindingErrorA,
|
||||
actual = result.error
|
||||
)
|
||||
assertFalse(xStateChange)
|
||||
assertTrue(xStateChange)
|
||||
assertFalse(yStateChange)
|
||||
assertTrue(zStateChange)
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
fun returnsStateChangedForOnlyTheFirstLaunchBindFailWhenEagerlyCancellingBinding() {
|
||||
var xStateChange = false
|
||||
var yStateChange = false
|
||||
var zStateChange = false
|
||||
|
||||
suspend fun provideX(): Result<Int, BindingError> {
|
||||
delay(1)
|
||||
xStateChange = true
|
||||
return Ok(1)
|
||||
}
|
||||
|
||||
suspend fun provideY(): Result<Int, BindingError.BindingErrorA> {
|
||||
delay(20)
|
||||
yStateChange = true
|
||||
return Err(BindingError.BindingErrorA)
|
||||
}
|
||||
|
||||
suspend fun provideZ(): Result<Int, BindingError.BindingErrorB> {
|
||||
delay(100)
|
||||
zStateChange = true
|
||||
return Err(BindingError.BindingErrorB)
|
||||
}
|
||||
|
||||
runBlocking {
|
||||
val result = binding<Unit, BindingError> {
|
||||
launch(newThread("Thread A")) { provideX().bind() }
|
||||
launch(newThread("Thread B")) { provideY().bind() }
|
||||
launch(newThread("Thread C")) { provideZ().bind() }
|
||||
}
|
||||
|
||||
assertTrue(result is Err)
|
||||
assertEquals(
|
||||
expected = BindingError.BindingErrorA,
|
||||
actual = result.error
|
||||
)
|
||||
assertTrue(xStateChange)
|
||||
assertTrue(yStateChange)
|
||||
assertFalse(zStateChange)
|
||||
}
|
||||
}
|
||||
|
||||
private fun newThread(name: String): CoroutineContext {
|
||||
return Executors.newSingleThreadExecutor().asCoroutineDispatcher() + CoroutineName(name)
|
||||
}
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user