2

I'd like to create a Scala macro & which in the case of a field will return a getter/setter pair and in the case of a method a partially applied function. Something like the following:

trait ValRef[T] {
    def get(): T 
}

trait VarRef[T] extends ValRef[T] {
    def set(x: T): Unit
}

// Example:

case class Foo() {
   val v = "whatever"
   var u = 100
   def m(x: Int): Unit 
}

val x = new Foo
val cp: ValRef[String] = &x.v
val p: VarRef[Int] = &x.u
p.set(300)
val f: Int => Unit = &x.m
f(p.get())

I have no experience with Scala macros, however I assume it should be fairly straightforward for those who do.

1 Answers1

4

I recently started reading about Scala's macros and found your question to be an interesting exercise. I only implemented pointers to val and var values, but since you are asking for only a sketch of the code, I thought I'd share what I found so far.

Edit: Regarding pointers to methods: If I am not misunderstanding something, this is already a feature of Scala. given class Foo{ def m(i: Int): Unit = println(i)}, you get a function as val f: Int => Unit = new Foo().m _

  • If you are only interested in the code, scroll to the bottom of the answer.

Note that macros are executed at compile time, and hence must be precompiled. If you are using IntelliJ (or Eclipse), then I'd suggest to put all macro-relevant code in a separate project.

As you mentioned, we have the two pointer-traits

trait ValRef[T] {
  def get(): T
}

trait VarRef[T] extends ValRef[T] {
  def set(x: T): Unit
}

Now we want to implement a method & that, for a given reference, i.e., name or qualifier.name, returns a ValRef. If the reference refers to a mutable value, then the result should be a VarRef.

def &[T](value: T): ValRef[T]

So far, this is just regular Scala code. The method & accepts any expression and returns a ValRef of the same type as the argument.

Let us define a macro that implements the logic of &:

def pointer[T: c.WeakTypeTag](c: scala.reflect.macros.blackbox.Context)(value: c.Expr[T]) = ???

The signature should be mostly standard:

  • c - the Context - contains information collected by the compiler that uses the macro.
  • T is the type of the expression that is passed to &
  • value corresponds to the argument of &, but since the implementation operates on Scala's AST, it is of type Expr[T] and not of the original type T

A bit special is the use of WeakTypeTag, which I don't fully understand myself. The documentation states:

Type parameters in an implementation may come with WeakTypeTag context bounds. In that case the corresponding type tags describing the actual type arguments instantiated at the application site will be passed along when the macro is expanded.

The intersting part is the implementation of pointer. Since the result of the method is supposed to be used by the compiler whenever the method & is called, it must return an AST. Hence, we want to generate a tree. The question is, what should the tree look like?

Fortunately, since Scala 2.11 there exists something called Quasiquotes. Quasiquotes can help us in building the tree from a string value.

Let's simplify the problem first: Instead of differentiating between val and var references, we always return a VarRef. For a VarRef generated by x.y

  • get() should return x.y
  • set(x) should execute x.y = x

So we want to generate a tree that represents the instantiation of an anonymous subclass of VarRef[T]. Because we cannot use the generic type T directly in the Quasiquote, we first need the tree representation of the type, which we can get by val tpe = value.tree.tpe

Now, our Quasiquote looks as follows:

q"""
  new VarRef[$tpe] {
    def get(): $tpe = $value

    def set(x: $tpe): Unit = {
      $value = x
    }
  }
"""

This implementation should work as long as we are only creating pointers to var references. However, as soon as we create a pointer to a val reference, compilation fails, because of "reassignment to val". Hence, our macro needs to distinguish between the two.

Apparently, Symbols provide this kind of information. We expect pointers only to be created for references, which should provide a TermSymbol.

val symbol: TermSymbol = value.tree.symbol.asTerm

Now, the TermSymbol api provides us with methods isVal and isVar, but they seem to only work for local variables. I'm not sure what the "right way" to discover whether a reference is var or val is, but the following seems to work:

if(symbol.isVar || symbol.setter != NoSymbol) {

The trick is that symbols of qualified names seem to provide a setter symbol iff they are var references. Otherwise, setter returns NoSymbol.


So the macro code looks as follows:

trait ValRef[T] {
  def get(): T
}

trait VarRef[T] extends ValRef[T] {
  def set(x: T): Unit
}

object PointerMacro {

  import scala.language.experimental.macros

  def pointer[T: c.WeakTypeTag](c: scala.reflect.macros.blackbox.Context)(value: c.Expr[T]) = {
    import c.universe._

    val symbol: TermSymbol = value.tree.symbol.asTerm
    val tpe = value.tree.tpe

    if(symbol.isVar || symbol.setter != NoSymbol) {
      q"""
        new VarRef[$tpe] {
          def get(): $tpe = $value

          def set(x: $tpe): Unit = {
            $value = x
          }
        }
      """
    } else {
      q"""
        new ValRef[$tpe] {
          def get(): $tpe = $value
        }
      """
    }
  }

  def &[T](value: T): ValRef[T] = macro pointer[T]
}

If you compile this code and add it to your project's classpath, then you should be able to create pointers like this:

case class Foo() {
  val v = "whatever"
  var u = 100
}

object Example{
  import PointerMacro.&

  def main(args: Array[String]): Unit = {
    val x = new Foo
    val mainInt = 90
    var mainString = "this is main"

    val localValPointer: ValRef[Int] = &(mainInt)
    val localVarPointer: VarRef[String] = &(mainString).asInstanceOf[VarRef[String]]
    val memberValPointer: ValRef[String] = &(x.v)
    val memberVarPointer: VarRef[Int] = &(x.u).asInstanceOf[VarRef[Int]]

    println(localValPointer.get())
    println(localVarPointer.get())
    println(memberValPointer.get())
    println(memberVarPointer.get())

    localVarPointer.set("Hello World")
    println(localVarPointer.get())

    memberVarPointer.set(62)
    println(memberVarPointer.get())

  }
}

which, when run, should print

90
this is main
whatever
100
Hello World
62
Community
  • 1
  • 1
Kulu Limpa
  • 3,501
  • 1
  • 21
  • 31