Though this question is quite an old one, but I would want to add to Roman's answer another possible approach with CopyableThreadContextElement
. Maybe it will be helpful for somebody else.
// Snippet from the source code's comment
class TraceContextElement(private val traceData: TraceData?) : CopyableThreadContextElement<TraceData?> {
companion object Key : CoroutineContext.Key<TraceContextElement>
override val key: CoroutineContext.Key<TraceContextElement> = Key
override fun updateThreadContext(context: CoroutineContext): TraceData? {
val oldState = traceThreadLocal.get()
traceThreadLocal.set(traceData)
return oldState
}
override fun restoreThreadContext(context: CoroutineContext, oldState: TraceData?) {
traceThreadLocal.set(oldState)
}
override fun copyForChild(): TraceContextElement {
// Copy from the ThreadLocal source of truth at child coroutine launch time. This makes
// ThreadLocal writes between resumption of the parent coroutine and the launch of the
// child coroutine visible to the child.
return TraceContextElement(traceThreadLocal.get()?.copy())
}
override fun mergeForChild(overwritingElement: CoroutineContext.Element): CoroutineContext {
// Merge operation defines how to handle situations when both
// the parent coroutine has an element in the context and
// an element with the same key was also
// explicitly passed to the child coroutine.
// If merging does not require special behavior,
// the copy of the element can be returned.
return TraceContextElement(traceThreadLocal.get()?.copy())
}
}
Note that copyForChild
method allows you to propagate thread local data taken from the parent coroutine's last resumption phase to the local context of the child coroutine (as Copyable
in CopyableThreadContextElement
implies), because method copyForChild
will be invoked on the parent coroutine's thread associated with the corresponding resumption phase when a child coroutine was created.
Just by adding TraceContextElement
context element to the root coroutine's context it will be propagated to all child coroutines as context element.
runBlocking(Dispatchers.IO + TraceContextElement(someTraceDataInstance)){...}
Whereas with ContinuationInterceptor
approach additional wrapping can be necessary for child coroutines' builders, if you redefine dispatchers for child coroutines.
fun main() {
runBlocking(WrappedDispatcher(Dispatchers.IO)) {
delay(100)
println("It is wrapped!")
delay(100)
println("It is also wrapped!")
// NOTE: we don't wrap with the WrappedDispatcher class here
// redefinition of the dispatcher leads to replacement of our custom ContinuationInterceptor
// with logic taken from specified dispatcher (in the case below from Dispatchers.Default)
withContext(Dispatchers.Default) {
delay(100)
println("It is nested coroutine, and it isn't wrapped!")
delay(100)
println("It is nested coroutine, and it isn't wrapped!")
}
delay(100)
println("It is also wrapped!")
}
}
with wrapper overriding ContinuationInterceptor
interface
class WrappedDispatcher(
private val dispatcher: ContinuationInterceptor
) : AbstractCoroutineContextElement(ContinuationInterceptor), ContinuationInterceptor {
override fun <T> interceptContinuation(continuation: Continuation<T>): Continuation<T> =
dispatcher.interceptContinuation(ContinuationWrapper(continuation))
private class ContinuationWrapper<T>(val base: Continuation<T>) : Continuation<T> by base {
override fun resumeWith(result: Result<T>) {
println("------WRAPPED START-----")
base.resumeWith(result)
println("------WRAPPED END-------")
}
}
}
output:
------WRAPPED START-----
------WRAPPED END-------
------WRAPPED START-----
It is wrapped!
------WRAPPED END-------
------WRAPPED START-----
It is also wrapped!
------WRAPPED END-------
It is nested coroutine, and it isn't wrapped!
It is nested coroutine, and it isn't wrapped!
------WRAPPED START-----
------WRAPPED END-------
------WRAPPED START-----
It is also wrapped!
------WRAPPED END-------
as you can see for the child (nested) coroutine our wrapper wasn't applied, since we reassigned a ContinuationInterceptor
supplying another dispatcher as a parameter. This can lead to a problem as you can mistakenly forget to wrap a child coroutine's dispatcher.
As a side note, if you decide to choose this approach with ContinuationInterceptor
, then consider to add such extension
fun ContinuationInterceptor.withMyProjectWrappers() = WrappedDispatcher(this)
wrapping your dispatcher with all necessary wrappers you have in the project, obviously it can be easily extended taking specific beans (wrappers) from an IoC container such as Spring.
And also as an extra example of CopyableThreadContextElement
where thread local changes are saved in all resumptions phases.
Executors.newFixedThreadPool(..).asCoroutineDispatcher()
is used to
better illustrate that different threads can be working between
resumptions phases.
val counterThreadLocal: ThreadLocal<Int> = ThreadLocal.withInitial{ 1 }
fun showCounter(){
println("-------------------------------------------------")
println("Thread: ${Thread.currentThread().name}\n Counter value: ${counterThreadLocal.get()}")
}
fun main() {
runBlocking(Executors.newFixedThreadPool(10).asCoroutineDispatcher() + CounterPropagator(1)) {
showCounter()
delay(100)
showCounter()
counterThreadLocal.set(2)
delay(100)
showCounter()
counterThreadLocal.set(3)
val nested = async(Executors.newFixedThreadPool(10).asCoroutineDispatcher()) {
println("-----------NESTED START---------")
showCounter()
delay(100)
counterThreadLocal.set(4)
showCounter()
println("------------NESTED END-----------")
}
nested.await()
showCounter()
println("---------------END------------")
}
}
class CounterPropagator(private var counterFromParenCoroutine: Int) : CopyableThreadContextElement<Int> {
companion object Key : CoroutineContext.Key<CounterPropagator>
override val key: CoroutineContext.Key<CounterPropagator> = Key
override fun updateThreadContext(context: CoroutineContext): Int {
// initialize thread local on the resumption
counterThreadLocal.set(counterFromParenCoroutine)
return 0
}
override fun restoreThreadContext(context: CoroutineContext, oldState: Int) {
// propagate thread local changes between resumption phases in the same coroutine
counterFromParenCoroutine = counterThreadLocal.get()
}
override fun copyForChild(): CounterPropagator {
// propagate thread local changes to children
return CounterPropagator(counterThreadLocal.get())
}
override fun mergeForChild(overwritingElement: CoroutineContext.Element): CoroutineContext {
return CounterPropagator(counterThreadLocal.get())
}
}
output:
-------------------------------------------------
Thread: pool-1-thread-1
Counter value: 1
-------------------------------------------------
Thread: pool-1-thread-2
Counter value: 1
-------------------------------------------------
Thread: pool-1-thread-3
Counter value: 2
-----------NESTED START---------
-------------------------------------------------
Thread: pool-2-thread-1
Counter value: 3
-------------------------------------------------
Thread: pool-2-thread-2
Counter value: 4
------------NESTED END-----------
-------------------------------------------------
Thread: pool-1-thread-4
Counter value: 3
---------------END------------
You can achieve similar behavior with ContinuationInterceptor
(but don't forget to re-wrap dispatchers of child (nested) coroutines in the coroutine builder as was mentioned above)
val counterThreadLocal: ThreadLocal<Int> = ThreadLocal()
class WrappedDispatcher(
private val dispatcher: ContinuationInterceptor,
private var savedCounter: Int = counterThreadLocal.get() ?: 0
) : AbstractCoroutineContextElement(ContinuationInterceptor), ContinuationInterceptor {
override fun <T> interceptContinuation(continuation: Continuation<T>): Continuation<T> =
dispatcher.interceptContinuation(ContinuationWrapper(continuation))
private inner class ContinuationWrapper<T>(val base: Continuation<T>) : Continuation<T> by base {
override fun resumeWith(result: Result<T>) {
counterThreadLocal.set(savedCounter)
try {
base.resumeWith(result)
} finally {
savedCounter = counterThreadLocal.get()
}
}
}
}