Correctly cancel child jobs in suspending variant of binding

Closes #46
This commit is contained in:
Tristan Hamilton 2021-05-05 14:30:41 +01:00 committed by Michael Bull
parent a5c153477b
commit f2bd9aaa11
2 changed files with 84 additions and 34 deletions

View File

@ -5,27 +5,31 @@ import com.github.michaelbull.result.Ok
import com.github.michaelbull.result.Result
import kotlinx.coroutines.CancellationException
import kotlinx.coroutines.CoroutineScope
import kotlinx.coroutines.Job
import kotlinx.coroutines.cancel
import kotlinx.coroutines.coroutineScope
import kotlinx.coroutines.sync.Mutex
import kotlinx.coroutines.sync.withLock
import kotlin.contracts.InvocationKind
import kotlin.contracts.contract
import kotlin.coroutines.CoroutineContext
/**
* Suspending variant of [binding][com.github.michaelbull.result.binding].
* Wraps the suspendable block in a new coroutine scope.
* This scope is cancelled once a failing bind is encountered, allowing deferred child jobs to be eagerly cancelled.
* The suspendable [block] runs in a new [CoroutineScope], inheriting the parent [CoroutineContext].
* This new scope is [cancelled][CoroutineScope.cancel] once a failing bind is encountered, eagerly cancelling all
* child [jobs][Job].
*/
public suspend inline fun <V, E> binding(crossinline block: suspend SuspendableResultBinding<E>.() -> V): Result<V, E> {
contract {
callsInPlace(block, InvocationKind.EXACTLY_ONCE)
}
val receiver = SuspendableResultBindingImpl<E>()
lateinit var receiver: SuspendableResultBindingImpl<E>
return try {
coroutineScope {
receiver.coroutineScope = this@coroutineScope
receiver = SuspendableResultBindingImpl(this.coroutineContext)
with(receiver) { Ok(block()) }
}
} catch (ex: BindCancellationException) {
@ -35,27 +39,27 @@ public suspend inline fun <V, E> binding(crossinline block: suspend SuspendableR
internal object BindCancellationException : CancellationException(null)
public interface SuspendableResultBinding<E> {
public interface SuspendableResultBinding<E> : CoroutineScope {
public suspend fun <V> Result<V, E>.bind(): V
}
@PublishedApi
internal class SuspendableResultBindingImpl<E> : SuspendableResultBinding<E> {
internal class SuspendableResultBindingImpl<E>(
override val coroutineContext: CoroutineContext
) : SuspendableResultBinding<E> {
private val mutex = Mutex()
lateinit var internalError: Err<E>
var coroutineScope: CoroutineScope? = null
override suspend fun <V> Result<V, E>.bind(): V {
return when (this) {
is Ok -> value
is Err -> {
mutex.withLock {
if (::internalError.isInitialized.not()) {
internalError = this
}
is Err -> mutex.withLock {
if (::internalError.isInitialized.not()) {
internalError = this
this@SuspendableResultBindingImpl.cancel(BindCancellationException)
}
coroutineScope?.cancel(BindCancellationException)
throw BindCancellationException
}
}

View File

@ -3,9 +3,14 @@ package com.github.michaelbull.result.coroutines.binding
import com.github.michaelbull.result.Err
import com.github.michaelbull.result.Ok
import com.github.michaelbull.result.Result
import kotlinx.coroutines.CoroutineName
import kotlinx.coroutines.asCoroutineDispatcher
import kotlinx.coroutines.async
import kotlinx.coroutines.delay
import kotlinx.coroutines.launch
import kotlinx.coroutines.runBlocking
import java.util.concurrent.Executors
import kotlin.coroutines.CoroutineContext
import kotlin.test.Test
import kotlin.test.assertEquals
import kotlin.test.assertFalse
@ -53,12 +58,12 @@ class AsyncSuspendableBindingTest {
}
suspend fun provideY(): Result<Int, BindingError.BindingErrorA> {
delay(2)
delay(3)
return Err(BindingError.BindingErrorA)
}
suspend fun provideZ(): Result<Int, BindingError.BindingErrorB> {
delay(1)
delay(2)
return Err(BindingError.BindingErrorB)
}
@ -82,41 +87,82 @@ class AsyncSuspendableBindingTest {
fun returnsStateChangedForOnlyTheFirstAsyncBindFailWhenEagerlyCancellingBinding() {
var xStateChange = false
var yStateChange = false
var zStateChange = false
suspend fun provideX(): Result<Int, BindingError> {
delay(20)
xStateChange = true
return Ok(1)
}
suspend fun provideY(): Result<Int, BindingError.BindingErrorA> {
delay(10)
yStateChange = true
suspend fun provideX(): Result<Int, BindingError> {
delay(2)
xStateChange = true
return Err(BindingError.BindingErrorA)
}
suspend fun provideZ(): Result<Int, BindingError.BindingErrorB> {
delay(1)
zStateChange = true
suspend fun provideY(): Result<Int, BindingError.BindingErrorB> {
// as this test uses a new thread for each coroutine, we want to set this delay to a high enough number that
// there isn't any chance of a jvm run actually completing this suspending function in this thread first
// otherwise the assertions might fail.
delay(100)
yStateChange = true
return Err(BindingError.BindingErrorB)
}
runBlocking {
val result = binding<Int, BindingError> {
val x = async { provideX().bind() }
val y = async { provideY().bind() }
val z = async { provideZ().bind() }
x.await() + y.await() + z.await()
val x = async(newThread("ThreadA")) { provideX().bind() }
val y = async(newThread("ThreadB")) { provideY().bind() }
x.await() + y.await()
}
assertTrue(result is Err)
assertEquals(
expected = BindingError.BindingErrorB,
expected = BindingError.BindingErrorA,
actual = result.error
)
assertFalse(xStateChange)
assertTrue(xStateChange)
assertFalse(yStateChange)
assertTrue(zStateChange)
}
}
@Test
fun returnsStateChangedForOnlyTheFirstLaunchBindFailWhenEagerlyCancellingBinding() {
var xStateChange = false
var yStateChange = false
var zStateChange = false
suspend fun provideX(): Result<Int, BindingError> {
delay(1)
xStateChange = true
return Ok(1)
}
suspend fun provideY(): Result<Int, BindingError.BindingErrorA> {
delay(20)
yStateChange = true
return Err(BindingError.BindingErrorA)
}
suspend fun provideZ(): Result<Int, BindingError.BindingErrorB> {
delay(100)
zStateChange = true
return Err(BindingError.BindingErrorB)
}
runBlocking {
val result = binding<Unit, BindingError> {
launch(newThread("Thread A")) { provideX().bind() }
launch(newThread("Thread B")) { provideY().bind() }
launch(newThread("Thread C")) { provideZ().bind() }
}
assertTrue(result is Err)
assertEquals(
expected = BindingError.BindingErrorA,
actual = result.error
)
assertTrue(xStateChange)
assertTrue(yStateChange)
assertFalse(zStateChange)
}
}
private fun newThread(name: String): CoroutineContext {
return Executors.newSingleThreadExecutor().asCoroutineDispatcher() + CoroutineName(name)
}
}