58

I want to have a binary operator cross (cross-product/cartesian product) that operates with traversables in Scala:

val x = Seq(1, 2)
val y = List('hello', 'world', 'bye')
val z = x cross y    # i can chain as many traversables e.g. x cross y cross w etc

assert z == ((1, 'hello'), (1, 'world'), (1, 'bye'), (2, 'hello'), (2, 'world'), (2, 'bye'))

What is the best way to do this in Scala only (i.e. not using something like scalaz)?

Martin Thoma
  • 124,992
  • 159
  • 614
  • 958
pathikrit
  • 32,469
  • 37
  • 142
  • 221

8 Answers8

89

You can do this pretty straightforwardly with an implicit class and a for-comprehension in Scala 2.10:

implicit class Crossable[X](xs: Traversable[X]) {
  def cross[Y](ys: Traversable[Y]) = for { x <- xs; y <- ys } yield (x, y)
}

val xs = Seq(1, 2)
val ys = List("hello", "world", "bye")

And now:

scala> xs cross ys
res0: Traversable[(Int, String)] = List((1,hello), (1,world), ...

This is possible before 2.10—just not quite as concise, since you'd need to define both the class and an implicit conversion method.

You can also write this:

scala> xs cross ys cross List('a, 'b)
res2: Traversable[((Int, String), Symbol)] = List(((1,hello),'a), ...

If you want xs cross ys cross zs to return a Tuple3, however, you'll need either a lot of boilerplate or a library like Shapeless.

Travis Brown
  • 138,631
  • 12
  • 375
  • 680
  • 1
    you can extends AnyVal for you implicit class to optimize and use [ValueClass](http://docs.scala-lang.org/overviews/core/value-classes.html) – twillouer Feb 06 '13 at 22:49
  • 2
    Thanks, I wish though `x cross y cross z` returned a `Tuple3` and not a `Tuple2(Tuple2, Value)` – pathikrit Feb 06 '13 at 22:51
  • @wrick: Off the top of my head I'm pretty sure that'd be possible, and I'd be happy to take a shot at it, but not without Shapeless. – Travis Brown Feb 06 '13 at 22:52
  • 3
    sure, please post an example with shapeless – pathikrit Feb 07 '13 at 05:45
  • 2
    Overloading a method to have different return types is not a good idea. Better to use HList/HArray in that case. – Jesper Nordenberg Feb 07 '13 at 07:42
  • 2
    @JesperNordenberg: You should be able to get `x cross y cross z` safely with the same kind of approach you see in `ProductLens` and the `~` method. – Travis Brown Feb 07 '13 at 12:33
  • Yes, it's doable with implicit parameters and type projections, but I don't think it's a good idea because of the combinatorial explosion of tuple types. – Jesper Nordenberg Feb 08 '13 at 08:21
  • @JesperNordenberg: Not sure what you mean—that approach definitely wouldn't involve type projections, and I don't see how it could lead to a "combinatorial explosion of tuple types". When I get a minute I'll post an implementation and we can discuss. – Travis Brown Feb 08 '13 at 12:28
  • 1
    For sake of completeness: Since I was having the exact same wish as wrick and was struggling to find a solution, I though this problem could deserve [a separate question](http://stackoverflow.com/q/16219545/1804173). – bluenote10 Apr 26 '13 at 08:18
35

cross x_list and y_list with:

val cross = x_list.flatMap(x => y_list.map(y => (x, y)))
clemens
  • 16,716
  • 11
  • 50
  • 65
王昕元
  • 351
  • 3
  • 2
14

Here is the implementation of recursive cross product of arbitrary number of lists:

def crossJoin[T](list: Traversable[Traversable[T]]): Traversable[Traversable[T]] =
  list match {
    case xs :: Nil => xs map (Traversable(_))
    case x :: xs => for {
      i <- x
      j <- crossJoin(xs)
    } yield Traversable(i) ++ j
  }

crossJoin(
  List(
    List(3, "b"),
    List(1, 8),
    List(0, "f", 4.3)
  )
)

res0: Traversable[Traversable[Any]] = List(List(3, 1, 0), List(3, 1, f), List(3, 1, 4.3), List(3, 8, 0), List(3, 8, f), List(3, 8, 4.3), List(b, 1, 0), List(b, 1, f), List(b, 1, 4.3), List(b, 8, 0), List(b, 8, f), List(b, 8, 4.3))
Milad Khajavi
  • 2,769
  • 9
  • 41
  • 66
  • 1
    Nice code, but it might be more efficient to do the recursive call outside the `for` statement. Also, adding a `case Nil => Nil` would catch the edge case of an empty list. – Tim Jan 23 '19 at 17:06
  • @Tim How do you call recursive outside the for statement? – Milad Khajavi Jan 23 '19 at 19:53
  • 1
    I posted a modification of your answer in my answer to [another question](https://stackoverflow.com/questions/54330356/scala-create-all-possible-permutations-of-a-sentence-based-synonyms-of-each-wor). – Tim Jan 24 '19 at 09:53
9

Alternative for cats users:

sequence on List[List[A]] creates cross product:

import cats.implicits._

val xs = List(1, 2)
val ys = List("hello", "world", "bye")

List(xs, ys).sequence 
//List(List(1, hello), List(1, world), List(1, bye), List(2, hello), List(2, world), List(2, bye))
Krzysztof Atłasik
  • 21,985
  • 6
  • 54
  • 76
3

Here is something similar to Milad's response, but non-recursive.

def cartesianProduct[T](seqs: Seq[Seq[T]]): Seq[Seq[T]] = {
  seqs.foldLeft(Seq(Seq.empty[T]))((b, a) => b.flatMap(i => a.map(j => i ++ Seq(j))))
}

Based off this blog post.

turtlemonvh
  • 9,149
  • 6
  • 47
  • 53
2
class CartesianProduct(product: Traversable[Traversable[_ <: Any]]) {
  override def toString(): String = {
    product.toString
  }

  def *(rhs: Traversable[_ <: Any]): CartesianProduct = {
      val p = product.flatMap { lhs =>
        rhs.map { r =>
          lhs.toList :+ r
        }
      }

      new CartesianProduct(p)
  }
}

object CartesianProduct {
  def apply(traversable: Traversable[_ <: Any]): CartesianProduct = {
    new CartesianProduct(
      traversable.map { t =>
        Traversable(t)
      }
    )
  }
}

// TODO: How can this conversion be made implicit?
val x = CartesianProduct(Set(0, 1))
val y = List("Alice", "Bob")
val z = Array(Math.E, Math.PI)

println(x * y * z) // Set(List(0, Alice, 3.141592653589793), List(0, Alice, 2.718281828459045), List(0, Bob, 3.141592653589793), List(1, Alice, 2.718281828459045), List(0, Bob, 2.718281828459045), List(1, Bob, 3.141592653589793), List(1, Alice, 3.141592653589793), List(1, Bob, 2.718281828459045))

// TODO: How can this conversion be made implicit?
val s0 = CartesianProduct(Seq(0, 0))
val s1 = Seq(0, 0)

println(s0 * s1) // List(List(0, 0), List(0, 0), List(0, 0), List(0, 0))
Noel Yap
  • 18,822
  • 21
  • 92
  • 144
  • 1
    What happens when you do `Seq(0, 0) * Seq(0, 0)`? I would expect 4 items. This would return 1 item. Also this is too untyped for my favor. Maybe something using HLists? – pathikrit Dec 06 '16 at 15:37
  • I don't know what you mean by 'non-deterministic'. Do you mean that `Set`s are unordered? – Noel Yap Dec 06 '16 at 15:38
  • Yes, `Set`s are unordered – pathikrit Dec 06 '16 at 15:58
  • I've changed the code to use lists. The result still depends upon whether or not `Set` is used as the underlying type for the first argument. I suppose that's due to the use of `Traversable`. – Noel Yap Dec 06 '16 at 17:13
0

Similar to other responses, just my approach.

def loop(lst: List[List[Int]],acc:List[Int]): List[List[Int]] = {
  lst match {
    case head :: Nil => head.map(_ :: acc)
    case head :: tail => head.flatMap(x => loop(tail,x :: acc))
    case Nil => ???
  }
}
val l1 = List(10,20,30,40)
val l2 = List(2,4,6)
val l3 = List(3,5,7,9,11)

val lst = List(l1,l2,l3)

loop(lst,List.empty[Int])
0

You could use applicative:

import cats.implicits._

val xs = Seq(1, 2)
val ys = List("hello", "world", "bye")

(xs,ys).mapN((x,y) => (x,y))

https://typelevel.org/cats/typeclasses/applicative.html

Rob Wills
  • 3
  • 1
  • Minor edits required. Add link to mapN docs. Define xs,ys or reuse x,y from question. Different approach and acceptable solution – devilpreet Feb 23 '22 at 09:25