#!/usr/bin/env python3
# systemd integration test: Handling of storage devices
# (C) 2015 Canonical Ltd.
# Author: Martin Pitt <martin.pitt@ubuntu.com>

import os
import random
import subprocess
import sys
import time
import unittest

from glob import glob
from threading import Thread


TIMEOUT_SERVICE_START = 10
TIMEOUT_PASSWORD_AGENT_STOP = 10
TIMEOUT_PLAINTEXT_DEV = 30
TIMEOUT_SCSI_DEBUG_ADD_HOST = 5

SCSI_DEBUG_DIR = '/sys/bus/pseudo/drivers/scsi_debug'

class FakeDriveTestBase(unittest.TestCase):
    @classmethod
    def setUpClass(cls):
        if os.path.isdir(SCSI_DEBUG_DIR):
            return

        # Consider missing scsi_debug module a test failure
        subprocess.check_call(['modprobe', 'scsi_debug', 'dev_size_mb=32'])
        assert os.path.isdir(SCSI_DEBUG_DIR)

    def setUp(self):
        existing_adapters = set(glob(os.path.join(SCSI_DEBUG_DIR, 'adapter*')))
        with open(os.path.join(SCSI_DEBUG_DIR, 'add_host'), 'w') as f:
            f.write('1')
        new_adapters = set(glob(os.path.join(SCSI_DEBUG_DIR, 'adapter*'))) - existing_adapters
        self.assertEqual(len(new_adapters), 1)
        self.adapter = new_adapters.pop()
        for timeout in range(TIMEOUT_SCSI_DEBUG_ADD_HOST):
            devices = set(glob(os.path.join(self.adapter, 'host*/target*/*:*/block/*')))
            if len(devices) > 0:
                break
            time.sleep(1)
        else:
            self.fail('Timed out waiting for scsi_debug block device name')
        self.assertEqual(len(devices), 1)
        self.device = os.path.join('/dev/', os.path.basename(devices.pop()))

    def tearDown(self):
        existing_adapters = set(glob(os.path.join(SCSI_DEBUG_DIR, 'adapter*')))
        with open(os.path.join(SCSI_DEBUG_DIR, 'add_host'), 'w') as f:
            f.write('-1')
        removed_adapters = existing_adapters - set(glob(os.path.join(SCSI_DEBUG_DIR, 'adapter*')))
        self.assertEqual(len(removed_adapters), 1)
        adapter = removed_adapters.pop()
        self.assertEqual(self.adapter, adapter)
        self.adapter = None
        self.device = None


