Spock Web Console

subscribe to the feed Subscribe
Prisoners And Hats Problem (via #spockwebconsole)
tweet this script Tweet

Prisoners And Hats Problem

Published 10 months ago by jp
Actions  ➤ Edit In Console Back To Console Show/Hide Line Numbers View Recent Scripts
import spock.lang.*

// Hit 'Run Script' below
class PrisonersSpec extends Specification
{
    def "The first prisoner can guess incorrectly"()
    {
        def answers = new PrisonerHatProblem(1).solve();

        expect:
        answers[0] == Hat.RED || answers[0] == Hat.BLUE;
    }

    def "The first prisoner answers RED if he sees an odd number of RED hats, otherwise BLUE"(int a0, int a1, int c)
    {
        def answers = new PrisonerHatProblem([a0, a1] as int[]).solve();

        expect:
        c == answers[0];

        where:
        a0 | a1 || c
        0  | 0  || 0
        0  | 1  || 1
        1  | 0  || 0
        1  | 1  || 1
    }

    @Unroll
    def "2 prisoners: #a0,#a1"(int a0, int a1)
    {
        def answers = new PrisonerHatProblem([a0, a1] as int[]).solve();

        expect:
        [a1] == answers[1..-1];

        where:
        a0 | a1
        0  | 0
        0  | 1
        1  | 0
        1  | 1
    }

    @Unroll
    def "3 prisoners: #a0,#a1,#a2"(int a0, int a1, int a2)
    {
        def answers = new PrisonerHatProblem([a0, a1, a2] as int[]).solve();

        expect:
        [a1, a2] == answers[1..-1];

        where:
        a0 | a1 | a2
        0  | 0  | 0
        0  | 0  | 1
        0  | 1  | 0
        0  | 1  | 1
        1  | 0  | 0
        1  | 0  | 1
        1  | 1  | 0
        1  | 1  | 1
    }

    @Unroll
    def "4 prisoners: #a0,#a1,#a2,#a3"(int a0, int a1, int a2, int a3)
    {
        def answers = new PrisonerHatProblem([a0, a1, a2, a3] as int[]).solve();

        expect:
        [a1, a2, a3] == answers[1..-1];

        where:
        a0 | a1 | a2 | a3
        0  | 0  | 0  | 0
        0  | 0  | 0  | 1
        0  | 0  | 1  | 0
        0  | 0  | 1  | 1
        0  | 1  | 0  | 0
        0  | 1  | 0  | 1
        0  | 1  | 1  | 0
        0  | 1  | 1  | 1
        1  | 0  | 0  | 0
        1  | 0  | 0  | 1
        1  | 0  | 1  | 0
        1  | 0  | 1  | 1
        1  | 1  | 0  | 0
        1  | 1  | 0  | 1
        1  | 1  | 1  | 0
        1  | 1  | 1  | 1
    }

    @Unroll
    def "Random problem with #n prisoners"(int n)
    {
        def problem = new PrisonerHatProblem(n);
        def answers = problem.solve();

        expect:
        (problem.getHats().collect { it.color })[1..-1] == answers[1..-1];

        where:
        n    | _
        17   | 0
        20   | 0
        47   | 0
        50   | 0
        99   | 0
    }
}

class Hat
{
    public static final int NONE = -1;
    public static final int RED = 0;
    public static final int BLUE = 1;

    public int color;

    public Hat(int color)
    {
        this.color = color;
    }
}

class Prisoner
{
    public int getAnswer(int[] pastAnswers, Hat[] visibleHats)
    {
        int heardRed = pastAnswers.toList().count({ it == Hat.RED });
        int seenRed = visibleHats.toList().count({ it.color == Hat.RED });

        boolean redOdd = (heardRed + seenRed) % 2 == 1;

        int answer = redOdd ? Hat.RED : Hat.BLUE;

        return answer;
    }
}

class PrisonerHatProblem
{
    public static final int[] availableColors = [ Hat.RED, Hat.BLUE ] as int[];

    private Hat[] hats;

    public PrisonerHatProblem(int N)
    {
        int[] randomColors = new int[N];

        for(int i = 0; i < randomColors.length; ++i)
        {
            randomColors[i] = availableColors[Math.random() % 2 == 0 ? 0 : 1];
        }

        hats = generateHatsFromColors(randomColors);
    }

    public PrisonerHatProblem(int[] predefinedColors)
    {
        hats = generateHatsFromColors(predefinedColors);
    }

    private Hat[] generateHatsFromColors(int[] colors)
    {
        def hats = new Hat[colors.length];

        for(int i = 0; i < hats.length; ++i)
        {
            hats[i] = new Hat(colors[i]);
        }

        return hats;
    }

    public Hat[] getHats()
    {
        return hats;
    }

    public int[] solve()
    {
        int len = hats.length;
        int[] answers = new int[len];

        if (len > 1)
        {
            answers[0] = new Prisoner().getAnswer([] as int[], hats[1..-1] as Hat[]);

            int n = 1;
            while(n < len-1)
            {
                answers[n] = new Prisoner().getAnswer(answers[0..(n-1)] as int[], hats[(n+1)..-1] as Hat[]);
                ++n;
            }

            answers[len-1] = new Prisoner().getAnswer(answers[0..-2] as int[], [] as Hat[]);
        }
        else
        {
            answers[0] = new Prisoner().getAnswer([] as int[], hats[] as Hat[]);
        }

        return answers;
    }
}