在Python中实现“波浪折叠函数”算法的问题

简而言之:

我在Python 2.7中执行Wave Collapse Function算法的实现存在缺陷,但是我无法确定问题所在。 我需要帮助来找出我可能会丢失或做错的事情。

什么是波浪塌陷函数算法?

它是Maxim Gumin在2016年编写的一种算法,可以从样本图像生成程序模式。 您可以在此处(2D重叠模型)和此处(3D瓷砖模型)看到它的运行情况。

实施目标:

将算法(2D重叠模型)简化为本质,并避免原始C#脚本的冗长和笨拙(令人惊讶的是,它很长且难以阅读)。 这是尝试使该算法更短,更清晰和pythonic版本。

该实现的特征:

我正在使用处理(Python模式),这是一种用于视觉设计的软件,可简化图像处理(没有PIL,没有Matplotlib等)。 主要缺点是我仅限于Python 2.7,并且无法导入numpy。

与原始版本不同,此实现:

  • 不是面向对象的(处于当前状态),因此更易于理解/更接近伪代码
  • 使用一维数组而不是二维数组
  • 使用数组切片进行矩阵操作

算法(据我了解)

1 /读取输入位图,存储每个NxN模式并计数它们的出现。   (可选:具有旋转和反射的增强图案数据。)

例如,当N = 3时:

enter image description here

2 /预计算并存储模式之间的所有可能的邻接关系。在下面的示例中,图案207、242、182和125可以与图案246的右侧重叠

enter image description here

3 /创建一个具有输出尺寸的数组(对于wave称为H[22] = 0)。 此数组的每个元素都是一个数组,其中包含每个模式的状态(True的状态为True)。

例如,假设我们在输入中计算了326个唯一模式,并且希望输出尺寸为20 x 20(400个单元)。 然后,“ Wave”数组将包含400个(20x20)数组,每个数组包含326个布尔值。

开始时,所有布尔值都设置为H[22] = 0,因为在Wave的任何位置都允许使用每种模式。

W = [[True for pattern in xrange(len(patterns))] for cell in xrange(20*20)]

4 /创建另一个具有输出尺寸的数组(称为H[22] = 0)。 此数组的每个元素都是一个浮点数,在输出中保留其对应单元格的“熵”值。

此处的熵是指香农熵,它是根据Wave中特定位置的有效模式数量来计算的。 单元格的有效模式越多(在Wave中设置为H[22] = 0),其熵越高。

例如,要计算单元格22的熵,我们查看其在波中的对应索引(H[22] = 0),并对设置为246的布尔数进行计数。现在,可以使用Shannon公式计算熵。 计算结果将以相同的索引True存储在H中

开始时,由于每个单元的所有模式都设置为246,因此所有单元的熵值相同(H[22] = 0中每个位置的浮点数相同)。

H = [entropyValue for cell in xrange(20*20)]

这4个步骤是介绍性步骤,它们是初始化算法所必需的。 现在开始算法的核心:

5 /观察:

查找具有最小非零熵的单元格的索引(请注意,在第一次迭代时,所有熵都是相等的,因此我们需要随机选择一个单元格的索引。)

然后,查看Wave中相应索引处的仍然有效的模式,并随机选择其中一个模式,并根据模式在输入图像中出现的频率进行加权(加权选择)。

例如,如果H[22] = 0中的最小值位于索引22(246),我们将查看在W[22]中设置为True的所有模式,并根据其出现在输入中的次数随机选择一个。 (请记住,在第1步中,我们已经计算出每种模式的出现次数)。 这样可以确保模式在输出中出现的分布与输入中的分布相似。

6 /收起:

现在,我们将选定模式的索引分配给具有最小熵的单元。 这意味着,除了已选择的波形外,Wave中相应位置的每个图形都设置为H[22] = 0

例如,如果已将246中的模式H[22] = 0设置为True,则将所有其他模式设置为W[22]。单元格246分配了246单元格。在输出单元22中,将填充图案246的第一种颜色(左上角)。(在此示例中为蓝色)

