#!/usr/bin/python
#
# Copyright 2019 The Android Open Source Project
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import unittest

from errno import *  # pylint: disable=wildcard-import
from socket import *  # pylint: disable=wildcard-import
import ctypes
import fcntl
import os
import random
import select
import termios
import threading
import time
from scapy import all as scapy

import multinetwork_base
import net_test
import packets

SOL_TCP = net_test.SOL_TCP
SHUT_RD = net_test.SHUT_RD
SHUT_WR = net_test.SHUT_WR
SHUT_RDWR = net_test.SHUT_RDWR
SIOCINQ = termios.FIONREAD
SIOCOUTQ = termios.TIOCOUTQ

TEST_PORT = 5555

# Following constants are SOL_TCP level options and arguments.
# They are defined in linux-kernel: include/uapi/linux/tcp.h

# SOL_TCP level options.
TCP_REPAIR = 19
TCP_REPAIR_QUEUE = 20
TCP_QUEUE_SEQ = 21

# TCP_REPAIR_{OFF, ON} is an argument to TCP_REPAIR.
TCP_REPAIR_OFF = 0
TCP_REPAIR_ON = 1

# TCP_{NO, RECV, SEND}_QUEUE is an argument to TCP_REPAIR_QUEUE.
TCP_NO_QUEUE = 0
TCP_RECV_QUEUE = 1
TCP_SEND_QUEUE = 2

