#!/usr/bin/env python3 import argparse import base64 import boto3 import os import shutil import subprocess import sys import tempfile import time FARSSH_VERSION = 'v0.3-devel' FARSSH_ID = 'default' FARSSH_URL = 'https://github.com/apparentorder/farssh' # ---------------------------------------------------------------------- class FarSshArgs: def __init__(self): self.cmd_args = self._parse_args() self.cmd_args['remote_port'] = self.cmd_args.get('remote_port') or self.cmd_args.get('local_port') self.enable_execute_command = False # defaults, if not found in Parameter Store self.force_public_ipv4 = False ssm = boto3.client('ssm') for param in ssm.get_parameters_by_path(Path = f"/farssh/{FARSSH_ID}", Recursive = True)['Parameters']: param_name = param['Name'].split('/')[-1] setattr(self, param_name, param['Value']) try: self.public_subnets = self.public_subnets.split(',') self.force_public_ipv4 = (self.force_public_ipv4 == "true") except AttributeError: x = f"ERROR: FarSSH parameters not found for this AWS account and region. " x += f"Please follow setup instructions at {FARSSH_URL} first." raise SystemExit(x) # assign_public_ipv4 is not set from Parameter Store, but set here dynamically. It has to be sent from # the client (network configuration in ecs:RunTask), but the AWS side can set force_public_ipv4 to # signal the client that the server side needs public IPv4, e.g. for pulling the FarSSH image without NAT. # Otherwise, and if we're going to use IPv6 anyway, we don't need public IPv4. self.assign_public_ipv4 = "ENABLED" if self.force_public_ipv4 or not self.cmd_args.get('ipv6') else "DISABLED" def _parse_args(self): # TODO: the old syntax from v0.1, where it's possible to omit the "tunnel" command word, # seems to be impossible with python's argparse. figure out a way to support this again. if len(sys.argv) == 1: sys.argv += [ "--help" ] parser = argparse.ArgumentParser( description = "Secure on-demand connections into AWS VPCs", ) parser.add_argument('-6', '--ipv6', action='store_true', help = 'use IPv6 (disables public IPv4 when possible)') parser.add_argument('-V', '--version', action='version', version = f'FarSSH {FARSSH_VERSION}') subparsers = parser.add_subparsers(dest = 'command', required = True) parser_ssh = subparsers.add_parser('ssh', help = 'start interactive SSH session') parser_ssh.add_argument('extra_arguments', nargs = '*', help = 'extra arguments passed to ssh') parser_proxy = subparsers.add_parser('proxy', help = 'start SOCKS proxy') parser_tunnel = subparsers.add_parser('tunnel', help = 'establish a port-forwarding tunnel') parser_tunnel.add_argument('local_port', help='local port number') parser_tunnel.add_argument('remote_host', help='remote host address') parser_tunnel.add_argument('remote_port', help='remote port (optional)', nargs = '?', default = argparse.SUPPRESS) parser_psql = subparsers.add_parser('psql', help = 'establish tunnel to a postgres endpoint and launch the psql client') parser_psql.add_argument('-i', '--identifier', help='database identifier (RDS)') parser_psql.add_argument('-u', '--user', dest='username', help='database username') parser_psql.add_argument('-U', '--username', help=argparse.SUPPRESS) parser_psql.add_argument('database', nargs = '?', help='database name (in the DB engine)') parser_psql.add_argument('extra_arguments', nargs = '*', help = 'extra arguments passed to psql') parser_mysql = subparsers.add_parser('mysql', help = 'establish tunnel to a mysql/mariadb endpoint and launch the mysql client') parser_mysql.add_argument('-i', '--identifier', help='database identifier (RDS)') parser_mysql.add_argument('-p', '--password', dest='password', action='store_true', help='ask for mysql password (pass -p to the mysql client)') parser_mysql.add_argument('-u', '--user', dest='username', help=argparse.SUPPRESS) parser_mysql.add_argument('-U', '--username', help='database username') parser_mysql.add_argument('database', nargs = '?', help='database name (in the DB engine)') parser_mysql.add_argument('extra_arguments', nargs = '*', help = 'extra arguments passed to mysql') return vars(parser.parse_args()) # ---------------------------------------------------------------------- class SshKeys: def __init__(self, farssh_args): self._tempdir = tempfile.TemporaryDirectory() self.farssh_args = farssh_args self.known_hosts_file = f"{self._tempdir.name}/known-hosts" subprocess.run(["ssh-keygen", "-q", "-N", "", "-t", "rsa", "-f", f"{self._tempdir.name}/ssh_host_rsa_key"], check = True) subprocess.run(["ssh-keygen", "-q", "-N", "", "-t", "rsa", "-f", f"{self._tempdir.name}/ssh_login_key"], check = True) self.host_key_file = f"{self._tempdir.name}/ssh_host_rsa_key" self.host_key_pub_file = f"{self._tempdir.name}/ssh_host_rsa_key.pub" self.login_key_file = f"{self._tempdir.name}/ssh_login_key" self.login_key_pub_file = f"{self._tempdir.name}/ssh_login_key.pub" self.host_key = open(self.host_key_file, "r").read() self.host_key_pub = open(self.host_key_pub_file, "r").read() # self.login_key = open(self.login_key_file, "r").read() # not used self.login_key_pub = open(self.login_key_pub_file, "r").read() def write_known_hosts(self, ip_address): # Create temporary known-hosts file so the SSH client can verify the remote host's public key that we configured it with. # The `host` *must* be without port number when the port is 22; this seems to be a quirk of OpenSSH's known-hosts file format. with open(self.known_hosts_file, "w") as f: host = ip_address if self.farssh_args.ssh_port != "22": host = f"[{host}]:{self.farssh_args.ssh_port}" f.write(f"{host} {self.host_key_pub}\n") def select_ip_address(args, task): ipv4_address = None ipv6_address = task['containers'][0]['networkInterfaces'][0].get('ipv6Address') if args.cmd_args.get('ipv6'): if not ipv6_address: raise SystemExit("ERROR: IPv6 requested, but no IPv6 address on the FarSSH ECS task. Check selected VPC subnets.") return ipv6_address # The IPv6 is available in `task`, but the public IPv4 address is not, so we need to query this separately. ec2 = boto3.client('ec2') eni_id = [detail['value'] for detail in task['attachments'][0]['details'] if detail['name'] == "networkInterfaceId"][0] dni = ec2.describe_network_interfaces(NetworkInterfaceIds = [ eni_id ]) ipv4_address = dni['NetworkInterfaces'][0].get('Association', {}).get('PublicIp') if ipv4_address: return ipv4_address if ipv6_address: print("WARNING: FarSSH ECS task has no public IPv4 address; attempting with IPv6.") return ipv6_address raise SystemExit("ERROR: FarSSH ECS task has neither IPv6 nor public IPv4 address. Check selected VPC subnets.") def run_ecs_task(args, ssh_keys): override_env = [ { "name": "FARSSH_SSH_AUTHORIZED_KEYS", "value": ssh_keys.login_key_pub, }, { "name": "FARSSH_SSH_HOST_RSA_KEY_BASE64", "value": base64.b64encode(bytes(ssh_keys.host_key, "utf-8")).decode("utf-8") } ] override_entry = { "name": "farssh", # container name from task definition "environment": override_env, } overrides = {} overrides['containerOverrides'] = [ override_entry ] network_configuration = { "awsvpcConfiguration": { "subnets": args.public_subnets, "securityGroups": [ args.security_group ], "assignPublicIp": args.assign_public_ipv4, } } ecs = boto3.client('ecs') tasks = ecs.run_task( cluster = "farssh", capacityProviderStrategy = [ { "capacityProvider": "FARGATE", "weight": 1, "base": 1 }, # { "capacityProvider": "FARGATE_SPOT", "weight": 0, "base": 0 }, # not supported for Arm images ], taskDefinition = f"farssh-{FARSSH_ID}", enableExecuteCommand = args.enable_execute_command, overrides = overrides, networkConfiguration = network_configuration, ) task = tasks['tasks'][0] task_arn = task['taskArn'] task_id = task_arn.split('/')[-1] print(f"Launched FarSSH ECS task: {task_id}") print(f"Status: {task['lastStatus']}") while True: time.sleep(1) task_old = task task = ecs.describe_tasks(cluster = "farssh", tasks = [ task_arn ])['tasks'][0] if task['lastStatus'] == "STOPPED": raise SystemExit("ERROR: ECS task is in status STOPPED; see ECS Console for details.") if task['lastStatus'] != task_old['lastStatus']: print(f"Status: {task['lastStatus']}") if task['lastStatus'] == "RUNNING": break ip_address = select_ip_address(args, task) print(f"FarSSH task IP address: {ip_address}") print() return ip_address def select_database(args): if args.cmd_args.get('command') not in ["psql", "mysql"]: return rds = boto3.client('rds') available = [] available += rds.describe_db_instances()['DBInstances'] available += rds.describe_db_clusters()['DBClusters'] candidates = [] for db in available: if db.get('DBInstanceStatus') == "creating": continue if db.get('Status') == "creating": continue db['database'] = db.get('DatabaseName') or db.get('DBName') db['identifier'] = db.get('DBClusterIdentifier') or db.get('DBInstanceIdentifier') db['hostname'] = db['Endpoint'] if isinstance(db['Endpoint'], str) else db['Endpoint'].get('Address') db['username'] = db.get('MasterUsername') db['port'] = db.get('Port') or db['Endpoint'].get('Port') if args.cmd_args.get('command') == "psql" and "postgres" not in db['Engine']: continue if args.cmd_args.get('command') == "mysql" and "mysql" not in db['Engine'] and "maria" not in db['Engine']: continue if args.cmd_args.get('identifier') and db['identifier'].lower() != args.cmd_args.get('identifier').lower(): continue candidates += [db] if len(candidates) == 0: raise SystemExit("ERROR: No matching database found") if len(candidates) == 1: db = candidates[0] if not args.cmd_args.get('identifier'): print(f"Selected database: {db['identifier']} ({db['hostname']})") return db print("Multiple databases available; use --identifier/-i to select one:") for db in candidates: print(f"- {db['identifier']}") exit(0) def build_commands(args, ip_address, database): ssh_command = [ shutil.which("ssh") ] ssh_command += [ "-p", args.ssh_port ] ssh_command += [ "-o", f"IdentityFile {ssh_keys.login_key_file}" ] ssh_command += [ "-o", f"IdentitiesOnly yes" ] ssh_command += [ "-o", f"UserKnownHostsFile {ssh_keys.known_hosts_file}" ] ssh_command += [ "-o", f"StrictHostKeyChecking yes" ] ssh_command += [ "-o", f"ExitOnForwardFailure yes" ] ssh_command += [ "-l", f"root" ] main_command = None if args.cmd_args.get('command') == "ssh": ssh_command += [ ip_address ] ssh_command += args.cmd_args.get('extra_arguments') print("------------------------------------------------------------------------") print() elif args.cmd_args.get('command') == "proxy": ssh_command += [ "-D", "1080" ] ssh_command += [ ip_address ] ssh_command += [ "echo SOCKS proxy available on port 1080. Hit Ctrl-C to terminate.; sleep infinity" ] elif args.cmd_args.get('command') in [ "psql", "mysql" ]: port = str(database['port']) l_args = [ port, database['hostname'], port ] ssh_command += [ "-L", ":".join(l_args) ] ssh_command += [ ip_address ] if args.cmd_args.get('command') == "psql": username = args.cmd_args.get('username') or os.environ.get('PGUSER') or database.get('username') database = args.cmd_args.get('database') or os.environ.get('PGDATABASE') or database.get('database') or 'template1' main_command = [ "psql" ] main_command += [ "--host", "localhost" ] main_command += [ "--port", port ] main_command += args.cmd_args.get('extra_arguments') main_command += [ database ] main_command += [ username ] elif args.cmd_args.get('command') == "mysql": username = args.cmd_args.get('username') or database.get('username') database = args.cmd_args.get('database') or database.get('database') or 'mysql' main_command = [ "mysql" ] main_command += [ "--protocol", "tcp" ] # defaults to local socket otherwise main_command += [ "--host", "localhost" ] main_command += [ "--port", port ] main_command += [ "--user", username ] if args.cmd_args.get('password'): # we intentionally only support -p without an argument; it would # create ambiguity for argparse otherwise. main_command += [ "-p" ] main_command += args.cmd_args.get('extra_arguments') main_command += [ database ] elif args.cmd_args.get('command') == "tunnel": l_args = [ args.cmd_args.get('local_port'), args.cmd_args.get('remote_host'), args.cmd_args.get('remote_port'), ] ssh_command += [ "-L", ":".join(l_args) ] ssh_command += [ ip_address ] ssh_command += [ "echo Port forwarding tunnel established. Hit Ctrl-C to terminate.; sleep infinity" ] if main_command: mc_path = shutil.which(main_command[0]) if not mc_path: raise SystemExit(f"ERROR: Command {main_command[0]} not found") main_command[0] = mc_path return (ssh_command, main_command) # ---------------------------------------------------------------------- args = FarSshArgs() ssh_keys = SshKeys(args) database = select_database(args) ip_address = run_ecs_task(args, ssh_keys) ssh_keys.write_known_hosts(ip_address) (ssh_command, main_command) = build_commands(args, ip_address, database) if not main_command: try: subprocess.run(ssh_command) except KeyboardInterrupt: pass exit(0) ssh_command.insert(1, "-n") # StdinNull ssh_command += [ "echo connected; sleep infinity" ] with subprocess.Popen(ssh_command, stdout = subprocess.PIPE, stdin = subprocess.DEVNULL, text = True) as ssh: try: ssh_out = ssh.stdout.readline().strip() # wait for "connected" ssh.stdout.close() print(f"Tunnel connection established.") print(f"Running {' '.join(main_command)}") print("------------------------------------------------------------------------") subprocess.run(main_command) finally: ssh.terminate()