7 /传播:

由于邻接限制,该模式选择会对Wave中的相邻单元产生影响。 与最近折叠的单元格的顶部和上方,左侧和右侧的单元格相对应的布尔数组需要相应地更新。

例如,如果单元格H[22] = 0已折叠并分配了模式246,则必须修改True(左),W[22](右),246(向上)和False(向下),以便它们仅保留与True相邻的模式 图案246

例如,回顾步骤2的图片,我们可以看到只有样式207、242、182和125可以放在样式246的右侧。这意味着H[22] = 0(单元格246的右侧)需要保留样式207 ,242、182和125分别为True,并将数组中的所有其他模式设置为W[22]。如果这些模式不再有效(由于先前的约束已设置为246),则该算法正面临矛盾。

8 /更新熵

由于某个单元已经崩溃(选择了一个模式,设置为H[22] = 0),并且其周围的单元也进行了相应的更新(将非相邻模式设置为246),所有这些单元的熵均已更改,需要重新计算。 (请记住,单元的熵与其在Wave中保存的有效模式的数量相关。)

在此示例中,单元格22的熵现在为0(H[22] = 0,因为只有模式246W[22]处设置为True)并且其相邻单元格的熵已减小(与模式246不相邻的模式已设置为246 )。

现在,该算法到达第一次迭代的末尾,并将遍历步骤5(查找具有最小非零熵的单元)到8(更新熵),直到所有单元都折叠为止。

我的剧本

您需要安装使用Python模式处理才能运行此脚本。它包含大约80行代码(与原始脚本的〜1000行相比要短一些),这些代码已完全注释,因此可以快速理解。 您还需要下载输入图像并相应地更改第16行的路径。

from collections import Counter
from itertools import chain, izip
import math

d = 20  # dimensions of output (array of dxd cells)
N = 3 # dimensions of a pattern (NxN matrix)

Output = [120 for i in xrange(d*d)] # array holding the color value for each cell in the output (at start each cell is grey = 120)

def setup():
    size(800, 800, P2D)
    textSize(11)

    global W, H, A, freqs, patterns, directions, xs, ys, npat

    img = loadImage('Flowers.png') # path to the input image
    iw, ih = img.width, img.height # dimensions of input image
    xs, ys = width//d, height//d # dimensions of cells (squares) in output
    kernel = [[i + n*iw for i in xrange(N)] for n in xrange(N)] # NxN matrix to read every patterns contained in input image
    directions = [(-1, 0), (1, 0), (0, -1), (0, 1)] # (x, y) tuples to access the 4 neighboring cells of a collapsed cell
    all = [] # array list to store all the patterns found in input



    # Stores the different patterns found in input
    for y in xrange(ih):
        for x in xrange(iw):

            ''' The one-liner below (cmat) creates a NxN matrix with (x, y) being its top left corner.
                This matrix will wrap around the edges of the input image.
                The whole snippet reads every NxN part of the input image and store the associated colors.
                Each NxN part is called a 'pattern' (of colors). Each pattern can be rotated or flipped (not mandatory). '''


            cmat = [[img.pixels[((x+n)%iw)+(((a[0]+iw*y)/iw)%ih)*iw] for n in a] for a in kernel]

            # Storing rotated patterns (90°, 180°, 270°, 360°) 
            for r in xrange(4):
                cmat = zip(*cmat[::-1]) # +90° rotation
                all.append(cmat) 

            # Storing reflected patterns (vertical/horizontal flip)
            all.append(cmat[::-1])
            all.append([a[::-1] for a in cmat])




    # Flatten pattern matrices + count occurences 

    ''' Once every pattern has been stored,
        - we flatten them (convert to 1D) for convenience
        - count the number of occurences for each one of them (one pattern can be found multiple times in input)
        - select unique patterns only
        - store them from less common to most common (needed for weighted choice)'''

    all = [tuple(chain.from_iterable(p)) for p in all] # flattern pattern matrices (NxN --> [])
    c = Counter(all)
    freqs = sorted(c.values()) # number of occurences for each unique pattern, in sorted order
    npat = len(freqs) # number of unique patterns
    total = sum(freqs) # sum of frequencies of unique patterns
    patterns = [p[0] for p in c.most_common()[:-npat-1:-1]] # list of unique patterns sorted from less common to most common



    # Computes entropy

    ''' The entropy of a cell is correlated to the number of possible patterns that cell holds.
        The more a cell has valid patterns (set to 'True'), the higher its entropy is.
        At start, every pattern is set to 'True' for each cell. So each cell holds the same high entropy value'''

    ent = math.log(total) - sum(map(lambda x: x * math.log(x), freqs)) / total



    # Initializes the 'wave' (W), entropy (H) and adjacencies (A) array lists

    W = [[True for _ in xrange(npat)] for i in xrange(d*d)] # every pattern is set to 'True' at start, for each cell
    H = [ent for i in xrange(d*d)] # same entropy for each cell at start (every pattern is valid)
    A = [[set() for dir in xrange(len(directions))] for i in xrange(npat)] #see below for explanation




    # Compute patterns compatibilities (check if some patterns are adjacent, if so -> store them based on their location)

    ''' EXAMPLE:
    If pattern index 42 can placed to the right of pattern index 120,
    we will store this adjacency rule as follow:

                     A[120][1].add(42)

    Here '1' stands for 'right' or 'East'/'E'

    0 = left or West/W
    1 = right or East/E
    2 = up or North/N
    3 = down or South/S '''

    # Comparing patterns to each other
    for i1 in xrange(npat):
        for i2 in xrange(npat):
            for dir in (0, 2):
                if compatible(patterns[i1], patterns[i2], dir):
                    A[i1][dir].add(i2)
                    A[i2][dir+1].add(i1)


