#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements.  See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership.  The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License.  You may obtain a copy of the License at
# 
#   http://www.apache.org/licenses/LICENSE-2.0
# 
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied.  See the License for the
# specific language governing permissions and limitations
# under the License.
#
from __future__ import absolute_import
from qpid.client import Client, Closed
from qpid.queue import Empty
from qpid.content import Content
from qpid.testlib import TestBase

class TxTests(TestBase):
    """
    Tests for 'methods' on the amqp tx 'class'
    """

    def test_commit(self):
        """
        Test that commited publishes are delivered and commited acks are not re-delivered
        """
        channel = self.channel
        queue_a, queue_b, queue_c = self.perform_txn_work(channel, "tx-commit-a", "tx-commit-b", "tx-commit-c")
        channel.tx_commit()

        #check results
        for i in range(1, 5):
            msg = queue_c.get(timeout=self.recv_timeout())
            self.assertEqual("TxMessage %d" % i, msg.content.body)

        msg = queue_b.get(timeout=self.recv_timeout())
        self.assertEqual("TxMessage 6", msg.content.body)

        msg = queue_a.get(timeout=self.recv_timeout())
        self.assertEqual("TxMessage 7", msg.content.body)

        for q in [queue_a, queue_b, queue_c]:
            try:
                extra = q.get(timeout=self.recv_timeout_negative())
                self.fail("Got unexpected message: " + extra.content.body)
            except Empty: None

        #cleanup
        channel.basic_ack(delivery_tag=0, multiple=True)
        channel.tx_commit()

    def test_auto_rollback(self):
        """
        Test that a channel closed with an open transaction is effectively rolled back
        """
        channel = self.channel
        queue_a, queue_b, queue_c = self.perform_txn_work(channel, "tx-autorollback-a", "tx-autorollback-b", "tx-autorollback-c")

        for q in [queue_a, queue_b, queue_c]:
            try:
                extra = q.get(timeout=self.recv_timeout_negative())
                self.fail("Got unexpected message: " + extra.content.body)
            except Empty: None

        channel.tx_rollback()

        #check results
        for i in range(1, 5):
            msg = queue_a.get(timeout=self.recv_timeout())
            self.assertEqual("Message %d" % i, msg.content.body)

        msg = queue_b.get(timeout=self.recv_timeout())
        self.assertEqual("Message 6", msg.content.body)

        msg = queue_c.get(timeout=self.recv_timeout())
        self.assertEqual("Message 7", msg.content.body)

        for q in [queue_a, queue_b, queue_c]:
            try:
                extra = q.get(timeout=self.recv_timeout_negative())
                self.fail("Got unexpected message: " + extra.content.body)
            except Empty: None

        #cleanup
        channel.basic_ack(delivery_tag=0, multiple=True)
        channel.tx_commit()

    def test_rollback(self):
        """
        Test that rolled back publishes are not delivered and rolled back acks are re-delivered
        """
        channel = self.channel
        queue_a, queue_b, queue_c = self.perform_txn_work(channel, "tx-rollback-a", "tx-rollback-b", "tx-rollback-c")

        for q in [queue_a, queue_b, queue_c]:
            try:
                extra = q.get(timeout=self.recv_timeout_negative())
                self.fail("Got unexpected message: " + extra.content.body)
            except Empty: None

        channel.tx_rollback()

        #check results
        for i in range(1, 5):
            msg = queue_a.get(timeout=self.recv_timeout())
            self.assertEqual("Message %d" % i, msg.content.body)

        msg = queue_b.get(timeout=self.recv_timeout())
        self.assertEqual("Message 6", msg.content.body)

        msg = queue_c.get(timeout=self.recv_timeout())
        self.assertEqual("Message 7", msg.content.body)

        for q in [queue_a, queue_b, queue_c]:
            try:
                extra = q.get(timeout=self.recv_timeout_negative())
                self.fail("Got unexpected message: " + extra.content.body)
            except Empty: None

        #cleanup
        channel.basic_ack(delivery_tag=0, multiple=True)
        channel.tx_commit()

    def perform_txn_work(self, channel, name_a, name_b, name_c):
        """
        Utility method that does some setup and some work under a transaction. Used for testing both
        commit and rollback
        """
        #setup:
        channel.queue_declare(queue=name_a, exclusive=True)
        channel.queue_declare(queue=name_b, exclusive=True)
        channel.queue_declare(queue=name_c, exclusive=True)

        key = "my_key_" + name_b
        topic = "my_topic_" + name_c 
    
        channel.queue_bind(queue=name_b, exchange="amq.direct", routing_key=key)
        channel.queue_bind(queue=name_c, exchange="amq.topic", routing_key=topic)

        for i in range(1, 5):
            channel.basic_publish(routing_key=name_a, content=Content("Message %d" % i))

        channel.basic_publish(routing_key=key, exchange="amq.direct", content=Content("Message 6"))
        channel.basic_publish(routing_key=topic, exchange="amq.topic", content=Content("Message 7"))

        channel.tx_select()

        #consume and ack messages
        sub_a = channel.basic_consume(queue=name_a, no_ack=False)
        queue_a = self.client.queue(sub_a.consumer_tag)
        for i in range(1, 5):
            msg = queue_a.get(timeout=self.recv_timeout())
            self.assertEqual("Message %d" % i, msg.content.body)
        channel.basic_ack(delivery_tag=msg.delivery_tag, multiple=True)    

        sub_b = channel.basic_consume(queue=name_b, no_ack=False)
        queue_b = self.client.queue(sub_b.consumer_tag)
        msg = queue_b.get(timeout=self.recv_timeout())
        self.assertEqual("Message 6", msg.content.body)
        channel.basic_ack(delivery_tag=msg.delivery_tag)    

        sub_c = channel.basic_consume(queue=name_c, no_ack=False)
        queue_c = self.client.queue(sub_c.consumer_tag)
        msg = queue_c.get(timeout=self.recv_timeout())
        self.assertEqual("Message 7", msg.content.body)
        channel.basic_ack(delivery_tag=msg.delivery_tag)    

        #publish messages
        for i in range(1, 5):
            channel.basic_publish(routing_key=topic, exchange="amq.topic", content=Content("TxMessage %d" % i))

        channel.basic_publish(routing_key=key, exchange="amq.direct", content=Content("TxMessage 6"))
        channel.basic_publish(routing_key=name_a, content=Content("TxMessage 7"))

        return queue_a, queue_b, queue_c

    def test_commit_overlapping_acks(self):
        """
        Test that logically 'overlapping' acks do not cause errors on commit
        """
        channel = self.channel
        channel.queue_declare(queue="commit-overlapping", exclusive=True)
        for i in range(1, 10):
            channel.basic_publish(routing_key="commit-overlapping", content=Content("Message %d" % i))

        
        channel.tx_select()

        sub = channel.basic_consume(queue="commit-overlapping", no_ack=False)
        queue = self.client.queue(sub.consumer_tag)
        for i in range(1, 10):
            msg = queue.get(timeout=self.recv_timeout())
            self.assertEqual("Message %d" % i, msg.content.body)
            if i in [3, 6, 10]:
                channel.basic_ack(delivery_tag=msg.delivery_tag)    
                
        channel.tx_commit()

        #check all have been acked:
        try:
            extra = queue.get(timeout=self.recv_timeout_negative())
            self.fail("Got unexpected message: " + extra.content.body)
        except Empty: None
