2

I have a method to compare two byte arrays. The code is java-style, and there are many "if-else"s.

def assertArray(b1: Array[Byte], b2: Array[Byte]) {
  if (b1 == null && b2 == null) return;
  else if (b1 != null && b2 != null) {
    if (b1.length != b2.length) throw new AssertionError("b1.length != b2.length")
    else {
      for (i <- b1.indices) {
        if (b1(i) != b2(i)) throw new AssertionError("b1(%d) != b2(%d)".format(i, i))
      }
    }
  } else {
    throw new AssertionError("b1 is null while b2 is not, vice versa")
  }
}

I have tried as following, but it's not simplified the code much:

(Option(b1), Option(b2)) match {
    case (Some(b1), Some(b2)) => if ( b1.length == b2.length ) {
       for (i <- b1.indices) {
        if (b1(i) != b2(i)) throw new AssertionError("b1(%d) != b2(%d)".format(i, i))
       }
    } else {
       throw new AssertionError("b1.length != b2.length")
    }
    case (None, None) => _
    case _ => throw new AssertionError("b1 is null while b2 is not, vice versa")
}
Freewind
  • 193,756
  • 157
  • 432
  • 708

5 Answers5

12

Unless you're doing this as an academic exercise, how about

java.util.Arrays.equals(b1, b2)

The description:

Returns true if the two specified arrays of bytes are equal to one another. Two arrays are considered equal if both arrays contain the same number of elements, and all corresponding pairs of elements in the two arrays are equal. In other words, two arrays are equal if they contain the same elements in the same order. Also, two array references are considered equal if both are null.

I will admit to this being 'java style' :-)

Since you're throwing AssertionErrors, you can remove all of the else's:

def assertArray(b1: Array[Byte], b2: Array[Byte]): Unit = {
  if (b1 == b2) return;

  if (b1 == null || b2 == null) throw new AssertionError("b1 is null while b2 is not, vice versa")  

  if (b1.length != b2.length) throw new AssertionError("b1.length != b2.length")

  for (i <- b1.indices) {
    if (b1(i) != b2(i)) throw new AssertionError("b1(%d) != b2(%d)".format(i, i))
  }
}

If, as I suspect, you're actually using this within JUnit tests (hence the assertArray), then you can use a trick which I often do, compare the string representations of the arrays:

def assertArray2(b1: Array[Byte], b2: Array[Byte]): Unit = {
  assertEquals(toString(b1), toString(b2))
}

def toString(b: Array[Byte]) = if (b == null) "null" else java.util.Arrays.asList(b:_*).toString

which will give you the same outcome (an AssertionError), with where the differences are.

Matthew Farwell
  • 60,889
  • 18
  • 128
  • 171
  • +1: The first test can be simplified to `if (b1 == b2) return;` – Peter Lawrey Oct 28 '11 at 10:54
  • 2
    @Peter I don't think that would work, since `==` on Arrays is reference equality: see this Q&A. http://stackoverflow.com/q/2481149/770361 However from Rex's answer there it sounds like you could explicitly wrap the arrays to make it work. – Luigi Plinge Oct 28 '11 at 11:01
  • `==` works for arrays in that if b1 and b2 are both null, it will be true, and if b1 and b2 are the same object, it will be true. – Peter Lawrey Oct 28 '11 at 13:41
  • @Peter Just to be clear Arrays.equals returns true if the elements are the same as well, but the objects are different. – Matthew Farwell Oct 28 '11 at 14:25
  • @MatthewFarwell, True, which is why only `if (b1 == null || b2 == null)` can be replaced with `if (b1 == b2)` – Peter Lawrey Oct 28 '11 at 14:28
  • Sorry, I thought you meant the first test (java.util.Arrays.equals(b1, b2)), rather than the first test in the method. You're right of course :-) – Matthew Farwell Oct 28 '11 at 15:24
6

The standard library provides sameElements for exactly this purpose:

scala> val a1 = Array[Byte](1, 3, 5, 7); val a2 = Array[Byte](1, 3, 5, 7); val a3 = Array[Byte](1, 3, 5, 7, 9)
a1: Array[Byte] = Array(1, 3, 5, 7)
a2: Array[Byte] = Array(1, 3, 5, 7)
a3: Array[Byte] = Array(1, 3, 5, 7, 9)

scala> a1 sameElements a2
res0: Boolean = true

scala> a1 sameElements a3
res1: Boolean = false
Paul Butcher
  • 10,722
  • 3
  • 40
  • 44
1

One possible simplification:

def assertArray(b1: Array[Byte], b2: Array[Byte]) {
    (Option(b1), Option(b2)) match {
        case (None, _) => 
            throw new AssertionError("b1 is null")
        case (_, None) => 
            throw new AssertionError("b2 is null")
        case (Some(Size(b1Size)), Some(Size(b2Size))) if b1Size != b2Size  => 
            throw new AssertionError("b1.length != b2.length")
        case (Some(b1), Some(b2)) if b1 zip b2 find (c => c._1 != c._2) isDefined => 
            throw new AssertionError("Arrays do not match")
        case _ => // everything is OK
    }
}

object Size {
    def unapply[T](arr: Array[T]): Option[Int] = Some(arr.size)
}

Probably can be improved even more, but at least it does not have nested ifs and external loops.

tenshi
  • 26,268
  • 8
  • 76
  • 90
1

A minor improvement to Matthew's solution could be, to return all diffs and not just the first:

def assertArray (b1: Array[Byte], b2: Array[Byte]): Unit = {

  def diffs [T] (a: Array[T], b: Array[T]) = 
    (a.zip (b).filter (e => (e._1 != e._2)))

  if (b1 == null && b2 == null) 
    return;
  if (b1 == null || b2 == null) 
    throw new AssertionError ("b1 is null while b2 is not, vice versa")  
  if (b1.length != b2.length) 
    throw new AssertionError ("b1.length != b2.length")
  val delta = diffs (b1, b2)
  delta.map (d => throw new AssertionError ("" + delta.mkString ))
}

Test invocation:

val ab = (List ((List (47, 99, 13, 23, 42).map (_.toByte)).toArray,
  (List (47, 99, 33, 13, 42).map (_.toByte)).toArray)).toArray

assertArray (ab(0), ab(1))
// java.lang.AssertionError: (13,33)(23,13)
user unknown
  • 35,537
  • 11
  • 75
  • 121
1

Here is my solution using tail recursive:

@scala.annotation.tailrec
def assertArray[T](b1: Array[T], b2: Array[T])(implicit m: Manifest[T]) : Unit = (b1, b2)  match{
    case (null, null) => 
    case (null, a) if a != null => throw new AssertionError 
    case (a, null) if a != null => throw new AssertionError  
    case (Array(), Array()) => 
    case _  => if (b1.length != b2.length ||  b1.head != b2.head ) throw new AssertionError  else assertArray(b1.tail,b2.tail)  
}

and the test casees

assertArray(null,null)
assertArray(Array[Byte](),null)
assertArray(null,Array[Byte]())
assertArray(Array[Byte](),Array[Byte]())
assertArray(Array[Byte](),Array[Byte](1))
assertArray(Array[Byte](1,2,3),Array[Byte](1,2,3))
assertArray(Array[Byte](1,3),Array[Byte](1))

How about this https://gist.github.com/1322299 link

爱国者
  • 4,298
  • 9
  • 47
  • 66