def compatible(p1, p2, dir):

    '''NOTE: 
    what is refered as 'columns' and 'rows' here below is not really columns and rows 
    since we are dealing with 1D patterns. Remember here N = 3'''

    # If the first two columns of pattern 1 == the last two columns of pattern 2 
    # --> pattern 2 can be placed to the left (0) of pattern 1
    if dir == 0:
        return [n for i, n in enumerate(p1) if i%N!=2] == [n for i, n in enumerate(p2) if i%N!=0]

    # If the first two rows of pattern 1 == the last two rows of pattern 2
    # --> pattern 2 can be placed on top (2) of pattern 1
    if dir == 2:
        return p1[:6] == p2[-6:]



def draw():    # Equivalent of a 'while' loop in Processing (all the code below will be looped over and over until all cells are collapsed)
    global H, W, grid

    ### OBSERVATION
    # Find cell with minimum non-zero entropy (not collapsed yet)

    '''Randomly select 1 cell at the first iteration (when all entropies are equal), 
       otherwise select cell with minimum non-zero entropy'''

    emin = int(random(d*d)) if frameCount <= 1 else H.index(min(H)) 



    # Stoping mechanism

    ''' When 'H' array is full of 'collapsed' cells --> stop iteration '''

    if H[emin] == 'CONT' or H[emin] == 'collapsed': 
        print 'stopped'
        noLoop()
        return



    ### COLLAPSE
    # Weighted choice of a pattern

    ''' Among the patterns available in the selected cell (the one with min entropy), 
        select one pattern randomly, weighted by the frequency that pattern appears in the input image.
        With Python 2.7 no possibility to use random.choice(x, weight) so we have to hard code the weighted choice '''

    lfreqs = [b * freqs[i] for i, b in enumerate(W[emin])] # frequencies of the patterns available in the selected cell
    weights = [float(f) / sum(lfreqs) for f in lfreqs] # normalizing these frequencies
    cumsum = [sum(weights[:i]) for i in xrange(1, len(weights)+1)] # cumulative sums of normalized frequencies
    r = random(1)
    idP = sum([cs < r for cs in cumsum])  # index of selected pattern 

    # Set all patterns to False except for the one that has been chosen   
    W[emin] = [0 if i != idP else 1 for i, b in enumerate(W[emin])]

    # Marking selected cell as 'collapsed' in H (array of entropies)
    H[emin] = 'collapsed' 

    # Storing first color (top left corner) of the selected pattern at the location of the collapsed cell
    Output[emin] = patterns[idP][0]



    ### PROPAGATION
    # For each neighbor (left, right, up, down) of the recently collapsed cell
    for dir, t in enumerate(directions):
        x = (emin%d + t[0])%d
        y = (emin/d + t[1])%d
        idN = x + y * d #index of neighbor

        # If that neighbor hasn't been collapsed yet
        if H[idN] != 'collapsed': 

            # Check indices of all available patterns in that neighboring cell
            available = [i for i, b in enumerate(W[idN]) if b]

            # Among these indices, select indices of patterns that can be adjacent to the collapsed cell at this location
            intersection = A[idP][dir] & set(available) 

            # If the neighboring cell contains indices of patterns that can be adjacent to the collapsed cell
            if intersection:

                # Remove indices of all other patterns that cannot be adjacent to the collapsed cell
                W[idN] = [True if i in list(intersection) else False for i in xrange(npat)]


                ### Update entropy of that neighboring cell accordingly (less patterns = lower entropy)

                # If only 1 pattern available left, no need to compute entropy because entropy is necessarily 0
                if len(intersection) == 1: 
                    H[idN] = '0' # Putting a str at this location in 'H' (array of entropies) so that it doesn't return 0 (float) when looking for minimum entropy (min(H)) at next iteration


                # If more than 1 pattern available left --> compute/update entropy + add noise (to prevent cells to share the same minimum entropy value)
                else:
                    lfreqs = [b * f for b, f in izip(W[idN], freqs) if b] 
                    ent = math.log(sum(lfreqs)) - sum(map(lambda x: x * math.log(x), lfreqs)) / sum(lfreqs)
                    H[idN] = ent + random(.001)


            # If no index of adjacent pattern in the list of pattern indices of the neighboring cell
            # --> mark cell as a 'contradiction'
            else:
                H[idN] = 'CONT'



    # Draw output

    ''' dxd grid of cells (squares) filled with their corresponding color.      
        That color is the first (top-left) color of the pattern assigned to that cell '''

    for i, c in enumerate(Output):
        x, y = i%d, i/d
        fill(c)
        rect(x * xs, y * ys, xs, ys)

        # Displaying corresponding entropy value
        fill(0)
        text(H[i], x * xs + xs/2 - 12, y * ys + ys/2)