class CryptsetupTest(FakeDriveTestBase):
    def setUp(self):
        testname = self.id().split('.')[-1]
        self.plaintext_name = 'testcrypt_%s' % testname
        self.plaintext_dev = '/dev/mapper/' + self.plaintext_name
        self.service_name = 'systemd-cryptsetup@%s.service' % self.plaintext_name
        if os.path.exists(self.plaintext_dev):
            self.fail('%s exists already' % self.plaintext_dev)

        super().setUp()

        if os.path.exists('/etc/crypttab'):
            os.rename('/etc/crypttab', '/etc/crypttab.systemdtest')
        self.password = 'pwd%i' % random.randint(1000, 10000)
        self.password_agent = None
        self.password_agent_stop = False

    def tearDown(self):
        if self.password_agent:
            self.password_agent_stop = True
            self.password_agent.join(timeout=TIMEOUT_PASSWORD_AGENT_STOP)
            self.assertFalse(self.password_agent.is_alive())
            self.password_agent = None
        for timeout in range(TIMEOUT_SERVICE_START):
            state = subprocess.run(['systemctl', 'show', '--no-pager', self.service_name, '--property', 'ActiveState'],
                                   stdout=subprocess.PIPE, universal_newlines=True).stdout
            state = state.strip().replace('ActiveState=', '', 1)
            if state in ['active', 'failed']:
                break
            time.sleep(1)
        else:
            self.fail('Timed out waiting for %s to start (or fail)' % self.service_name)
        subprocess.call(['umount', self.plaintext_dev], stderr=subprocess.DEVNULL)
        if state == 'active':
            subprocess.call(['systemctl', 'stop', self.service_name], stderr=subprocess.STDOUT)
        if os.path.exists('/etc/crypttab'):
            os.unlink('/etc/crypttab')
        if os.path.exists('/etc/crypttab.systemdtest'):
            os.rename('/etc/crypttab.systemdtest', '/etc/crypttab')
        if os.path.exists(self.plaintext_dev):
            subprocess.call(['dmsetup', 'remove', self.plaintext_dev],
                            stderr=subprocess.STDOUT)
        subprocess.check_call(['systemctl', 'daemon-reload'])

        super().tearDown()

    def format_luks(self):
        '''Format test device with LUKS'''

        p = subprocess.Popen(['cryptsetup', '--batch-mode', 'luksFormat', self.device, '-'],
                             stdin=subprocess.PIPE)
        p.communicate(self.password.encode())
        self.assertEqual(p.returncode, 0)
        os.sync()
        subprocess.check_call(['udevadm', 'settle'])

    def start_password_agent(self):
        '''Run password agent to answer passphrase request for crypt device'''

        # wait for incoming request
        found = False
        while not found:
            for ask in glob('/run/systemd/ask-password/ask.*'):
                with open(ask) as f:
                    contents = f.read()
                    if self.plaintext_name in contents:
                        found = True
                        break
            if not found:
                if self.password_agent_stop:
                    return
                time.sleep(0.5)

        # parse Socket=
        for line in contents.splitlines():
            if line.startswith('Socket='):
                socket = line.split('=', 1)[1]
                break
        else:
            self.fail('Could not find socket')

        # send reply
        p = subprocess.Popen(['/lib/systemd/systemd-reply-password', '1', socket],
                             stdin=subprocess.PIPE)
        p.communicate(self.password.encode())
        self.assertEqual(p.returncode, 0)

    def apply(self, target):
        '''Tell systemd to generate and run the cryptsetup units'''

        subprocess.check_call(['systemctl', 'daemon-reload'])

        self.password_agent = Thread(target=self.start_password_agent);
        self.password_agent.start()
        subprocess.check_call(['systemctl', '--no-ask-password', 'restart', target])
        for timeout in range(TIMEOUT_PLAINTEXT_DEV):
            if os.path.exists(self.plaintext_dev):
                break
            time.sleep(1)
        else:
            self.fail('Timed out waiting for %s to appear' % self.plaintext_dev)

    def test_luks_by_devname(self):
        '''LUKS device by plain device name, empty'''

        self.format_luks()
        with open('/etc/crypttab', 'w') as f:
            f.write('%s %s none luks\n' % (self.plaintext_name, self.device))
        self.apply('cryptsetup.target')

        # should not be mounted
        with open('/proc/mounts') as f:
            self.assertNotIn(self.plaintext_name, f.read())

        # device should not have anything on it
        p = subprocess.Popen(['blkid', self.plaintext_dev], stdout=subprocess.PIPE)
        out = p.communicate()[0]
        self.assertEqual(out, b'')
        self.assertNotEqual(p.returncode, 0)

    def test_luks_by_uuid(self):
        '''LUKS device by UUID, empty'''

        self.format_luks()
        uuid = subprocess.check_output(['blkid', '-ovalue', '-sUUID', self.device],
                                       universal_newlines=True).strip()
        with open('/etc/crypttab', 'w') as f:
            f.write('%s UUID=%s none luks\n' % (self.plaintext_name, uuid))
        self.apply('cryptsetup.target')

        # should not be mounted
        with open('/proc/mounts') as f:
            self.assertNotIn(self.plaintext_name, f.read())

        # device should not have anything on it
        p = subprocess.Popen(['blkid', self.plaintext_dev], stdout=subprocess.PIPE)
        out = p.communicate()[0]
        self.assertEqual(out, b'')
        self.assertNotEqual(p.returncode, 0)

    def test_luks_swap(self):
        '''LUKS device with "swap" option'''

        self.format_luks()
        with open('/etc/crypttab', 'w') as f:
            f.write('%s %s none luks,swap\n' % (self.plaintext_name, self.device))
        self.apply('cryptsetup.target')

        # should not be mounted
        with open('/proc/mounts') as f:
            self.assertNotIn(self.plaintext_name, f.read())

        # device should be formatted with swap
        out = subprocess.check_output(['blkid', '-ovalue', '-sTYPE', self.plaintext_dev])
        self.assertEqual(out, b'swap\n')

    def test_luks_tmp(self):
        '''LUKS device with "tmp" option'''

        self.format_luks()
        with open('/etc/crypttab', 'w') as f:
            f.write('%s %s none luks,tmp\n' % (self.plaintext_name, self.device))
        self.apply('cryptsetup.target')

        # should not be mounted
        with open('/proc/mounts') as f:
            self.assertNotIn(self.plaintext_name, f.read())

        # device should be formatted with ext2
        out = subprocess.check_output(['blkid', '-ovalue', '-sTYPE', self.plaintext_dev])
        self.assertEqual(out, b'ext2\n')

    def test_luks_fstab(self):
        '''LUKS device in /etc/fstab'''

        self.format_luks()
        with open('/etc/crypttab', 'w') as f:
            f.write('%s %s none luks,tmp\n' % (self.plaintext_name, self.device))

        mountpoint = '/run/crypt1.systemdtest'
        os.mkdir(mountpoint)
        self.addCleanup(os.rmdir, mountpoint)
        os.rename('/etc/fstab', '/etc/fstab.systemdtest')
        self.addCleanup(os.rename, '/etc/fstab.systemdtest', '/etc/fstab')
        with open('/etc/fstab', 'a') as f:
            with open('/etc/fstab.systemdtest') as forig:
                f.write(forig.read())
            f.write('%s %s ext2 defaults 0 0\n' % (self.plaintext_dev, mountpoint))

        # this should now be a requirement of local-fs.target
        self.apply('local-fs.target')

        # should be mounted
        found = False
        with open('/proc/mounts') as f:
            for line in f:
                fields = line.split()
                if fields[0] == self.plaintext_dev:
                    self.assertEqual(fields[1], mountpoint)
                    self.assertEqual(fields[2], 'ext2')
                    found = True
                    break
        if not found:
            self.fail('%s is not mounted' % self.plaintext_dev)


if __name__ == '__main__':
    unittest.main(testRunner=unittest.TextTestRunner(stream=sys.stdout,
                                                     verbosity=2))
