From f2bd9aaa11915f2b20a41adc2543b98722b80027 Mon Sep 17 00:00:00 2001 From: Tristan Hamilton Date: Wed, 5 May 2021 14:30:41 +0100 Subject: [PATCH] Correctly cancel child jobs in suspending variant of binding Closes #46 --- .../coroutines/binding/SuspendableBinding.kt | 30 ++++--- .../binding/AsyncSuspendableBindingTest.kt | 88 ++++++++++++++----- 2 files changed, 84 insertions(+), 34 deletions(-) diff --git a/kotlin-result-coroutines/src/commonMain/kotlin/com/github/michaelbull/result/coroutines/binding/SuspendableBinding.kt b/kotlin-result-coroutines/src/commonMain/kotlin/com/github/michaelbull/result/coroutines/binding/SuspendableBinding.kt index c0466af..38cb63e 100644 --- a/kotlin-result-coroutines/src/commonMain/kotlin/com/github/michaelbull/result/coroutines/binding/SuspendableBinding.kt +++ b/kotlin-result-coroutines/src/commonMain/kotlin/com/github/michaelbull/result/coroutines/binding/SuspendableBinding.kt @@ -5,27 +5,31 @@ import com.github.michaelbull.result.Ok import com.github.michaelbull.result.Result import kotlinx.coroutines.CancellationException import kotlinx.coroutines.CoroutineScope +import kotlinx.coroutines.Job import kotlinx.coroutines.cancel import kotlinx.coroutines.coroutineScope import kotlinx.coroutines.sync.Mutex import kotlinx.coroutines.sync.withLock import kotlin.contracts.InvocationKind import kotlin.contracts.contract +import kotlin.coroutines.CoroutineContext /** * Suspending variant of [binding][com.github.michaelbull.result.binding]. - * Wraps the suspendable block in a new coroutine scope. - * This scope is cancelled once a failing bind is encountered, allowing deferred child jobs to be eagerly cancelled. + * The suspendable [block] runs in a new [CoroutineScope], inheriting the parent [CoroutineContext]. + * This new scope is [cancelled][CoroutineScope.cancel] once a failing bind is encountered, eagerly cancelling all + * child [jobs][Job]. */ public suspend inline fun binding(crossinline block: suspend SuspendableResultBinding.() -> V): Result { contract { callsInPlace(block, InvocationKind.EXACTLY_ONCE) } - val receiver = SuspendableResultBindingImpl() + + lateinit var receiver: SuspendableResultBindingImpl return try { coroutineScope { - receiver.coroutineScope = this@coroutineScope + receiver = SuspendableResultBindingImpl(this.coroutineContext) with(receiver) { Ok(block()) } } } catch (ex: BindCancellationException) { @@ -35,27 +39,27 @@ public suspend inline fun binding(crossinline block: suspend SuspendableR internal object BindCancellationException : CancellationException(null) -public interface SuspendableResultBinding { +public interface SuspendableResultBinding : CoroutineScope { public suspend fun Result.bind(): V } @PublishedApi -internal class SuspendableResultBindingImpl : SuspendableResultBinding { +internal class SuspendableResultBindingImpl( + override val coroutineContext: CoroutineContext +) : SuspendableResultBinding { private val mutex = Mutex() lateinit var internalError: Err - var coroutineScope: CoroutineScope? = null override suspend fun Result.bind(): V { return when (this) { is Ok -> value - is Err -> { - mutex.withLock { - if (::internalError.isInitialized.not()) { - internalError = this - } + is Err -> mutex.withLock { + if (::internalError.isInitialized.not()) { + internalError = this + this@SuspendableResultBindingImpl.cancel(BindCancellationException) } - coroutineScope?.cancel(BindCancellationException) + throw BindCancellationException } } diff --git a/kotlin-result-coroutines/src/jvmTest/kotlin/com/github/michaelbull/result/coroutines/binding/AsyncSuspendableBindingTest.kt b/kotlin-result-coroutines/src/jvmTest/kotlin/com/github/michaelbull/result/coroutines/binding/AsyncSuspendableBindingTest.kt index 7d3cb11..3a30494 100644 --- a/kotlin-result-coroutines/src/jvmTest/kotlin/com/github/michaelbull/result/coroutines/binding/AsyncSuspendableBindingTest.kt +++ b/kotlin-result-coroutines/src/jvmTest/kotlin/com/github/michaelbull/result/coroutines/binding/AsyncSuspendableBindingTest.kt @@ -3,9 +3,14 @@ package com.github.michaelbull.result.coroutines.binding import com.github.michaelbull.result.Err import com.github.michaelbull.result.Ok import com.github.michaelbull.result.Result +import kotlinx.coroutines.CoroutineName +import kotlinx.coroutines.asCoroutineDispatcher import kotlinx.coroutines.async import kotlinx.coroutines.delay +import kotlinx.coroutines.launch import kotlinx.coroutines.runBlocking +import java.util.concurrent.Executors +import kotlin.coroutines.CoroutineContext import kotlin.test.Test import kotlin.test.assertEquals import kotlin.test.assertFalse @@ -53,12 +58,12 @@ class AsyncSuspendableBindingTest { } suspend fun provideY(): Result { - delay(2) + delay(3) return Err(BindingError.BindingErrorA) } suspend fun provideZ(): Result { - delay(1) + delay(2) return Err(BindingError.BindingErrorB) } @@ -82,41 +87,82 @@ class AsyncSuspendableBindingTest { fun returnsStateChangedForOnlyTheFirstAsyncBindFailWhenEagerlyCancellingBinding() { var xStateChange = false var yStateChange = false - var zStateChange = false - suspend fun provideX(): Result { - delay(20) - xStateChange = true - return Ok(1) - } - suspend fun provideY(): Result { - delay(10) - yStateChange = true + suspend fun provideX(): Result { + delay(2) + xStateChange = true return Err(BindingError.BindingErrorA) } - suspend fun provideZ(): Result { - delay(1) - zStateChange = true + suspend fun provideY(): Result { + // as this test uses a new thread for each coroutine, we want to set this delay to a high enough number that + // there isn't any chance of a jvm run actually completing this suspending function in this thread first + // otherwise the assertions might fail. + delay(100) + yStateChange = true return Err(BindingError.BindingErrorB) } runBlocking { val result = binding { - val x = async { provideX().bind() } - val y = async { provideY().bind() } - val z = async { provideZ().bind() } - x.await() + y.await() + z.await() + val x = async(newThread("ThreadA")) { provideX().bind() } + val y = async(newThread("ThreadB")) { provideY().bind() } + x.await() + y.await() } assertTrue(result is Err) assertEquals( - expected = BindingError.BindingErrorB, + expected = BindingError.BindingErrorA, actual = result.error ) - assertFalse(xStateChange) + assertTrue(xStateChange) assertFalse(yStateChange) - assertTrue(zStateChange) } } + + @Test + fun returnsStateChangedForOnlyTheFirstLaunchBindFailWhenEagerlyCancellingBinding() { + var xStateChange = false + var yStateChange = false + var zStateChange = false + + suspend fun provideX(): Result { + delay(1) + xStateChange = true + return Ok(1) + } + + suspend fun provideY(): Result { + delay(20) + yStateChange = true + return Err(BindingError.BindingErrorA) + } + + suspend fun provideZ(): Result { + delay(100) + zStateChange = true + return Err(BindingError.BindingErrorB) + } + + runBlocking { + val result = binding { + launch(newThread("Thread A")) { provideX().bind() } + launch(newThread("Thread B")) { provideY().bind() } + launch(newThread("Thread C")) { provideZ().bind() } + } + + assertTrue(result is Err) + assertEquals( + expected = BindingError.BindingErrorA, + actual = result.error + ) + assertTrue(xStateChange) + assertTrue(yStateChange) + assertFalse(zStateChange) + } + } + + private fun newThread(name: String): CoroutineContext { + return Executors.newSingleThreadExecutor().asCoroutineDispatcher() + CoroutineName(name) + } }