问题

尽管我竭尽全力将上面描述的所有步骤都仔细地编写到代码中,但是此实现返回的结果非常奇怪和令人失望:

20x20输出的示例

enter image description here

模式分布和邻接约束似乎都得到了尊重(与输入中相同数量的蓝色,绿色,黄色和棕色和相同类型的模式:水平地面,绿色茎)。

但是这些模式:

  • 经常断开
  • 通常不完整(缺少由4个黄色花瓣组成的“头”)
  • 遇到太多矛盾的状态(标记为“ CONT”的灰色单元格)

关于最后一点,我应该澄清矛盾的状态是正常的,但很少发生(如本文第6页中段和本文中所述)

数小时的调试使我确信入门步骤(1至5)是正确的(计数和存储模式,邻接和熵计算,数组初始化)。 这使我认为算法的核心部分必须有所失误(步骤6至8)。 我可能没有正确执行这些步骤之一,或者我错过了逻辑的关键要素。

因此,在此问题上的任何帮助将不胜感激!

同样,任何基于提供的脚本的答案(无论是否使用处理)都受到欢迎。

有用的其他资源:

本文来自Stephen Sherratt,来自Karth&Smith。另外,为了进行比较,我建议检查其他Python实现(包含非强制性的回溯机制)。

注意:我已尽力使这个问题尽可能清晰(带有GIF和插图的全面解释,带有有用链接和资源的带有完整注释的代码),但是如果出于某些原因您决定拒绝它,请留下简短的评论以进行解释 为什么这样做。

