Skip to main content
 首页 » 编程设计

performance之Scala 列表递归性能

2025年05月04日93java哥

这个问题是关于Scala对列表进行模式匹配和递归的方式及其性能。

如果我有一个在列表上递归的函数,并且我在缺点上进行匹配,例如:

def myFunction(xs) = xs match {  
  case Nil => Nil 
  case x :: xs => «something» myFunction(xs) 
} 

在 haskell :
myFunction [] = [] 
myFunction (x:xs) = «something» : myFunction xs 

我使用的语义与我使用的语义相同,例如 Haskell。我认为 Haskell 的实现不会有任何问题,因为这只是您处理列表的方式。对于一个很长的列表(我将在一个包含几千个节点的列表上进行操作),Haskell 不会闪烁(我想虽然;我从未尝试过)。

但是根据我对 Scala 的理解,match 语句将调用 unapply 提取器方法来围绕缺点拆分列表,并将示例扩展到一个对列表没有任何作用的函数:
def myFunction(xs) = xs match {  
  case Nil => Nil 
  case x :: xs => x :: myFunction(xs) 
} 

在 haskell :
myFunction [] = [] 
myFunction (x:xs) = x : myFunction xs 

它会调用apply extractor方法将它重新组合在一起。对于一长串的 list ,我想这将是非常昂贵的。

为了说明这一点,在我的特定情况下,我想对字符列表进行递归并累积各种内容,其中输入字符串最多可达几十 KB。

如果我想递归一个长列表,我真的会为递归的每一步调用构造函数和提取器吗?还是有优化?或者更好的方法来做到这一点?在这种情况下,我需要几个累加器变量,显然我不会只是递归列表什么都不做......

(请原谅我的 Haskell,我已经两年没有写过一行了。)

(是的,我要进行尾递归。)

请您参考如下方法:

首先,Haskell 是非严格的,所以这些尾部的函数调用可能永远不会被评估。另一方面,Scala 将在返回之前计算所有列表。更接近 Haskell 中发生的事情的实现是这样的:

def myFunction[T](l: List[T]): Stream[T] = l match {    
  case Nil => Stream.empty   
  case x :: xs => x #:: myFunction(xs) 
} 

收到 List ,这是严格的,并返回 Stream这是非严格的。

现在,如果你想避免模式匹配和提取器(尽管在这种特殊情况下没有调用 - 见下文),你可以这样做:
def myFunction[T](xs: List[T]): Stream[T] = 
  if (xs.isEmpty) Stream.empty else xs.head #:: myFunction(xs.tail) 

我刚刚意识到您打算进行尾递归。你写的不是尾递归的,因为你在前面加上 x到递归的结果。处理列表时,如果向后计算结果然后反转,则会得到尾递归:
def myFunction[T](xs: List[T]): List[T] = { 
  def recursion(input: List[T], output: List[T]): List[T] = input match { 
    case x :: xs => recursion(xs, x :: output) 
    case Nil => output 
  } 
  recursion(xs, Nil).reverse 
} 

最后,让我们反编译一个例子,看看它是如何工作的:
class ListExample { 
  def test(o: Any): Any = o match { 
    case Nil => Nil 
    case x :: xs => xs 
    case _ => null 
  } 
} 

产生:
public class ListExample extends java.lang.Object implements scala.ScalaObject{ 
public ListExample(); 
  Code: 
   0:   aload_0 
   1:   invokespecial   #10; //Method java/lang/Object."<init>":()V 
   4:   return 
 
public java.lang.Object test(java.lang.Object); 
  Code: 
   0:   aload_1 
   1:   astore_2 
   2:   getstatic       #18; //Field scala/Nil$.MODULE$:Lscala/Nil$; 
   5:   aload_2 
   6:   invokestatic    #24; //Method scala/runtime/BoxesRunTime.equals:(Ljava/lang/Object;Ljava/lang/Object;)Z 
   9:   ifeq    18 
   12:  getstatic       #18; //Field scala/Nil$.MODULE$:Lscala/Nil$; 
   15:  goto    38 
   18:  aload_2 
   19:  instanceof      #26; //class scala/$colon$colon 
   22:  ifeq    35 
   25:  aload_2 
   26:  checkcast       #26; //class scala/$colon$colon 
   29:  invokevirtual   #30; //Method scala/$colon$colon.tl$1:()Lscala/List; 
   32:  goto    38 
   35:  aconst_null 
   36:  pop 
   37:  aconst_null 
   38:  areturn 
 
public int $tag()   throws java.rmi.RemoteException; 
  Code: 
   0:   aload_0 
   1:   invokestatic    #42; //Method scala/ScalaObject$class.$tag:(Lscala/ScalaObject;)I 
   4:   ireturn 
 
} 

解码,它首先调用方法 equals在传递的参数和对象上 Nil .如果为真,则返回后者。否则,它调用 instanceOf[::]在对象上。如果为真,它将对象强制转换为该对象,并调用方法 tl在上面。如果所有这些都失败了,则加载辅 null并返回它。

所以,你看, x :: xs没有调用任何提取器。

至于累积,您可能需要考虑另一种模式:
val list = List.fill(100)(scala.util.Random.nextInt) 
case class Accumulator(negative: Int = 0, zero: Int = 0, positive: Int = 0) 
val accumulator = list.foldLeft(Accumulator())( (acc, n) =>  
  n match { 
    case neg if neg < 0 => acc.copy(negative = acc.negative + 1) 
    case zero if zero == 0 => acc.copy(zero = acc.zero + 1) 
    case pos if pos > 0 => acc.copy(positive = acc.positive + 1) 
  }) 

默认参数和复制方法是我使用的 Scala 2.8 特性,只是为了使示例更简单,但重点是使用 foldLeft当您想在列表上累积事物时的方法。