Correctly cancel child jobs in suspending variant of binding

Closes #46
This commit is contained in:
Tristan Hamilton 2021-05-05 14:30:41 +01:00 committed by Michael Bull
parent a5c153477b
commit f2bd9aaa11
2 changed files with 84 additions and 34 deletions

View File

@ -5,27 +5,31 @@ import com.github.michaelbull.result.Ok
import com.github.michaelbull.result.Result import com.github.michaelbull.result.Result
import kotlinx.coroutines.CancellationException import kotlinx.coroutines.CancellationException
import kotlinx.coroutines.CoroutineScope import kotlinx.coroutines.CoroutineScope
import kotlinx.coroutines.Job
import kotlinx.coroutines.cancel import kotlinx.coroutines.cancel
import kotlinx.coroutines.coroutineScope import kotlinx.coroutines.coroutineScope
import kotlinx.coroutines.sync.Mutex import kotlinx.coroutines.sync.Mutex
import kotlinx.coroutines.sync.withLock import kotlinx.coroutines.sync.withLock
import kotlin.contracts.InvocationKind import kotlin.contracts.InvocationKind
import kotlin.contracts.contract import kotlin.contracts.contract
import kotlin.coroutines.CoroutineContext
/** /**
* Suspending variant of [binding][com.github.michaelbull.result.binding]. * Suspending variant of [binding][com.github.michaelbull.result.binding].
* Wraps the suspendable block in a new coroutine scope. * The suspendable [block] runs in a new [CoroutineScope], inheriting the parent [CoroutineContext].
* This scope is cancelled once a failing bind is encountered, allowing deferred child jobs to be eagerly cancelled. * This new scope is [cancelled][CoroutineScope.cancel] once a failing bind is encountered, eagerly cancelling all
* child [jobs][Job].
*/ */
public suspend inline fun <V, E> binding(crossinline block: suspend SuspendableResultBinding<E>.() -> V): Result<V, E> { public suspend inline fun <V, E> binding(crossinline block: suspend SuspendableResultBinding<E>.() -> V): Result<V, E> {
contract { contract {
callsInPlace(block, InvocationKind.EXACTLY_ONCE) callsInPlace(block, InvocationKind.EXACTLY_ONCE)
} }
val receiver = SuspendableResultBindingImpl<E>()
lateinit var receiver: SuspendableResultBindingImpl<E>
return try { return try {
coroutineScope { coroutineScope {
receiver.coroutineScope = this@coroutineScope receiver = SuspendableResultBindingImpl(this.coroutineContext)
with(receiver) { Ok(block()) } with(receiver) { Ok(block()) }
} }
} catch (ex: BindCancellationException) { } catch (ex: BindCancellationException) {
@ -35,27 +39,27 @@ public suspend inline fun <V, E> binding(crossinline block: suspend SuspendableR
internal object BindCancellationException : CancellationException(null) internal object BindCancellationException : CancellationException(null)
public interface SuspendableResultBinding<E> { public interface SuspendableResultBinding<E> : CoroutineScope {
public suspend fun <V> Result<V, E>.bind(): V public suspend fun <V> Result<V, E>.bind(): V
} }
@PublishedApi @PublishedApi
internal class SuspendableResultBindingImpl<E> : SuspendableResultBinding<E> { internal class SuspendableResultBindingImpl<E>(
override val coroutineContext: CoroutineContext
) : SuspendableResultBinding<E> {
private val mutex = Mutex() private val mutex = Mutex()
lateinit var internalError: Err<E> lateinit var internalError: Err<E>
var coroutineScope: CoroutineScope? = null
override suspend fun <V> Result<V, E>.bind(): V { override suspend fun <V> Result<V, E>.bind(): V {
return when (this) { return when (this) {
is Ok -> value is Ok -> value
is Err -> { is Err -> mutex.withLock {
mutex.withLock { if (::internalError.isInitialized.not()) {
if (::internalError.isInitialized.not()) { internalError = this
internalError = this this@SuspendableResultBindingImpl.cancel(BindCancellationException)
}
} }
coroutineScope?.cancel(BindCancellationException)
throw BindCancellationException throw BindCancellationException
} }
} }

View File

@ -3,9 +3,14 @@ package com.github.michaelbull.result.coroutines.binding
import com.github.michaelbull.result.Err import com.github.michaelbull.result.Err
import com.github.michaelbull.result.Ok import com.github.michaelbull.result.Ok
import com.github.michaelbull.result.Result import com.github.michaelbull.result.Result
import kotlinx.coroutines.CoroutineName
import kotlinx.coroutines.asCoroutineDispatcher
import kotlinx.coroutines.async import kotlinx.coroutines.async
import kotlinx.coroutines.delay import kotlinx.coroutines.delay
import kotlinx.coroutines.launch
import kotlinx.coroutines.runBlocking import kotlinx.coroutines.runBlocking
import java.util.concurrent.Executors
import kotlin.coroutines.CoroutineContext
import kotlin.test.Test import kotlin.test.Test
import kotlin.test.assertEquals import kotlin.test.assertEquals
import kotlin.test.assertFalse import kotlin.test.assertFalse
@ -53,12 +58,12 @@ class AsyncSuspendableBindingTest {
} }
suspend fun provideY(): Result<Int, BindingError.BindingErrorA> { suspend fun provideY(): Result<Int, BindingError.BindingErrorA> {
delay(2) delay(3)
return Err(BindingError.BindingErrorA) return Err(BindingError.BindingErrorA)
} }
suspend fun provideZ(): Result<Int, BindingError.BindingErrorB> { suspend fun provideZ(): Result<Int, BindingError.BindingErrorB> {
delay(1) delay(2)
return Err(BindingError.BindingErrorB) return Err(BindingError.BindingErrorB)
} }
@ -82,41 +87,82 @@ class AsyncSuspendableBindingTest {
fun returnsStateChangedForOnlyTheFirstAsyncBindFailWhenEagerlyCancellingBinding() { fun returnsStateChangedForOnlyTheFirstAsyncBindFailWhenEagerlyCancellingBinding() {
var xStateChange = false var xStateChange = false
var yStateChange = false var yStateChange = false
var zStateChange = false
suspend fun provideX(): Result<Int, BindingError> {
delay(20)
xStateChange = true
return Ok(1)
}
suspend fun provideY(): Result<Int, BindingError.BindingErrorA> { suspend fun provideX(): Result<Int, BindingError> {
delay(10) delay(2)
yStateChange = true xStateChange = true
return Err(BindingError.BindingErrorA) return Err(BindingError.BindingErrorA)
} }
suspend fun provideZ(): Result<Int, BindingError.BindingErrorB> { suspend fun provideY(): Result<Int, BindingError.BindingErrorB> {
delay(1) // as this test uses a new thread for each coroutine, we want to set this delay to a high enough number that
zStateChange = true // there isn't any chance of a jvm run actually completing this suspending function in this thread first
// otherwise the assertions might fail.
delay(100)
yStateChange = true
return Err(BindingError.BindingErrorB) return Err(BindingError.BindingErrorB)
} }
runBlocking { runBlocking {
val result = binding<Int, BindingError> { val result = binding<Int, BindingError> {
val x = async { provideX().bind() } val x = async(newThread("ThreadA")) { provideX().bind() }
val y = async { provideY().bind() } val y = async(newThread("ThreadB")) { provideY().bind() }
val z = async { provideZ().bind() } x.await() + y.await()
x.await() + y.await() + z.await()
} }
assertTrue(result is Err) assertTrue(result is Err)
assertEquals( assertEquals(
expected = BindingError.BindingErrorB, expected = BindingError.BindingErrorA,
actual = result.error actual = result.error
) )
assertFalse(xStateChange) assertTrue(xStateChange)
assertFalse(yStateChange) assertFalse(yStateChange)
assertTrue(zStateChange)
} }
} }
@Test
fun returnsStateChangedForOnlyTheFirstLaunchBindFailWhenEagerlyCancellingBinding() {
var xStateChange = false
var yStateChange = false
var zStateChange = false
suspend fun provideX(): Result<Int, BindingError> {
delay(1)
xStateChange = true
return Ok(1)
}
suspend fun provideY(): Result<Int, BindingError.BindingErrorA> {
delay(20)
yStateChange = true
return Err(BindingError.BindingErrorA)
}
suspend fun provideZ(): Result<Int, BindingError.BindingErrorB> {
delay(100)
zStateChange = true
return Err(BindingError.BindingErrorB)
}
runBlocking {
val result = binding<Unit, BindingError> {
launch(newThread("Thread A")) { provideX().bind() }
launch(newThread("Thread B")) { provideY().bind() }
launch(newThread("Thread C")) { provideZ().bind() }
}
assertTrue(result is Err)
assertEquals(
expected = BindingError.BindingErrorA,
actual = result.error
)
assertTrue(xStateChange)
assertTrue(yStateChange)
assertFalse(zStateChange)
}
}
private fun newThread(name: String): CoroutineContext {
return Executors.newSingleThreadExecutor().asCoroutineDispatcher() + CoroutineName(name)
}
} }