2个解决方案
15 votes

@mbrig和@Leon提出的假设是,传播步骤在整个细胞堆栈上进行迭代(而不是局限于4个直接邻居的集合)是正确的。 以下是在回答我自己的问题时尝试提供更多详细信息的尝试。

该问题在传播时发生在步骤7中。 原始算法确实更新了特定小区BUT的4个直接邻居:

  • 该特定小区的索引又被先前更新的邻居的索引替换。
  • 每次单元崩溃时都会触发此级联过程
  • 并且只要特定单元格的相邻模式在其相邻单元格之一中可用就持续

换句话说,正如评论中所提到的,这是递归类型的传播,它不仅更新折叠单元的邻居,而且还更新邻居的邻居……等等,只要邻接是可能的。

详细算法

单元折叠后,其索引将放入堆栈中。 该堆栈意味着稍后可以临时存储相邻单元的索引

stack = set([emin]) #emin = index of cell with minimum entropy that has been collapsed

只要该堆栈充满索引,传播就会持续:

while stack:

我们要做的第一件事是while堆栈中包含的最后一个索引(目前唯一的索引),并获取其4个相邻单元(E,W,N,S)的索引。 我们必须使它们保持边界,并确保它们环绕。

while stack:
    idC = stack.pop() # index of current cell
    for dir, t in enumerate(mat):
        x = (idC%w + t[0])%w
        y = (idC/w + t[1])%h
        idN = x + y * w  # index of neighboring cell

在继续进行操作之前,我们确保相邻单元尚未折叠(我们不希望更新只有一个可用模式的单元):

        if H[idN] != 'c': 

然后,我们检查可以放置在该位置的所有模式。 例如:如果相邻的单元格在当前单元格的左侧(东侧),我们将查看可放置在当前单元格中每个图案左侧的所有图案。

            possible = set([n for idP in W[idC] for n in A[idP][dir]])

我们还将查看相邻单元中可用的模式:

            available = W[idN]

现在,我们确保确实必须更新相邻单元。 如果所有可用模式都已在所有可能模式的列表中->则无需对其进行更新(算法会跳过该邻居并继续执行下一个操作):

            if not available.issubset(possible):

但是,如果它不是while列表的子集—>我们看一下这两个集合的交集(所有可以放在该位置并且“幸运的”都可以在同一位置使用的模式):

                intersection = possible & available

如果它们不相交(本可以放置在此处但不可用的图案),则意味着我们遇到了“矛盾”。 我们必须停止整个WFC算法。

                if not intersection:
                    print 'contradiction'
                    noLoop()

相反,如果它们确实相交->我们将使用该模式索引的精炼列表更新相邻单元格:

                W[idN] = intersection

因为该相邻小区已被更新,所以它的熵也必须被更新:

                lfreqs = [freqs[i] for i in W[idN]]
                H[idN] = (log(sum(lfreqs)) - sum(map(lambda x: x * log(x), lfreqs)) / sum(lfreqs)) - random(.001)

最后,最重要的是,我们将该相邻单元的索引添加到堆栈中,从而使其依次成为下一个当前单元(其邻居将在下一个while循环中更新的单元):

                stack.add(idN)

完整更新的脚本

from collections import Counter
from itertools import chain
from random import choice

w, h = 40, 25
N = 3