# This test is aiming to ensure tcp keep alive offload works correctly
# when it fetches tcp information from kernel via tcp repair mode.
class TcpRepairTest(multinetwork_base.MultiNetworkBaseTest):

  def assertSocketNotConnected(self, sock):
    self.assertRaisesErrno(ENOTCONN, sock.getpeername)

  def assertSocketConnected(self, sock):
    sock.getpeername()  # No errors? Socket is alive and connected.

  def createConnectedSocket(self, version, netid):
    s = net_test.TCPSocket(net_test.GetAddressFamily(version))
    net_test.DisableFinWait(s)
    self.SelectInterface(s, netid, "mark")

    remotesockaddr = self.GetRemoteSocketAddress(version)
    remoteaddr = self.GetRemoteAddress(version)
    self.assertRaisesErrno(EINPROGRESS, s.connect, (remotesockaddr, TEST_PORT))
    self.assertSocketNotConnected(s)

    myaddr = self.MyAddress(version, netid)
    port = s.getsockname()[1]
    self.assertNotEqual(0, port)

    desc, expect_syn = packets.SYN(TEST_PORT, version, myaddr, remoteaddr, port, seq=None)
    msg = "socket connect: expected %s" % desc
    syn = self.ExpectPacketOn(netid, msg, expect_syn)
    synack_desc, synack = packets.SYNACK(version, remoteaddr, myaddr, syn)
    synack.getlayer("TCP").seq = random.getrandbits(32)
    synack.getlayer("TCP").window = 14400
    self.ReceivePacketOn(netid, synack)
    desc, ack = packets.ACK(version, myaddr, remoteaddr, synack)
    msg = "socket connect: got SYN+ACK, expected %s" % desc
    ack = self.ExpectPacketOn(netid, msg, ack)
    self.last_sent = ack
    self.last_received = synack
    return s

  def receiveFin(self, netid, version, sock):
    self.assertSocketConnected(sock)
    remoteaddr = self.GetRemoteAddress(version)
    myaddr = self.MyAddress(version, netid)
    desc, fin = packets.FIN(version, remoteaddr, myaddr, self.last_sent)
    self.ReceivePacketOn(netid, fin)
    self.last_received = fin

  def sendData(self, netid, version, sock, payload):
    sock.send(payload)

    remoteaddr = self.GetRemoteAddress(version)
    myaddr = self.MyAddress(version, netid)
    desc, send = packets.ACK(version, myaddr, remoteaddr,
                             self.last_received, payload)
    self.last_sent = send

  def receiveData(self, netid, version, payload):
    remoteaddr = self.GetRemoteAddress(version)
    myaddr = self.MyAddress(version, netid)

    desc, received = packets.ACK(version, remoteaddr, myaddr,
                                 self.last_sent, payload)
    ack_desc, ack = packets.ACK(version, myaddr, remoteaddr, received)
    self.ReceivePacketOn(netid, received)
    time.sleep(0.1)
    self.ExpectPacketOn(netid, "expecting %s" % ack_desc, ack)
    self.last_sent = ack
    self.last_received = received

  # Test the behavior of NO_QUEUE. Expect incoming data will be stored into
  # the queue, but socket cannot be read/written in NO_QUEUE.
  def testTcpRepairInNoQueue(self):
    for version in [4, 5, 6]:
      self.tcpRepairInNoQueueTest(version)

  def tcpRepairInNoQueueTest(self, version):
    netid = self.RandomNetid()
    sock = self.createConnectedSocket(version, netid)
    sock.setsockopt(SOL_TCP, TCP_REPAIR, TCP_REPAIR_ON)

    # In repair mode with NO_QUEUE, writes fail...
    self.assertRaisesErrno(EINVAL, sock.send, "write test")

    # remote data is coming.
    TEST_RECEIVED = net_test.UDP_PAYLOAD
    self.receiveData(netid, version, TEST_RECEIVED)

    # In repair mode with NO_QUEUE, read fail...
    self.assertRaisesErrno(EPERM, sock.recv, 4096)

    sock.setsockopt(SOL_TCP, TCP_REPAIR, TCP_REPAIR_OFF)
    readData = sock.recv(4096)
    self.assertEqual(readData, TEST_RECEIVED)
    sock.close()

  # Test whether tcp read/write sequence number can be fetched correctly
  # by TCP_QUEUE_SEQ.
  def testGetSequenceNumber(self):
    for version in [4, 5, 6]:
      self.GetSequenceNumberTest(version)

  def GetSequenceNumberTest(self, version):
    netid = self.RandomNetid()
    sock = self.createConnectedSocket(version, netid)
    # test write queue sequence number
    sequence_before = self.GetWriteSequenceNumber(version, sock)
    expect_sequence = self.last_sent.getlayer("TCP").seq
    self.assertEqual(sequence_before & 0xffffffff, expect_sequence)
    TEST_SEND = net_test.UDP_PAYLOAD
    self.sendData(netid, version, sock, TEST_SEND)
    sequence_after = self.GetWriteSequenceNumber(version, sock)
    self.assertEqual(sequence_before + len(TEST_SEND), sequence_after)

    # test read queue sequence number
    sequence_before = self.GetReadSequenceNumber(version, sock)
    expect_sequence = self.last_received.getlayer("TCP").seq + 1
    self.assertEqual(sequence_before & 0xffffffff, expect_sequence)
    TEST_READ = net_test.UDP_PAYLOAD
    self.receiveData(netid, version, TEST_READ)
    sequence_after = self.GetReadSequenceNumber(version, sock)
    self.assertEqual(sequence_before + len(TEST_READ), sequence_after)
    sock.close()

  def GetWriteSequenceNumber(self, version, sock):
    sock.setsockopt(SOL_TCP, TCP_REPAIR, TCP_REPAIR_ON)
    sock.setsockopt(SOL_TCP, TCP_REPAIR_QUEUE, TCP_SEND_QUEUE)
    sequence = sock.getsockopt(SOL_TCP, TCP_QUEUE_SEQ)
    sock.setsockopt(SOL_TCP, TCP_REPAIR_QUEUE, TCP_NO_QUEUE)
    sock.setsockopt(SOL_TCP, TCP_REPAIR, TCP_REPAIR_OFF)
    return sequence

  def GetReadSequenceNumber(self, version, sock):
    sock.setsockopt(SOL_TCP, TCP_REPAIR, TCP_REPAIR_ON)
    sock.setsockopt(SOL_TCP, TCP_REPAIR_QUEUE, TCP_RECV_QUEUE)
    sequence = sock.getsockopt(SOL_TCP, TCP_QUEUE_SEQ)
    sock.setsockopt(SOL_TCP, TCP_REPAIR_QUEUE, TCP_NO_QUEUE)
    sock.setsockopt(SOL_TCP, TCP_REPAIR, TCP_REPAIR_OFF)
    return sequence

  # Test whether tcp repair socket can be poll()'ed correctly
  # in mutiple threads at the same time.
  def testMultiThreadedPoll(self):
    for version in [4, 5, 6]:
      self.PollWhenShutdownTest(version)
      self.PollWhenReceiveFinTest(version)

  def PollRepairSocketInMultipleThreads(self, netid, version, expected):
    sock = self.createConnectedSocket(version, netid)
    sock.setsockopt(SOL_TCP, TCP_REPAIR, TCP_REPAIR_ON)

    multiThreads = []
    for i in [0, 1]:
      thread = SocketExceptionThread(sock, lambda sk: self.fdSelect(sock, expected))
      thread.start()
      self.assertTrue(thread.is_alive())
      multiThreads.append(thread)

    return sock, multiThreads

  def assertThreadsStopped(self, multiThreads, msg) :
    for thread in multiThreads:
      if (thread.is_alive()):
        thread.join(1)
      if (thread.is_alive()):
        thread.stop()
        raise AssertionError(msg)

  def PollWhenShutdownTest(self, version):
    netid = self.RandomNetid()
    expected = select.POLLIN
    sock, multiThreads = self.PollRepairSocketInMultipleThreads(netid, version, expected)
    # Test shutdown RD.
    sock.shutdown(SHUT_RD)
    self.assertThreadsStopped(multiThreads, "poll fail during SHUT_RD")
    sock.close()

    expected = None
    sock, multiThreads = self.PollRepairSocketInMultipleThreads(netid, version, expected)
    # Test shutdown WR.
    sock.shutdown(SHUT_WR)
    self.assertThreadsStopped(multiThreads, "poll fail during SHUT_WR")
    sock.close()

    expected = select.POLLIN | select.POLLHUP
    sock, multiThreads = self.PollRepairSocketInMultipleThreads(netid, version, expected)
    # Test shutdown RDWR.
    sock.shutdown(SHUT_RDWR)
    self.assertThreadsStopped(multiThreads, "poll fail during SHUT_RDWR")
    sock.close()

  def PollWhenReceiveFinTest(self, version):
    netid = self.RandomNetid()
    expected = select.POLLIN
    sock, multiThreads = self.PollRepairSocketInMultipleThreads(netid, version, expected)
    self.receiveFin(netid, version, sock)
    self.assertThreadsStopped(multiThreads, "poll fail during FIN")
    sock.close()

  # Test whether socket idle can be detected by SIOCINQ and SIOCOUTQ.
  def testSocketIdle(self):
    for version in [4, 5, 6]:
      self.readQueueIdleTest(version)
      self.writeQueueIdleTest(version)

  def readQueueIdleTest(self, version):
    netid = self.RandomNetid()
    sock = self.createConnectedSocket(version, netid)

    buf = ctypes.c_int()
    fcntl.ioctl(sock, SIOCINQ, buf)
    self.assertEqual(buf.value, 0)

    TEST_RECV_PAYLOAD = net_test.UDP_PAYLOAD
    self.receiveData(netid, version, TEST_RECV_PAYLOAD)
    fcntl.ioctl(sock, SIOCINQ, buf)
    self.assertEqual(buf.value, len(TEST_RECV_PAYLOAD))
    sock.close()

  def writeQueueIdleTest(self, version):
    netid = self.RandomNetid()
    # Setup a connected socket, write queue is empty.
    sock = self.createConnectedSocket(version, netid)
    buf = ctypes.c_int()
    fcntl.ioctl(sock, SIOCOUTQ, buf)
    self.assertEqual(buf.value, 0)
    # Change to repair mode with SEND_QUEUE, writing some data to the queue.
    sock.setsockopt(SOL_TCP, TCP_REPAIR, TCP_REPAIR_ON)
    TEST_SEND_PAYLOAD = net_test.UDP_PAYLOAD
    sock.setsockopt(SOL_TCP, TCP_REPAIR_QUEUE, TCP_SEND_QUEUE)
    self.sendData(netid, version, sock, TEST_SEND_PAYLOAD)
    fcntl.ioctl(sock, SIOCOUTQ, buf)
    self.assertEqual(buf.value, len(TEST_SEND_PAYLOAD))
    sock.close()

    # Setup a connected socket again.
    netid = self.RandomNetid()
    sock = self.createConnectedSocket(version, netid)
    # Send out some data and don't receive ACK yet.
    self.sendData(netid, version, sock, TEST_SEND_PAYLOAD)
    fcntl.ioctl(sock, SIOCOUTQ, buf)
    self.assertEqual(buf.value, len(TEST_SEND_PAYLOAD))
    # Receive response ACK.
    remoteaddr = self.GetRemoteAddress(version)
    myaddr = self.MyAddress(version, netid)
    desc_ack, ack = packets.ACK(version, remoteaddr, myaddr, self.last_sent)
    self.ReceivePacketOn(netid, ack)
    fcntl.ioctl(sock, SIOCOUTQ, buf)
    self.assertEqual(buf.value, 0)
    sock.close()


  def fdSelect(self, sock, expected):
    READ_ONLY = select.POLLIN | select.POLLPRI | select.POLLHUP | select.POLLERR | select.POLLNVAL
    p = select.poll()
    p.register(sock, READ_ONLY)
    events = p.poll(500)
    for fd,event in events:
      if fd == sock.fileno():
        self.assertEqual(event, expected)
      else:
        raise AssertionError("unexpected poll fd")

class SocketExceptionThread(threading.Thread):

  def __init__(self, sock, operation):
    self.exception = None
    super(SocketExceptionThread, self).__init__()
    self.daemon = True
    self.sock = sock
    self.operation = operation

  def stop(self):
    self._Thread__stop()

  def run(self):
    try:
      self.operation(self.sock)
    except (IOError, AssertionError) as e:
      self.exception = e

if __name__ == '__main__':
  unittest.main()