def setup():
    size(w*20, h*20, P2D)
    background('#FFFFFF')
    frameRate(1000)
    noStroke()

    global W, A, H, patterns, freqs, npat, mat, xs, ys

    img = loadImage('Flowers.png') 
    iw, ih = img.width, img.height
    xs, ys = width//w, height//h
    kernel = [[i + n*iw for i in xrange(N)] for n in xrange(N)]
    mat = ((-1, 0), (1, 0), (0, -1), (0, 1))
    all = []

    for y in xrange(ih):
        for x in xrange(iw):
            cmat = [[img.pixels[((x+n)%iw)+(((a[0]+iw*y)/iw)%ih)*iw] for n in a] for a in kernel]
            for r in xrange(4):
                cmat = zip(*cmat[::-1])
                all.append(cmat)
                all.append(cmat[::-1])
                all.append([a[::-1] for a in cmat])

    all = [tuple(chain.from_iterable(p)) for p in all] 
    c = Counter(all)
    patterns = c.keys()
    freqs = c.values()
    npat = len(freqs) 

    W = [set(range(npat)) for i in xrange(w*h)] 
    A = [[set() for dir in xrange(len(mat))] for i in xrange(npat)]
    H = [100 for i in xrange(w*h)] 

    for i1 in xrange(npat):
        for i2 in xrange(npat):
            if [n for i, n in enumerate(patterns[i1]) if i%N!=(N-1)] == [n for i, n in enumerate(patterns[i2]) if i%N!=0]:
                A[i1][0].add(i2)
                A[i2][1].add(i1)
            if patterns[i1][:(N*N)-N] == patterns[i2][N:]:
                A[i1][2].add(i2)
                A[i2][3].add(i1)


def draw():    
    global H, W

    emin = int(random(w*h)) if frameCount <= 1 else H.index(min(H)) 

    if H[emin] == 'c': 
        print 'finished'
        noLoop()

    id = choice([idP for idP in W[emin] for i in xrange(freqs[idP])])
    W[emin] = [id]
    H[emin] = 'c' 

    stack = set([emin])
    while stack:
        idC = stack.pop() 
        for dir, t in enumerate(mat):
            x = (idC%w + t[0])%w
            y = (idC/w + t[1])%h
            idN = x + y * w 
            if H[idN] != 'c': 
                possible = set([n for idP in W[idC] for n in A[idP][dir]])
                if not W[idN].issubset(possible):
                    intersection = possible & W[idN] 
                    if not intersection:
                        print 'contradiction'
                        noLoop()
                        return

                    W[idN] = intersection
                    lfreqs = [freqs[i] for i in W[idN]]
                    H[idN] = (log(sum(lfreqs)) - sum(map(lambda x: x * log(x), lfreqs)) / sum(lfreqs)) - random(.001)
                    stack.add(idN)

    fill(patterns[id][0])
    rect((emin%w) * xs, (emin/w) * ys, xs, ys)

enter image description here

整体改善

除了这些修复程序外,我还做了一些次要的代码优化,以加快观察和传播步骤,并缩短了加权选择的计算时间。

  • 现在,“ Wave”由Python索引集组成,这些索引的大小随着单元格“折叠”而减小(替换固定大小的布尔值大列表)。

  • 熵存储在defaultdict中,其密钥将逐渐删除。

  • 起始熵值由一个随机整数代替(不需要第一熵计算,因为在开始时相当高的不确定性水平)

  • 单元格只显示一次(避免将它们存储在数组中并在每帧重绘)

  • 加权选择现在是单线(避免使用列表理解的几行)

solub answered 2020-01-23T22:02:00Z
5 votes

在查看其中一个示例中链接的实时演示时,并基于对原始算法代码的快速回顾,我相信您的错误在于“传播”步骤。

传播不仅将相邻的4个单元更新为折叠的单元。 您还必须递归更新所有这些单元的邻居,然后更新这些单元的邻居,等等。 好吧,具体地说,一旦您更新了一个相邻的小区,就可以更新它的邻居(在到达第一个小区的其他邻居之前),即深度优先更新,而不是广度优先更新。 至少,这就是我从现场演示中收集的信息。

原始算法的实际C#代码实现非常复杂,我还没有完全理解它,但是关键点似乎是此处创建“传播器”对象以及此处的传播函数本身。

mbrig answered 2020-01-23T22:02:30Z
translate from https://stackoverflow.com:/questions/57049191/issues-implementing-the-wave-collapse-function-algorithm-in-python