#!/usr/bin/env python3 # PYTHON_ARGCOMPLETE_OK from __future__ import unicode_literals, print_function import re import json import sys import argparse import os import time from typing import Dict, List, Tuple, Optional from datetime import date, datetime, timedelta, timezone import hashlib import math import threading from concurrent.futures import ThreadPoolExecutor import requests import getpass import subprocess from time import sleep from subprocess import PIPE import urllib3 import atexit from contextlib import redirect_stdout, redirect_stderr from io import StringIO from typing import Optional import shutil import logging import textwrap from pathlib import Path import warnings import importlib.metadata from copy import deepcopy PYPI_BASE_PATH = "https://pypi.org" # INFO - Change to False if you don't want to check for update each run. should_check_for_update = False ARGS = None TABCOMPLETE = False try: import argcomplete TABCOMPLETE = True except: # No tab-completion for you pass try: import curlify except ImportError: pass try: from urllib import quote_plus # Python 2.X except ImportError: from urllib.parse import quote_plus # Python 3+ try: JSONDecodeError = json.JSONDecodeError except AttributeError: JSONDecodeError = ValueError try: input = raw_input except NameError: pass #server_url_default = "https://vast.ai" server_url_default = os.getenv("VAST_URL") or "https://console.vast.ai" #server_url_default = "http://localhost:5002" #server_url_default = "host.docker.internal" #server_url_default = "http://localhost:5002" #server_url_default = "https://vast.ai/api/v0" logging.basicConfig( level=os.getenv("LOGLEVEL") or logging.WARN, format="%(levelname)s - %(message)s" ) def parse_version(version: str) -> tuple[int, ...]: parts = version.split(".") if len(parts) < 3: print(f"Invalid version format: {version}", file=sys.stderr) return tuple(int(part) for part in parts) def get_git_version(): try: result = subprocess.run( ["git", "describe", "--tags", "--abbrev=0"], capture_output=True, text=True, check=True, ) tag = result.stdout.strip() return tag[1:] if tag.startswith("v") else tag except Exception: return "0.0.0" def get_pip_version(): try: return importlib.metadata.version("vastai") except Exception: return "0.0.0" def is_pip_package(): try: return importlib.metadata.metadata("vastai") is not None except Exception: return False def get_update_command(stable_version: str) -> str: if is_pip_package(): if "test.pypi.org" in PYPI_BASE_PATH: return f"{sys.executable} -m pip install --force-reinstall --no-cache-dir -i {PYPI_BASE_PATH} vastai=={stable_version}" else: return f"{sys.executable} -m pip install --force-reinstall --no-cache-dir vastai=={stable_version}" else: return f"git fetch --all --tags --prune && git checkout tags/v{stable_version}" def get_local_version(): if is_pip_package(): return get_pip_version() return get_git_version() def get_project_data(project_name: str) -> dict[str, dict[str, str]]: url = PYPI_BASE_PATH + f"/pypi/{project_name}/json" response = requests.get(url, headers={"Accept": "application/json"}) # this will raise for HTTP status 4xx and 5xx response.raise_for_status() # this will raise for HTTP status >200,<=399 if response.status_code != 200: raise Exception( f"Could not get PyPi Project: {project_name}. Response: {response.status_code}" ) response_data: dict[str, dict[str, str]] = response.json() return response_data def get_pypi_version(project_data: dict[str, dict[str, str]]) -> str: info_data = project_data.get("info") if not info_data: raise Exception("Could not get PyPi Project") version_data: str = str(info_data.get("version")) return str(version_data) def check_for_update(): pypi_data = get_project_data("vastai") pypi_version = get_pypi_version(pypi_data) local_version = get_local_version() local_tuple = parse_version(local_version) pypi_tuple = parse_version(pypi_version) if local_tuple >= pypi_tuple: return user_wants_update = input( f"Update available from {local_version} to {pypi_version}. Would you like to update [Y/n]: " ).lower() if user_wants_update not in ["y", ""]: print("You selected no. If you don't want to check for updates each time, update should_check_for_update in vast.py") return update_command = get_update_command(pypi_version) print("Updating...") _ = subprocess.run( update_command, shell=True, check=True, text=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE, ) print("Update completed successfully!\nAttempt to run your command again!") sys.exit(0) APP_NAME = "vastai" VERSION = get_local_version() try: # Although xdg-base-dirs is the newer name, there's # python compatibility issues with dependencies that # can be unresolvable using things like python 3.9 # So we actually use the older name, thus older # version for now. This is as of now (2024/11/15) # the safer option. -cjm import xdg DIRS = { 'config': xdg.xdg_config_home(), 'temp': xdg.xdg_cache_home() } except: # Reasonable defaults. DIRS = { 'config': os.path.join(os.getenv('HOME'), '.config'), 'temp': os.path.join(os.getenv('HOME'), '.cache'), } for key in DIRS.keys(): DIRS[key] = path = os.path.join(DIRS[key], APP_NAME) if not os.path.exists(path): os.makedirs(path) CACHE_FILE = os.path.join(DIRS['temp'], "gpu_names_cache.json") CACHE_DURATION = timedelta(hours=24) APIKEY_FILE = os.path.join(DIRS['config'], "vast_api_key") APIKEY_FILE_HOME = os.path.expanduser("~/.vast_api_key") # Legacy if not os.path.exists(APIKEY_FILE) and os.path.exists(APIKEY_FILE_HOME): #print(f'copying key from {APIKEY_FILE_HOME} -> {APIKEY_FILE}') shutil.copyfile(APIKEY_FILE_HOME, APIKEY_FILE) api_key_guard = object() headers = {} class Object(object): pass def validate_seconds(value): """Validate that the input value is a valid number for seconds between yesterday and Jan 1, 2100.""" try: val = int(value) # Calculate min_seconds as the start of yesterday in seconds yesterday = datetime.now() - timedelta(days=1) min_seconds = int(yesterday.timestamp()) # Calculate max_seconds for Jan 1st, 2100 in seconds max_date = datetime(2100, 1, 1, 0, 0, 0) max_seconds = int(max_date.timestamp()) if not (min_seconds <= val <= max_seconds): raise argparse.ArgumentTypeError(f"{value} is not a valid second timestamp.") return val except ValueError: raise argparse.ArgumentTypeError(f"{value} is not a valid integer.") def strip_strings(value): if isinstance(value, str): return value.strip() elif isinstance(value, dict): return {k: strip_strings(v) for k, v in value.items()} elif isinstance(value, list): return [strip_strings(item) for item in value] return value # Return as is if not a string, list, or dict def string_to_unix_epoch(date_string): if date_string is None: return None try: # Check if the input is a float or integer representing Unix time return float(date_string) except ValueError: # If not, parse it as a date string date_object = datetime.strptime(date_string, "%m/%d/%Y") return time.mktime(date_object.timetuple()) def unix_to_readable(ts): # ts: integer or float, Unix timestamp return datetime.fromtimestamp(ts).strftime('%H:%M:%S|%h-%d-%Y') def fix_date_fields(query: Dict[str, Dict], date_fields: List[str]): """Takes in a query and date fields to correct and returns query with appropriate epoch dates""" new_query: Dict[str, Dict] = {} for field, sub_query in query.items(): # fix date values for given date fields if field in date_fields: new_sub_query = {k: string_to_unix_epoch(v) for k, v in sub_query.items()} new_query[field] = new_sub_query # else, use the original else: new_query[field] = sub_query return new_query class argument(object): def __init__(self, *args, mutex_group=None, **kwargs): self.args = args self.kwargs = kwargs self.mutex_group = mutex_group # Name of the mutually exclusive group this arg belongs to class hidden_aliases(object): # just a bit of a hack def __init__(self, l): self.l = l def __iter__(self): return iter(self.l) def __bool__(self): return False def __nonzero__(self): return False def append(self, x): self.l.append(x) def http_request(verb, args, req_url, headers: dict[str, str] | None = None, json = None): t = 0.15 for i in range(0, args.retry): req = requests.Request(method=verb, url=req_url, headers=headers, json=json) session = requests.Session() prep = session.prepare_request(req) if ARGS.curl: as_curl = curlify.to_curl(prep) simple = re.sub(r" -H '[^']*'", '', as_curl) parts = re.split(r'(?=\s+-\S+)', simple) pp = parts[-1].split("'") pp[-3] += "\n " parts = [*parts[:-1], *[x.rstrip() for x in "'".join(pp).split("\n")]] print("\n" + ' \\\n '.join(parts).strip() + "\n") sys.exit(0) else: r = session.send(prep) if (r.status_code == 429): time.sleep(t) t *= 1.5 else: break return r def http_get(args, req_url, headers = None, json = None): return http_request('GET', args, req_url, headers, json) def http_put(args, req_url, headers = None, json = {}): return http_request('PUT', args, req_url, headers, json) def http_post(args, req_url, headers = None, json={}): return http_request('POST', args, req_url, headers, json) def http_del(args, req_url, headers = None, json={}): return http_request('DELETE', args, req_url, headers, json) def load_permissions_from_file(file_path): with open(file_path, 'r') as file: return json.load(file) def complete_instance_machine(prefix=None, action=None, parser=None, parsed_args=None): return show__instances(ARGS, {'internal': True, 'field': 'machine_id'}) def complete_instance(prefix=None, action=None, parser=None, parsed_args=None): return show__instances(ARGS, {'internal': True, 'field': 'id'}) def complete_sshkeys(prefix=None, action=None, parser=None, parsed_args=None): return [str(m) for m in Path.home().joinpath('.ssh').glob('*.pub')] class apwrap(object): def __init__(self, *args, **kwargs): if "formatter_class" not in kwargs: kwargs["formatter_class"] = MyWideHelpFormatter self.parser = argparse.ArgumentParser(*args, **kwargs) self.parser.set_defaults(func=self.fail_with_help) self.subparsers_ = None self.subparser_objs = [] self.added_help_cmd = False self.post_setup = [] self.verbs = set() self.objs = set() def fail_with_help(self, *a, **kw): self.parser.print_help(sys.stderr) raise SystemExit def add_argument(self, *a, **kw): if not kw.get("parent_only"): for x in self.subparser_objs: try: # Create a global options group for better visual separation if not hasattr(x, '_global_options_group'): x._global_options_group = x.add_argument_group('Global options (available for all commands)') x._global_options_group.add_argument(*a, **kw) except argparse.ArgumentError: # duplicate - or maybe other things, hopefully not pass return self.parser.add_argument(*a, **kw) def subparsers(self, *a, **kw): if self.subparsers_ is None: kw["metavar"] = "command" kw["help"] = "command to run. one of:" self.subparsers_ = self.parser.add_subparsers(*a, **kw) return self.subparsers_ def get_name(self, verb, obj): if obj: self.verbs.add(verb) self.objs.add(obj) name = verb + ' ' + obj else: self.objs.add(verb) name = verb return name def command(self, *arguments, aliases=(), help=None, **kwargs): help_ = help if not self.added_help_cmd: self.added_help_cmd = True @self.command(argument("subcommand", default=None, nargs="?"), help="print this help message") def help(*a, **kw): self.fail_with_help() def inner(func): dashed_name = func.__name__.replace("_", "-") verb, _, obj = dashed_name.partition("--") name = self.get_name(verb, obj) aliases_transformed = [] if aliases else hidden_aliases([]) for x in aliases: verb, _, obj = x.partition(" ") aliases_transformed.append(self.get_name(verb, obj)) if "formatter_class" not in kwargs: kwargs["formatter_class"] = MyWideHelpFormatter sp = self.subparsers().add_parser(name, aliases=aliases_transformed, help=help_, **kwargs) # TODO: Sometimes the parser.command has a help parameter. Ideally # I'd extract this during the sdk phase but for the life of me # I can't find it. setattr(func, "mysignature", sp) setattr(func, "mysignature_help", help_) self.subparser_objs.append(sp) self._process_arguments_with_groups(sp, arguments) sp.set_defaults(func=func) return func if len(arguments) == 1 and type(arguments[0]) != argument: func = arguments[0] arguments = [] return inner(func) return inner def parse_args(self, argv=None, *a, **kw): if argv is None: argv = sys.argv[1:] argv_ = [] for x in argv: if argv_ and argv_[-1] in self.verbs: argv_[-1] += " " + x else: argv_.append(x) args = self.parser.parse_args(argv_, *a, **kw) for func in self.post_setup: func(args) return args def _process_arguments_with_groups(self, parser_obj, arguments): """Process arguments and handle mutually exclusive groups""" mutex_groups_to_required = {} arg_to_group = {} # Determine if any mutex groups are required for arg in arguments: key = arg.args[0] if arg.mutex_group: is_required = arg.kwargs.pop('required', False) group_name = arg.mutex_group arg_to_group[key] = group_name if mutex_groups_to_required.get(group_name): continue # if marked as required then it stays required else: mutex_groups_to_required[group_name] = is_required name_to_group_parser = {} # Create mutually exclusive group parsers for group_name, is_required in mutex_groups_to_required.items(): mutex_group = parser_obj.add_mutually_exclusive_group(required=is_required) name_to_group_parser[group_name] = mutex_group for arg in arguments: # Add args via the appropriate parser key = arg.args[0] if arg_to_group.get(key): group_parser = name_to_group_parser[arg_to_group[key]] tsp = group_parser.add_argument(*arg.args, **arg.kwargs) else: tsp = parser_obj.add_argument(*arg.args, **arg.kwargs) self._add_completer(tsp, arg) def _add_completer(self, tsp, arg): """Helper function to add completers based on argument names""" myCompleter = None comparator = arg.args[0].lower() if comparator.startswith('machine'): myCompleter = complete_instance_machine elif comparator.startswith('id') or comparator.endswith('id'): myCompleter = complete_instance elif comparator.startswith('ssh'): myCompleter = complete_sshkeys if myCompleter: setattr(tsp, 'completer', myCompleter) class MyWideHelpFormatter(argparse.RawTextHelpFormatter): def __init__(self, prog): super().__init__(prog, width=128, max_help_position=50, indent_increment=1) parser = apwrap( epilog="Use 'vast COMMAND --help' for more info about a command", formatter_class=MyWideHelpFormatter ) def translate_null_strings_to_blanks(d: Dict) -> Dict: """Map over a dict and translate any null string values into ' '. Leave everything else as is. This is needed because you cannot add TableCell objects with only a null string or the client crashes. :param Dict d: dict of item values. :rtype Dict: """ # Beware: locally defined function. def translate_nulls(s): if s == "": return " " return s new_d = {k: translate_nulls(v) for k, v in d.items()} return new_d #req_url = apiurl(args, "/instances", {"owner": "me"}); def apiurl(args: argparse.Namespace, subpath: str, query_args: Dict = None) -> str: """Creates the endpoint URL for a given combination of parameters. :param argparse.Namespace args: Namespace with many fields relevant to the endpoint. :param str subpath: added to end of URL to further specify endpoint. :param typing.Dict query_args: specifics such as API key and search parameters that complete the URL. :rtype str: """ result = None if query_args is None: query_args = {} if args.api_key is not None: query_args["api_key"] = args.api_key if not re.match(r"^/api/v(\d)+/", subpath): subpath = "/api/v0" + subpath query_json = None if query_args: # a_list = [ for in ] ''' vector result; for (l_expression: expression) { result.push_back(expression); } ''' # an_iterator = ( for in ) query_json = "&".join( "{x}={y}".format(x=x, y=quote_plus(y if isinstance(y, str) else json.dumps(y))) for x, y in query_args.items()) result = args.url + subpath + "?" + query_json else: result = args.url + subpath if (args.explain): print("query args:") print(query_args) print("") print(f"base: {args.url + subpath + '?'} + query: ") print(result) print("") return result def apiheaders(args: argparse.Namespace) -> Dict: """Creates the headers for a given combination of parameters. :param argparse.Namespace args: Namespace with many fields relevant to the endpoint. :rtype Dict: """ result = {} if args.api_key is not None: result["Authorization"] = "Bearer " + args.api_key return result def deindent(message: str) -> str: """ Deindent a quoted string. Scans message and finds the smallest number of whitespace characters in any line and removes that many from the start of every line. :param str message: Message to deindent. :rtype str: """ message = re.sub(r" *$", "", message, flags=re.MULTILINE) indents = [len(x) for x in re.findall("^ *(?=[^ ])", message, re.MULTILINE) if len(x)] a = min(indents) message = re.sub(r"^ {," + str(a) + "}", "", message, flags=re.MULTILINE) return message.strip() # These are the fields that are displayed when a search is run displayable_fields = ( # ("bw_nvlink", "Bandwidth NVLink", "{}", None, True), ("id", "ID", "{}", None, True), ("cuda_max_good", "CUDA", "{:0.1f}", None, True), ("num_gpus", "N", "{}x", None, False), ("gpu_name", "Model", "{}", None, True), ("pcie_bw", "PCIE", "{:0.1f}", None, True), ("cpu_ghz", "cpu_ghz", "{:0.1f}", None, True), ("cpu_cores_effective", "vCPUs", "{:0.1f}", None, True), ("cpu_ram", "RAM", "{:0.1f}", lambda x: x / 1000, False), ("disk_space", "Disk", "{:.0f}", None, True), ("dph_total", "$/hr", "{:0.4f}", None, True), ("dlperf", "DLP", "{:0.1f}", None, True), ("dlperf_per_dphtotal", "DLP/$", "{:0.2f}", None, True), ("score", "score", "{:0.1f}", None, True), ("driver_version", "NV Driver", "{}", None, True), ("inet_up", "Net_up", "{:0.1f}", None, True), ("inet_down", "Net_down", "{:0.1f}", None, True), ("reliability", "R", "{:0.1f}", lambda x: x * 100, True), ("duration", "Max_Days", "{:0.1f}", lambda x: x / (24.0 * 60.0 * 60.0), True), ("machine_id", "mach_id", "{}", None, True), ("verification", "status", "{}", None, True), ("host_id", "host_id", "{}", None, True), ("direct_port_count", "ports", "{}", None, True), ("geolocation", "country", "{}", None, True), # ("direct_port_count", "Direct Port Count", "{}", None, True), ) displayable_fields_reserved = ( # ("bw_nvlink", "Bandwidth NVLink", "{}", None, True), ("id", "ID", "{}", None, True), ("cuda_max_good", "CUDA", "{:0.1f}", None, True), ("num_gpus", "N", "{}x", None, False), ("gpu_name", "Model", "{}", None, True), ("pcie_bw", "PCIE", "{:0.1f}", None, True), ("cpu_ghz", "cpu_ghz", "{:0.1f}", None, True), ("cpu_cores_effective", "vCPUs", "{:0.1f}", None, True), ("cpu_ram", "RAM", "{:0.1f}", lambda x: x / 1000, False), ("disk_space", "Disk", "{:.0f}", None, True), ("discounted_dph_total", "$/hr", "{:0.4f}", None, True), ("dlperf", "DLP", "{:0.1f}", None, True), ("dlperf_per_dphtotal", "DLP/$", "{:0.2f}", None, True), ("driver_version", "NV Driver", "{}", None, True), ("inet_up", "Net_up", "{:0.1f}", None, True), ("inet_down", "Net_down", "{:0.1f}", None, True), ("reliability", "R", "{:0.1f}", lambda x: x * 100, True), ("duration", "Max_Days", "{:0.1f}", lambda x: x / (24.0 * 60.0 * 60.0), True), ("machine_id", "mach_id", "{}", None, True), ("verification", "status", "{}", None, True), ("host_id", "host_id", "{}", None, True), ("direct_port_count", "ports", "{}", None, True), ("geolocation", "country", "{}", None, True), # ("direct_port_count", "Direct Port Count", "{}", None, True), ) vol_offers_fields = { "cpu_arch", "cuda_vers", "cluster_id", "nw_disk_min_bw", "nw_disk_avg_bw", "nw_disk_max_bw", "datacenter", "disk_bw", "disk_space", "driver_version", "duration", "geolocation", "gpu_arch", "has_avx", "host_id", "id", "inet_down", "inet_up", "machine_id", "pci_gen", "pcie_bw", "reliability", "storage_cost", "static_ip", "total_flops", "ubuntu_version", "verified", } vol_displayable_fields = ( ("id", "ID", "{}", None, True), ("cuda_max_good", "CUDA", "{:0.1f}", None, True), ("cpu_ghz", "cpu_ghz", "{:0.1f}", None, True), ("disk_bw", "Disk B/W", "{:0.1f}", None, True), ("disk_space", "Disk", "{:.0f}", None, True), ("disk_name", "Disk Name", "{}", None, True), ("storage_cost", "$/Gb/Month", "{:.2f}", None, True), ("driver_version", "NV Driver", "{}", None, True), ("inet_up", "Net_up", "{:0.1f}", None, True), ("inet_down", "Net_down", "{:0.1f}", None, True), ("reliability", "R", "{:0.1f}", lambda x: x * 100, True), ("duration", "Max_Days", "{:0.1f}", lambda x: x / (24.0 * 60.0 * 60.0), True), ("machine_id", "mach_id", "{}", None, True), ("verification", "status", "{}", None, True), ("host_id", "host_id", "{}", None, True), ("geolocation", "country", "{}", None, True), ) nw_vol_displayable_fields = ( ("id", "ID", "{}", None, True), ("disk_space", "Disk", "{:.0f}", None, True), ("storage_cost", "$/Gb/Month", "{:.2f}", None, True), ("inet_up", "Net_up", "{:0.1f}", None, True), ("inet_down", "Net_down", "{:0.1f}", None, True), ("reliability", "R", "{:0.1f}", lambda x: x * 100, True), ("duration", "Max_Days", "{:0.1f}", lambda x: x / (24.0 * 60.0 * 60.0), True), ("verification", "status", "{}", None, True), ("host_id", "host_id", "{}", None, True), ("cluster_id", "cluster_id", "{}", None, True), ("geolocation", "country", "{}", None, True), ("nw_disk_min_bw", "Min BW MiB/s", "{}", None, True), ("nw_disk_max_bw", "Max BW MiB/s", "{}", None, True), ("nw_disk_avg_bw", "Avg BW MiB/s", "{}", None, True), ) # Need to add bw_nvlink, machine_id, direct_port_count to output. # These fields are displayed when you do 'show instances' instance_fields = ( ("id", "ID", "{}", None, True), ("machine_id", "Machine", "{}", None, True), ("actual_status", "Status", "{}", None, True), ("num_gpus", "Num", "{}x", None, False), ("gpu_name", "Model", "{}", None, True), ("gpu_util", "Util. %", "{:0.1f}", None, True), ("cpu_cores_effective", "vCPUs", "{:0.1f}", None, True), ("cpu_ram", "RAM", "{:0.1f}", lambda x: x / 1000, False), ("disk_space", "Storage", "{:.0f}", None, True), ("ssh_host", "SSH Addr", "{}", None, True), ("ssh_port", "SSH Port", "{}", None, True), ("dph_total", "$/hr", "{:0.4f}", None, True), ("image_uuid", "Image", "{}", None, True), # ("dlperf", "DLPerf", "{:0.1f}", None, True), # ("dlperf_per_dphtotal", "DLP/$", "{:0.1f}", None, True), ("inet_up", "Net up", "{:0.1f}", None, True), ("inet_down", "Net down", "{:0.1f}", None, True), ("reliability2", "R", "{:0.1f}", lambda x: x * 100, True), ("label", "Label", "{}", None, True), ("duration", "age(hours)", "{:0.2f}", lambda x: x/(3600.0), True), ("uptime_mins", "uptime(mins)", "{:0.2f}", None, True), ) cluster_fields = ( ("id", "ID", "{}", None, True), ("subnet", "Subnet", "{}", None, True), ("node_count", "Nodes", "{}", None, True), ("manager_id", "Manager ID", "{}", None, True), ("manager_ip", "Manager IP", "{}", None, True), ("machine_ids", "Machine ID's", "{}", None, True) ) network_disk_fields = ( ("network_disk_id", "Network Disk ID", "{}", None, True), ("free_space", "Free Space (GB)", "{}", None, True), ("total_space", "Total Space (GB)", "{}", None, True), ) network_disk_machine_fields = ( ("machine_id", "Machine ID", "{}", None, True), ("mount_point", "Mount Point", "{}", None, True), ) overlay_fields = ( ("overlay_id", "Overlay ID", "{}", None, True), ("name", "Name", "{}", None, True), ("subnet", "Subnet", "{}", None, True), ("cluster_id", "Cluster ID", "{}", None, True), ("instance_count", "Instances", "{}", None, True), ("instances", "Instance IDs", "{}", None, True), ) volume_fields = ( ("id", "ID", "{}", None, True), ("cluster_id", "Cluster ID", "{}", None, True), ("label", "Name", "{}", None, True), ("disk_space", "Disk", "{:.0f}", None, True), ("status", "status", "{}", None, True), ("disk_name", "Disk Name", "{}", None, True), ("driver_version", "NV Driver", "{}", None, True), ("inet_up", "Net_up", "{:0.1f}", None, True), ("inet_down", "Net_down", "{:0.1f}", None, True), ("reliability2", "R", "{:0.1f}", lambda x: x * 100, True), ("duration", "age(hours)", "{:0.2f}", lambda x: x/(3600.0), True), ("machine_id", "mach_id", "{}", None, True), ("verification", "Verification", "{}", None, True), ("host_id", "host_id", "{}", None, True), ("geolocation", "country", "{}", None, True), ("instances", "instances","{}", None, True) ) # These fields are displayed when you do 'show machines' machine_fields = ( ("id", "ID", "{}", None, True), ("num_gpus", "#gpus", "{}", None, True), ("gpu_name", "gpu_name", "{}", None, True), ("disk_space", "disk", "{}", None, True), ("hostname", "hostname", "{}", lambda x: x[:16], True), ("driver_version", "driver", "{}", None, True), ("reliability2", "reliab", "{:0.4f}", None, True), ("verification", "veri", "{}", None, True), ("public_ipaddr", "ip", "{}", None, True), ("geolocation", "geoloc", "{}", None, True), ("num_reports", "reports", "{}", None, True), ("listed_gpu_cost", "gpuD_$/h", "{:0.2f}", None, True), ("min_bid_price", "gpuI$/h", "{:0.2f}", None, True), ("credit_discount_max", "rdisc", "{:0.2f}", None, True), ("listed_inet_up_cost", "netu_$/TB", "{:0.2f}", lambda x: x * 1024, True), ("listed_inet_down_cost", "netd_$/TB", "{:0.2f}", lambda x: x * 1024, True), ("gpu_occupancy", "occup", "{}", None, True), ) # These fields are displayed when you do 'show maints' maintenance_fields = ( ("machine_id", "Machine ID", "{}", None, True), ("start_time", "Start (Date/Time)", "{}", lambda x: datetime.fromtimestamp(x).strftime('%Y-%m-%d/%H:%M'), True), ("end_time", "End (Date/Time)", "{}", lambda x: datetime.fromtimestamp(x).strftime('%Y-%m-%d/%H:%M'), True), ("duration_hours", "Duration (Hrs)", "{}", None, True), ("maintenance_category", "Category", "{}", None, True), ) ipaddr_fields = ( ("ip", "ip", "{}", None, True), ("first_seen", "first_seen", "{}", None, True), ("first_location", "first_location", "{}", None, True), ) audit_log_fields = ( ("ip_address", "ip_address", "{}", None, True), ("api_key_id", "api_key_id", "{}", None, True), ("created_at", "created_at", "{}", None, True), ("api_route", "api_route", "{}", None, True), ("args", "args", "{}", None, True), ) scheduled_jobs_fields = ( ("id", "Scheduled Job ID", "{}", None, True), ("instance_id", "Instance ID", "{}", None, True), ("api_endpoint", "API Endpoint", "{}", None, True), ("start_time", "Start (Date/Time in UTC)", "{}", lambda x: datetime.fromtimestamp(x).strftime('%Y-%m-%d/%H:%M'), True), ("end_time", "End (Date/Time in UTC)", "{}", lambda x: datetime.fromtimestamp(x).strftime('%Y-%m-%d/%H:%M'), True), ("day_of_the_week", "Day of the Week", "{}", None, True), ("hour_of_the_day", "Hour of the Day in UTC", "{}", None, True), ("min_of_the_hour", "Minute of the Hour", "{}", None, True), ("frequency", "Frequency", "{}", None, True), ) invoice_fields = ( ("description", "Description", "{}", None, True), ("quantity", "Quantity", "{}", None, True), ("rate", "Rate", "{}", None, True), ("amount", "Amount", "{}", None, True), ("timestamp", "Timestamp", "{:0.1f}", None, True), ("type", "Type", "{}", None, True) ) user_fields = ( # ("api_key", "api_key", "{}", None, True), ("balance", "Balance", "{}", None, True), ("balance_threshold", "Bal. Thld", "{}", None, True), ("balance_threshold_enabled", "Bal. Thld Enabled", "{}", None, True), ("billaddress_city", "City", "{}", None, True), ("billaddress_country", "Country", "{}", None, True), ("billaddress_line1", "Addr Line 1", "{}", None, True), ("billaddress_line2", "Addr line 2", "{}", None, True), ("billaddress_zip", "Zip", "{}", None, True), ("billed_expected", "Billed Expected", "{}", None, True), ("billed_verified", "Billed Vfy", "{}", None, True), ("billing_creditonly", "Billing Creditonly", "{}", None, True), ("can_pay", "Can Pay", "{}", None, True), ("credit", "Credit", "{:0.2f}", None, True), ("email", "Email", "{}", None, True), ("email_verified", "Email Vfy", "{}", None, True), ("fullname", "Full Name", "{}", None, True), ("got_signup_credit", "Got Signup Credit", "{}", None, True), ("has_billing", "Has Billing", "{}", None, True), ("has_payout", "Has Payout", "{}", None, True), ("id", "Id", "{}", None, True), ("last4", "Last4", "{}", None, True), ("paid_expected", "Paid Expected", "{}", None, True), ("paid_verified", "Paid Vfy", "{}", None, True), ("password_resettable", "Pwd Resettable", "{}", None, True), ("paypal_email", "Paypal Email", "{}", None, True), ("ssh_key", "Ssh Key", "{}", None, True), ("user", "User", "{}", None, True), ("username", "Username", "{}", None, True) ) connection_fields = ( ("id", "ID", "{}", None, True), ("name", "NAME", "{}", None, True), ("cloud_type", "Cloud Type", "{}", None, True), ) def version_string_sort(a, b) -> int: """ Accepts two version strings and decides whether a > b, a == b, or a < b. This is meant as a sort function to be used for the driver versions in which only the == operator currently works correctly. Not quite finished... :param str a: :param str b: :return int: """ a_parts = a.split(".") b_parts = b.split(".") return 0 offers_fields = { "bw_nvlink", "compute_cap", "cpu_arch", "cpu_cores", "cpu_cores_effective", "cpu_ghz", "cpu_ram", "cuda_max_good", "datacenter", "direct_port_count", "driver_version", "disk_bw", "disk_space", "dlperf", "dlperf_per_dphtotal", "dph_total", "duration", "external", "flops_per_dphtotal", "gpu_arch", "gpu_display_active", "gpu_frac", # "gpu_ram_free_min", "gpu_mem_bw", "gpu_name", "gpu_ram", "gpu_total_ram", "gpu_display_active", "gpu_max_power", "gpu_max_temp", "has_avx", "host_id", "id", "inet_down", "inet_down_cost", "inet_up", "inet_up_cost", "machine_id", "min_bid", "mobo_name", "num_gpus", "pci_gen", "pcie_bw", "reliability", #"reliability2", "rentable", "rented", "storage_cost", "static_ip", "total_flops", "ubuntu_version", "verification", "verified", "vms_enabled", "geolocation", "cluster_id" } offers_alias = { "cuda_vers": "cuda_max_good", "display_active": "gpu_display_active", #"reliability": "reliability2", "dlperf_usd": "dlperf_per_dphtotal", "dph": "dph_total", "flops_usd": "flops_per_dphtotal", } offers_mult = { "cpu_ram": 1000, "gpu_ram": 1000, "gpu_total_ram" : 1000, "duration": 24.0 * 60.0 * 60.0, } def parse_query(query_str: str, res: Dict = None, fields = {}, field_alias = {}, field_multiplier = {}) -> Dict: """ Basically takes a query string (like the ones in the examples of commands for the search__offers function) and processes it into a dict of URL parameters to be sent to the server. :param str query_str: :param Dict res: :return Dict: """ if query_str is None: return res if res is None: res = {} if type(query_str) == list: query_str = " ".join(query_str) query_str = query_str.strip() # Revised regex pattern to accurately capture quoted strings, bracketed lists, and single words/numbers #pattern = r"([a-zA-Z0-9_]+)\s*(=|!=|<=|>=|<|>| in | nin | eq | neq | not eq | not in )?\s*(\"[^\"]*\"|\[[^\]]+\]|[^ ]+)" #pattern = "([a-zA-Z0-9_]+)( *[=>=": "gte", ">": "gt", "gt": "gt", "gte": "gte", "<=": "lte", "<": "lt", "lt": "lt", "lte": "lte", "!=": "neq", "==": "eq", "=": "eq", "eq": "eq", "neq": "neq", "noteq": "neq", "not eq": "neq", "notin": "notin", "not in": "notin", "nin": "notin", "in": "in", } joined = "".join("".join(x) for x in opts) if joined != query_str: raise ValueError( "Unconsumed text. Did you forget to quote your query? " + repr(joined) + " != " + repr(query_str)) for field, op, _, value, _ in opts: value = value.strip(",[]") v = res.setdefault(field, {}) op = op.strip() op_name = op_names.get(op) if field in field_alias: res.pop(field) field = field_alias[field] if (field == "driver_version") and ('.' in value): value = numeric_version(value) if not field in fields: print("Warning: Unrecognized field: {}, see list of recognized fields.".format(field), file=sys.stderr); if not op_name: raise ValueError("Unknown operator. Did you forget to quote your query? " + repr(op).strip("u")) if op_name in ["in", "notin"]: value = [x.strip() for x in value.split(",") if x.strip()] if not value: raise ValueError("Value cannot be blank. Did you forget to quote your query? " + repr((field, op, value))) if not field: raise ValueError("Field cannot be blank. Did you forget to quote your query? " + repr((field, op, value))) if value in ["?", "*", "any"]: if op_name != "eq": raise ValueError("Wildcard only makes sense with equals.") if field in v: del v[field] if field in res: del res[field] continue if isinstance(value, str): value = value.replace('_', ' ') value = value.strip('\"') elif isinstance(value, list): value = [x.replace('_', ' ') for x in value] value = [x.strip('\"') for x in value] if field in field_multiplier: value = float(value) * field_multiplier[field] v[op_name] = value else: #print(value) if (value == 'true') or (value == 'True'): v[op_name] = True elif (value == 'false') or (value == 'False'): v[op_name] = False elif (value == 'None') or (value == 'null'): v[op_name] = None else: v[op_name] = value if field not in res: res[field] = v else: res[field].update(v) #print(res) return res def display_table(rows: list, fields: Tuple, replace_spaces: bool = True) -> None: """Basically takes a set of field names and rows containing the corresponding data and prints a nice tidy table of it. :param list rows: Each row is a dict with keys corresponding to the field names (first element) in the fields tuple. :param Tuple fields: 5-tuple describing a field. First element is field name, second is human readable version, third is format string, fourth is a lambda function run on the data in that field, fifth is a bool determining text justification. True = left justify, False = right justify. Here is an example showing the tuples in action. :rtype None: Example of 5-tuple: ("cpu_ram", "RAM", "{:0.1f}", lambda x: x / 1000, False) """ header = [name for _, name, _, _, _ in fields] out_rows = [header] lengths = [len(x) for x in header] for instance in rows: row = [] out_rows.append(row) for key, name, fmt, conv, _ in fields: conv = conv or (lambda x: x) val = instance.get(key, None) if val is None: s = "-" else: val = conv(val) s = fmt.format(val) if replace_spaces: s = s.replace(' ', '_') idx = len(row) lengths[idx] = max(len(s), lengths[idx]) row.append(s) for row in out_rows: out = [] for l, s, f in zip(lengths, row, fields): _, _, _, _, ljust = f if ljust: s = s.ljust(l) else: s = s.rjust(l) out.append(s) print(" ".join(out)) def print_or_page(args, text): """ Print text to terminal, or pipe to pager_cmd if too long. """ line_threshold = shutil.get_terminal_size(fallback=(80, 24)).lines lines = text.splitlines() if not args.full and len(lines) > line_threshold: pager_cmd = ['less', '-R'] if shutil.which('less') else None if pager_cmd: proc = subprocess.Popen(pager_cmd, stdin=subprocess.PIPE) proc.communicate(input=text.encode()) return True else: print(text) return False else: print(text) return False class VRLException(Exception): pass def parse_vast_url(url_str): """ Breaks up a vast-style url in the form instance_id:path and does some basic sanity type-checking. :param url_str: :return: """ instance_id = None path = url_str #print(f'url_str: {url_str}') if (":" in url_str): url_parts = url_str.split(":", 2) if len(url_parts) == 2: (instance_id, path) = url_parts else: raise VRLException("Invalid VRL (Vast resource locator).") else: try: instance_id = int(path) path = "/" except: pass valid_unix_path_regex = re.compile('^(/)?([^/\0]+(/)?)+$') # Got this regex from https://stackoverflow.com/questions/537772/what-is-the-most-correct-regular-expression-for-a-unix-file-path if (path != "/") and (valid_unix_path_regex.match(path) is None): raise VRLException(f"Path component: {path} of VRL is not a valid Unix style path.") #print(f'instance_id: {instance_id}') #print(f'path: {path}') return (instance_id, path) def get_ssh_key(argstr): ssh_key = argstr # Including a path to a public key is pretty reasonable. if os.path.exists(argstr): with open(argstr) as f: ssh_key = f.read() if "PRIVATE KEY" in ssh_key: raise ValueError(deindent(""" 🐴 Woah, hold on there, partner! That's a *private* SSH key. You need to give the *public* one. It usually starts with 'ssh-rsa', is on a single line, has around 200 or so "base64" characters and ends with some-user@some-where. "Generate public ssh key" would be a good search term if you don't know how to do this. """)) if not ssh_key.lower().startswith('ssh'): raise ValueError(deindent(""" Are you sure that's an SSH public key? Usually it starts with the stanza 'ssh-(keytype)' where the keytype can be things such as rsa, ed25519-sk, or dsa. What you passed me was: {} And welp, that just don't look right. """.format(ssh_key))) return ssh_key @parser.command( argument("instance_id", help="id of instance to attach to", type=int), argument("ssh_key", help="ssh key to attach to instance", type=str), usage="vastai attach ssh instance_id ssh_key", help="Attach an ssh key to an instance. This will allow you to connect to the instance with the ssh key", epilog=deindent(""" Attach an ssh key to an instance. This will allow you to connect to the instance with the ssh key. Examples: vastai attach ssh 12371 ssh-rsa AAAAB3NzaC1yc2EAAA... vastai attach ssh 12371 ssh-rsa $(cat ~/.ssh/id_rsa) """), ) def attach__ssh(args): ssh_key = get_ssh_key(args.ssh_key) url = apiurl(args, "/instances/{id}/ssh/".format(id=args.instance_id)) req_json = {"ssh_key": ssh_key} r = http_post(args, url, headers=headers, json=req_json) r.raise_for_status() print(r.json()) @parser.command( argument("dst", help="instance_id:/path to target of copy operation", type=str), usage="vastai cancel copy DST", help="Cancel a remote copy in progress, specified by DST id", epilog=deindent(""" Use this command to cancel any/all current remote copy operations copying to a specific named instance, given by DST. Examples: vast cancel copy 12371 The first example cancels all copy operations currently copying data into instance 12371 """), ) def cancel__copy(args: argparse.Namespace): """ Cancel a remote copy in progress, specified by DST id" @param dst: ID of copy instance Target to cancel. """ url = apiurl(args, f"/commands/copy_direct/") dst_id = args.dst if (dst_id is None): print("invalid arguments") return print(f"canceling remote copies to {dst_id} ") req_json = { "client_id": "me", "dst_id": dst_id, } r = http_del(args, url, headers=headers,json=req_json) r.raise_for_status() if (r.status_code == 200): rj = r.json(); if (rj["success"]): print("Remote copy canceled - check instance status bar for progress updates (~30 seconds delayed).") else: print(rj["msg"]); else: print(r.text); print("failed with error {r.status_code}".format(**locals())); @parser.command( argument("dst", help="instance_id:/path to target of sync operation", type=str), usage="vastai cancel sync DST", help="Cancel a remote copy in progress, specified by DST id", epilog=deindent(""" Use this command to cancel any/all current remote cloud sync operations copying to a specific named instance, given by DST. Examples: vast cancel sync 12371 The first example cancels all copy operations currently copying data into instance 12371 """), ) def cancel__sync(args: argparse.Namespace): """ Cancel a remote cloud sync in progress, specified by DST id" @param dst: ID of cloud sync instance Target to cancel. """ url = apiurl(args, f"/commands/rclone/") dst_id = args.dst if (dst_id is None): print("invalid arguments") return print(f"canceling remote copies to {dst_id} ") req_json = { "client_id": "me", "dst_id": dst_id, } r = http_del(args, url, headers=headers,json=req_json) r.raise_for_status() if (r.status_code == 200): rj = r.json(); if (rj["success"]): print("Remote copy canceled - check instance status bar for progress updates (~30 seconds delayed).") else: print(rj["msg"]); else: print(r.text); print("failed with error {r.status_code}".format(**locals())); def default_start_date(): return datetime.now(timezone.utc).strftime("%Y-%m-%d") def default_end_date(): return (datetime.now(timezone.utc) + timedelta(days=7)).strftime("%Y-%m-%d") def convert_timestamp_to_date(unix_timestamp): utc_datetime = datetime.fromtimestamp(unix_timestamp, tz=timezone.utc) return utc_datetime.strftime("%Y-%m-%d") def parse_day_cron_style(value): """ Accepts an integer string 0-6 or '*' to indicate 'Every day'. Returns 0-6 as int, or None if '*'. """ val = str(value).strip() if val == "*": return None try: day = int(val) if 0 <= day <= 6: return day except ValueError: pass raise argparse.ArgumentTypeError("Day must be 0-6 (0=Sunday) or '*' for every day.") def parse_hour_cron_style(value): """ Accepts an integer string 0-23 or '*' to indicate 'Every hour'. Returns 0-23 as int, or None if '*'. """ val = str(value).strip() if val == "*": return None try: hour = int(val) if 0 <= hour <= 23: return hour except ValueError: pass raise argparse.ArgumentTypeError("Hour must be 0-23 or '*' for every hour.") @parser.command( argument("id", help="id of instance type to change bid", type=int), argument("--price", help="per machine bid price in $/hour", type=float), argument("--schedule", choices=["HOURLY", "DAILY", "WEEKLY"], help="try to schedule a command to run hourly, daily, or monthly. Valid values are HOURLY, DAILY, WEEKLY For ex. --schedule DAILY"), argument("--start_date", type=str, default=default_start_date(), help="Start date/time in format 'YYYY-MM-DD HH:MM:SS PM' (UTC). Default is now. (optional)"), argument("--end_date", type=str, default=default_end_date(), help="End date/time in format 'YYYY-MM-DD HH:MM:SS PM' (UTC). Default is 7 days from now. (optional)"), argument("--day", type=parse_day_cron_style, help="Day of week you want scheduled job to run on (0-6, where 0=Sunday) or \"*\". Default will be 0. For ex. --day 0", default=0), argument("--hour", type=parse_hour_cron_style, help="Hour of day you want scheduled job to run on (0-23) or \"*\" (UTC). Default will be 0. For ex. --hour 16", default=0), usage="vastai change bid id [--price PRICE]", help="Change the bid price for a spot/interruptible instance", epilog=deindent(""" Change the current bid price of instance id to PRICE. If PRICE is not specified, then a winning bid price is used as the default. """), ) def change__bid(args: argparse.Namespace): """Alter the bid with id contained in args. :param argparse.Namespace args: should supply all the command-line options :rtype int: """ url = apiurl(args, "/instances/bid_price/{id}/".format(id=args.id)) json_blob = {"client_id": "me", "price": args.price,} if (args.explain): print("request json: ") print(json_blob) if (args.schedule): validate_frequency_values(args.day, args.hour, args.schedule) cli_command = "change bid" api_endpoint = "/api/v0/instances/bid_price/{id}/".format(id=args.id) json_blob["instance_id"] = args.id add_scheduled_job(args, json_blob, cli_command, api_endpoint, "PUT", instance_id=args.id) return r = http_put(args, url, headers=headers, json=json_blob) r.raise_for_status() print("Per gpu bid price changed".format(r.json())) @parser.command( argument("source", help="id of volume contract being cloned", type=int), argument("dest", help="id of volume offer volume is being copied to", type=int), argument("-s", "--size", help="Size of new volume contract, in GB. Must be greater than or equal to the source volume, and less than or equal to the destination offer.", type=float), argument("-d", "--disable_compression", action="store_true", help="Do not compress volume data before copying."), usage="vastai copy volume [options]", help="Clone an existing volume", epilog=deindent(""" Create a new volume with the given offer, by copying the existing volume. Size defaults to the size of the existing volume, but can be increased if there is available space. """) ) def clone__volume(args: argparse.Namespace): json_blob={ "source" : args.source, "dest": args.dest, } if args.size: json_blob["size"] = args.size if args.disable_compression: json_blob["disable_compression"] = True url = apiurl(args, "/volumes/copy/") if (args.explain): print("request json: ") print(json_blob) r = http_post(args, url, headers=headers,json=json_blob) r.raise_for_status() if args.raw: return r else: print("Created. {}".format(r.json())) @parser.command( argument("src", help="Source location for copy operation (supports multiple formats)", type=str), argument("dst", help="Target location for copy operation (supports multiple formats)", type=str), argument("-i", "--identity", help="Location of ssh private key", type=str), usage="vastai copy SRC DST", help="Copy directories between instances and/or local", epilog=deindent(""" Copies a directory from a source location to a target location. Each of source and destination directories can be either local or remote, subject to appropriate read and write permissions required to carry out the action. Supported location formats: - [instance_id:]path (legacy format, still supported) - C.instance_id:path (container copy format) - cloud_service:path (cloud service format) - cloud_service.cloud_service_id:path (cloud service with ID) - local:path (explicit local path) - V.volume_id:path (volume copy, see restrictions) You should not copy to /root or / as a destination directory, as this can mess up the permissions on your instance ssh folder, breaking future copy operations (as they use ssh authentication) You can see more information about constraints here: https://vast.ai/docs/gpu-instances/data-movement#constraints Volume copy is currently only supported for copying to other volumes or instances, not cloud services or local. Examples: vast copy 6003036:/workspace/ 6003038:/workspace/ vast copy C.11824:/data/test local:data/test vast copy local:data/test C.11824:/data/test vast copy drive:/folder/file.txt C.6003036:/workspace/ vast copy s3.101:/data/ C.6003036:/workspace/ vast copy V.1234:/file C.5678:/workspace/ The first example copy syncs all files from the absolute directory '/workspace' on instance 6003036 to the directory '/workspace' on instance 6003038. The second example copy syncs files from container 11824 to the local machine using structured syntax. The third example copy syncs files from local to container 11824 using structured syntax. The fourth example copy syncs files from Google Drive to an instance. The fifth example copy syncs files from S3 bucket with id 101 to an instance. """), ) def copy(args: argparse.Namespace): """ Transfer data from one instance to another. @param src: Location of data object to be copied. @param dst: Target to copy object to. """ (src_id, src_path) = parse_vast_url(args.src) (dst_id, dst_path) = parse_vast_url(args.dst) if (src_id is None) and (dst_id is None): pass #print("invalid arguments") #return print(f"copying {str(src_id)+':' if src_id else ''}{src_path} {str(dst_id)+':' if dst_id else ''}{dst_path}") req_json = { "client_id": "me", "src_id": src_id, "dst_id": dst_id, "src_path": src_path, "dst_path": dst_path, } if (args.explain): print("request json: ") print(req_json) if (src_id is None) or (dst_id is None): url = apiurl(args, f"/commands/rsync/") else: url = apiurl(args, f"/commands/copy_direct/") r = http_put(args, url, headers=headers,json=req_json) r.raise_for_status() if (r.status_code == 200): rj = r.json() #print(json.dumps(rj, indent=1, sort_keys=True)) if (rj["success"]) and ((src_id is None or src_id == "local") or (dst_id is None or dst_id == "local")): homedir = subprocess.getoutput("echo $HOME") #print(f"homedir: {homedir}") remote_port = None identity = f"-i {args.identity}" if (args.identity is not None) else "" if (src_id is None or src_id == "local"): #result = subprocess.run(f"mkdir -p {src_path}", shell=True) remote_port = rj["dst_port"] remote_addr = rj["dst_addr"] cmd = f"rsync -arz -v --progress --rsh=ssh -e 'ssh {identity} -p {remote_port} -o StrictHostKeyChecking=no' {src_path} vastai_kaalia@{remote_addr}::{dst_id}/{dst_path}" print(cmd) result = subprocess.run(cmd, shell=True) #result = subprocess.run(["sudo", "rsync" "-arz", "-v", "--progress", "-rsh=ssh", "-e 'sudo ssh -i {homedir}/.ssh/id_rsa -p {remote_port} -o StrictHostKeyChecking=no'", src_path, "vastai_kaalia@{remote_addr}::{dst_id}"], shell=True) elif (dst_id is None or dst_id == "local"): result = subprocess.run(f"mkdir -p {dst_path}", shell=True) remote_port = rj["src_port"] remote_addr = rj["src_addr"] cmd = f"rsync -arz -v --progress --rsh=ssh -e 'ssh {identity} -p {remote_port} -o StrictHostKeyChecking=no' vastai_kaalia@{remote_addr}::{src_id}/{src_path} {dst_path}" print(cmd) result = subprocess.run(cmd, shell=True) #result = subprocess.run(["sudo", "rsync" "-arz", "-v", "--progress", "-rsh=ssh", "-e 'ssh -i {homedir}/.ssh/id_rsa -p {remote_port} -o StrictHostKeyChecking=no'", "vastai_kaalia@{remote_addr}::{src_id}", dst_path], shell=True) else: if (rj["success"]): print("Remote to Remote copy initiated - check instance status bar for progress updates (~30 seconds delayed).") else: if rj["msg"] == "src_path not supported VMs.": print("copy between VM instances does not currently support subpaths (only full disk copy)") elif rj["msg"] == "dst_path not supported for VMs.": print("copy between VM instances does not currently support subpaths (only full disk copy)") else: print(rj["msg"]) else: print(r.text) print("failed with error {r.status_code}".format(**locals())); ''' @parser.command( argument("src", help="instance_id of source VM.", type=int), argument("dst", help="instance_id of destination VM", type=int), usage="vastai vm copy SRC DST", help=" Copy VM image from one VM instance to another", epilog=deindent(""" Copies the entire VM image of from one instance to another. Note: destination VM must be stopped during copy. The source VM does not need to be stopped, but it's highly recommended that you keep the source VM stopped for the duration of the copy. """), ) def vm__copy(args: argparse.Namespace): """ Transfer VM image from one instance to another. @param src: instance_id of source. @param dst: instance_id of destination. """ src_id = args.src dst_id = args.dst print(f"copying from {src_id} to {dst_id}") req_json = { "client_id": "me", "src_id": src_id, "dst_id": dst_id, } url = apiurl(args, f"/commands/copy_direct/") if (args.explain): print("request json: ") print(req_json) r = http_put(args, url, headers=headers,json=req_json) r.raise_for_status() if (r.status_code == 200): rj = r.json(); if (rj["success"]): print("Remote to Remote copy initiated - check instance status bar for progress updates (~30 seconds delayed).") else: if rj["msg"] == "Invalid src_path.": print("src instance is not a VM") elif rj["msg"] == "Invalid dst_path.": print("dst instance is not a VM") else: print(rj["msg"]); else: print(r.text); print("failed with error {r.status_code}".format(**locals())); ''' @parser.command( argument("--src", help="path to source of object to copy", type=str), argument("--dst", help="path to target of copy operation", type=str, default="/workspace"), argument("--instance", help="id of the instance", type=str), argument("--connection", help="id of cloud connection on your account (get from calling 'vastai show connections')", type=str), argument("--transfer", help="type of transfer, possible options include Instance To Cloud and Cloud To Instance", type=str, default="Instance to Cloud"), argument("--dry-run", help="show what would have been transferred", action="store_true"), argument("--size-only", help="skip based on size only, not mod-time or checksum", action="store_true"), argument("--ignore-existing", help="skip all files that exist on destination", action="store_true"), argument("--update", help="skip files that are newer on the destination", action="store_true"), argument("--delete-excluded", help="delete files on dest excluded from transfer", action="store_true"), argument("--schedule", choices=["HOURLY", "DAILY", "WEEKLY"], help="try to schedule a command to run hourly, daily, or monthly. Valid values are HOURLY, DAILY, WEEKLY For ex. --schedule DAILY"), argument("--start_date", type=str, default=default_start_date(), help="Start date/time in format 'YYYY-MM-DD HH:MM:SS PM' (UTC). Default is now. (optional)"), argument("--end_date", type=str, help="End date/time in format 'YYYY-MM-DD HH:MM:SS PM' (UTC). Default is contract's end. (optional)"), argument("--day", type=parse_day_cron_style, help="Day of week you want scheduled job to run on (0-6, where 0=Sunday) or \"*\". Default will be 0. For ex. --day 0", default=0), argument("--hour", type=parse_hour_cron_style, help="Hour of day you want scheduled job to run on (0-23) or \"*\" (UTC). Default will be 0. For ex. --hour 16", default=0), usage="vastai cloud copy --src SRC --dst DST --instance INSTANCE_ID -connection CONNECTION_ID --transfer TRANSFER_TYPE", help="Copy files/folders to and from cloud providers", epilog=deindent(""" Copies a directory from a source location to a target location. Each of source and destination directories can be either local or remote, subject to appropriate read and write permissions required to carry out the action. The format for both src and dst is [instance_id:]path. You can find more information about the cloud copy operation here: https://vast.ai/docs/gpu-instances/cloud-sync Examples: vastai show connections ID NAME Cloud Type 1001 test_dir drive 1003 data_dir drive vastai cloud copy --src /folder --dst /workspace --instance 6003036 --connection 1001 --transfer "Instance To Cloud" The example copies all contents of /folder into /workspace on instance 6003036 from gdrive connection 'test_dir'. """), ) def cloud__copy(args: argparse.Namespace): """ Transfer data from one instance to another. @param src: Location of data object to be copied. @param dst: Target to copy object to. """ url = apiurl(args, f"/commands/rclone/") #(src_id, src_path) = parse_vast_url(args.src) #(dst_id, dst_path) = parse_vast_url(args.dst) if (args.src is None) and (args.dst is None): print("invalid arguments") return # Initialize an empty list for flags flags = [] # Append flags to the list based on the argparse.Namespace if args.dry_run: flags.append("--dry-run") if args.size_only: flags.append("--size-only") if args.ignore_existing: flags.append("--ignore-existing") if args.update: flags.append("--update") if args.delete_excluded: flags.append("--delete-excluded") print(f"copying {args.src} {args.dst} {args.instance} {args.connection} {args.transfer}") req_json = { "src": args.src, "dst": args.dst, "instance_id": args.instance, "selected": args.connection, "transfer": args.transfer, "flags": flags } if (args.explain): print("request json: ") print(req_json) if (args.schedule): validate_frequency_values(args.day, args.hour, args.schedule) req_url = apiurl(args, "/instances/{id}/".format(id=args.instance) , {"owner": "me"} ) r = http_get(args, req_url) r.raise_for_status() row = r.json()["instances"] if args.transfer.lower() == "instance to cloud": if row: # Get the cost per TB of internet upload up_cost = row.get("internet_up_cost_per_tb", None) if up_cost is not None: confirm = input( f"Internet upload cost is ${up_cost} per TB. " "Are you sure you want to schedule a cloud backup? (y/n): " ).strip().lower() if confirm != "y": print("Cloud backup scheduling aborted.") return else: print("Warning: Could not retrieve internet upload cost. Proceeding without confirmation. You can use show scheduled-jobs and delete scheduled-job commands to delete scheduled cloud backup job.") cli_command = "cloud copy" api_endpoint = "/api/v0/commands/rclone/" contract_end_date = row.get("end_date", None) add_scheduled_job(args, req_json, cli_command, api_endpoint, "POST", instance_id=args.instance, contract_end_date=contract_end_date) return else: print("Instance not found. Please check the instance ID.") return r = http_post(args, url, headers=headers,json=req_json) r.raise_for_status() if (r.status_code == 200): print("Cloud Copy Started - check instance status bar for progress updates (~30 seconds delayed).") print("When the operation is finished you should see 'Cloud Copy Operation Finished' in the instance status bar.") else: print(r.text); print("failed with error {r.status_code}".format(**locals())); @parser.command( argument("instance_id", help="instance_id of the container instance to snapshot", type=str), argument("--container_registry", help="Container registry to push the snapshot to. Default will be docker.io", type=str, default="docker.io"), argument("--repo", help="repo to push the snapshot to", type=str), argument("--docker_login_user",help="Username for container registry with repo", type=str), argument("--docker_login_pass",help="Password or token for container registry with repo", type=str), argument("--pause", help="Pause container's processes being executed by the CPU to take snapshot (true/false). Default will be true", type=str, default="true"), usage="vastai take snapshot INSTANCE_ID " "--repo REPO --docker_login_user USER --docker_login_pass PASS" "[--container_registry REGISTRY] [--pause true|false]", help="Schedule a snapshot of a running container and push it to your repo in a container registry", epilog=deindent(""" Takes a snapshot of a running container instance and pushes snapshot to the specified repository in container registry. Use pause=true to pause the container during commit (safer but slower), or pause=false to leave it running (faster but may produce a filesystem- // safer snapshot). """), ) def take__snapshot(args: argparse.Namespace): """ Take a container snapshot and push. @param instance_id: instance identifier. @param repo: Docker repository for the snapshot. @param container_registry: Container registry @param docker_login_user: Docker registry username. @param docker_login_pass: Docker registry password/token. @param pause: "true" or "false" to pause the container during commit. """ instance_id = args.instance_id repo = args.repo container_registry = args.container_registry user = args.docker_login_user password = args.docker_login_pass pause_flag = args.pause print(f"Taking snapshot for instance {instance_id} and pushing to repo {repo} in container registry {container_registry}") req_json = { "id": instance_id, "container_registry": container_registry, "personal_repo": repo, "docker_login_user":user, "docker_login_pass":password, "pause": pause_flag } url = apiurl(args, f"/instances/take_snapshot/{instance_id}/") if args.explain: print("Request JSON:") print(json.dumps(req_json, indent=2)) # POST to the snapshot endpoint r = http_post(args, url, headers=headers, json=req_json) r.raise_for_status() if r.status_code == 200: data = r.json() if data.get("success"): print(f"Snapshot request sent successfully. Please check your repo {repo} in container registry {container_registry} in 5-10 mins. It can take longer than 5-10 mins to push your snapshot image to your repo depending on the size of your image.") else: print(data.get("msg", "Unknown error with snapshot request")) else: print(r.text); print("failed with error {r.status_code}".format(**locals())); def validate_frequency_values(day_of_the_week, hour_of_the_day, frequency): # Helper to raise an error with a consistent message. def raise_frequency_error(): msg = "" if frequency == "HOURLY": msg += "For HOURLY jobs, day and hour must both be \"*\"." elif frequency == "DAILY": msg += "For DAILY jobs, day must be \"*\" and hour must have a value between 0-23." elif frequency == "WEEKLY": msg += "For WEEKLY jobs, day must have a value between 0-6 and hour must have a value between 0-23." sys.exit(msg) if frequency == "HOURLY": if not (day_of_the_week is None and hour_of_the_day is None): raise_frequency_error() if frequency == "DAILY": if not (day_of_the_week is None and hour_of_the_day is not None): raise_frequency_error() if frequency == "WEEKLY": if not (day_of_the_week is not None and hour_of_the_day is not None): raise_frequency_error() def add_scheduled_job(args, req_json, cli_command, api_endpoint, request_method, instance_id, contract_end_date): start_timestamp, end_timestamp = convert_dates_to_timestamps(args) if args.end_date is None: end_timestamp=contract_end_date args.end_date = convert_timestamp_to_date(contract_end_date) if start_timestamp >= end_timestamp: raise ValueError("--start_date must be less than --end_date.") day, hour, frequency = args.day, args.hour, args.schedule schedule_job_url = apiurl(args, f"/commands/schedule_job/") request_body = { "start_time": start_timestamp, "end_time": end_timestamp, "api_endpoint": api_endpoint, "request_method": request_method, "request_body": req_json, "day_of_the_week": day, "hour_of_the_day": hour, "frequency": frequency, "instance_id": instance_id } # Send a POST request response = requests.post(schedule_job_url, headers=headers, json=request_body) if args.explain: print("request json: ") print(request_body) # Handle the response based on the status code if response.status_code == 200: print(f"add_scheduled_job insert: success - Scheduling {frequency} job to {cli_command} from {args.start_date} UTC to {args.end_date} UTC") elif response.status_code == 401: print(f"add_scheduled_job insert: failed status_code: {response.status_code}. It could be because you aren't using a valid api_key.") elif response.status_code == 422: user_input = input("Existing scheduled job found. Do you want to update it (y|n)? ") if user_input.strip().lower() == "y": scheduled_job_id = response.json()["scheduled_job_id"] schedule_job_url = apiurl(args, f"/commands/schedule_job/{scheduled_job_id}/") response = update_scheduled_job(cli_command, schedule_job_url, frequency, args.start_date, args.end_date, request_body) else: print("Job update aborted by the user.") else: # print(r.text) print(f"add_scheduled_job insert: failed error: {response.status_code}. Response body: {response.text}") def update_scheduled_job(cli_command, schedule_job_url, frequency, start_date, end_date, request_body): response = requests.put(schedule_job_url, headers=headers, json=request_body) # Raise an exception for HTTP errors response.raise_for_status() if response.status_code == 200: print(f"add_scheduled_job update: success - Scheduling {frequency} job to {cli_command} from {start_date} UTC to {end_date} UTC") print(response.json()) elif response.status_code == 401: print(f"add_scheduled_job update: failed status_code: {response.status_code}. It could be because you aren't using a valid api_key.") else: # print(r.text) print(f"add_scheduled_job update: failed status_code: {response.status_code}.") print(response.json()) return response @parser.command( argument("--name", help="name of the api-key", type=str), argument("--permission_file", help="file path for json encoded permissions, see https://vast.ai/docs/cli/roles-and-permissions for more information", type=str), argument("--key_params", help="optional wildcard key params for advanced keys", type=str), usage="vastai create api-key --name NAME --permission_file PERMISSIONS", help="Create a new api-key with restricted permissions. Can be sent to other users and teammates", epilog=deindent(""" In order to create api keys you must understand how permissions must be sent via json format. You can find more information about permissions here: https://vast.ai/docs/cli/roles-and-permissions """) ) def create__api_key(args): try: url = apiurl(args, "/auth/apikeys/") permissions = load_permissions_from_file(args.permission_file) r = http_post(args, url, headers=headers, json={"name": args.name, "permissions": permissions, "key_params": args.key_params}) r.raise_for_status() print("api-key created {}".format(r.json())) except FileNotFoundError: print("Error: Permission file '{}' not found.".format(args.permission_file)) except requests.exceptions.RequestException as e: print("Error: Failed to create api-key. Reason: {}".format(e)) except Exception as e: print("An unexpected error occurred:", e) @parser.command( argument("subnet", help="local subnet for cluster, ex: '0.0.0.0/24'", type=str), argument("manager_id", help="Machine ID of manager node in cluster. Must exist already.", type=int), usage="vastai create cluster SUBNET MANAGER_ID", help="Create Vast cluster", epilog=deindent(""" Create Vast Cluster by defining a local subnet and manager id.""") ) def create__cluster(args: argparse.Namespace): json_blob = { "subnet": args.subnet, "manager_id": args.manager_id } #TODO: this should happen at the decorator level for all CLI commands to reduce boilerplate if args.explain: print("request json: ") print(json_blob) req_url = apiurl(args, "/cluster/") r = http_post(args, req_url, json=json_blob) r.raise_for_status() if args.raw: return r print(r.json()["msg"]) @parser.command( argument("name", help="Environment variable name", type=str), argument("value", help="Environment variable value", type=str), usage="vastai create env-var ", help="Create a new user environment variable", ) def create__env_var(args): """Create a new environment variable for the current user.""" url = apiurl(args, "/secrets/") data = {"key": args.name, "value": args.value} r = http_post(args, url, headers=headers, json=data) r.raise_for_status() result = r.json() if result.get("success"): print(result.get("msg", "Environment variable created successfully.")) else: print(f"Failed to create environment variable: {result.get('msg', 'Unknown error')}") @parser.command( argument("ssh_key", help="add your existing ssh public key to your account (from the .pub file). If no public key is provided, a new key pair will be generated.", type=str, nargs='?'), argument("-y", "--yes", help="automatically answer yes to prompts", action="store_true"), usage="vastai create ssh-key [ssh_public_key] [-y]", help="Create a new ssh-key", epilog=deindent(""" You may use this command to add an existing public key, or create a new ssh key pair and add that public key, to your Vast account. If you provide an ssh_public_key.pub argument, that public key will be added to your Vast account. All ssh public keys should be in OpenSSH format. Example: $vastai create ssh-key 'ssh_public_key.pub' If you don't provide an ssh_public_key.pub argument, a new Ed25519 key pair will be generated. Example: $vastai create ssh-key The generated keys are saved as ~/.ssh/id_ed25519 (private) and ~/.ssh/id_ed25519.pub (public). Any existing id_ed25519 keys are backed up as .backup_. The public key will be added to your Vast account. All ssh public keys are stored in your Vast account and can be used to connect to instances they've been added to. """) ) def create__ssh_key(args): ssh_key_content = args.ssh_key # If no SSH key provided, generate one if not ssh_key_content: ssh_key_content = generate_ssh_key(args.yes) else: print("Adding provided SSH public key to account...") # Send the SSH key to the API url = apiurl(args, "/ssh/") r = http_post(args, url, headers=headers, json={"ssh_key": ssh_key_content}) r.raise_for_status() # Print json response print("ssh-key created {}\nNote: You may need to add the new public key to any pre-existing instances".format(r.json())) def generate_ssh_key(auto_yes=False): """ Generate a new SSH key pair using ssh-keygen and return the public key content. Args: auto_yes (bool): If True, automatically answer yes to prompts Returns: str: The content of the generated public key Raises: SystemExit: If ssh-keygen is not available or key generation fails """ print("No SSH key provided. Generating a new SSH key pair and adding public key to account...") # Define paths ssh_dir = Path.home() / '.ssh' private_key_path = ssh_dir / 'id_ed25519' public_key_path = ssh_dir / 'id_ed25519.pub' # Create .ssh directory if it doesn't exist try: ssh_dir.mkdir(mode=0o700, exist_ok=True) except OSError as e: print(f"Error creating .ssh directory: {e}", file=sys.stderr) sys.exit(1) # Check if any part of the key pair already exists and backup if needed if private_key_path.exists() or public_key_path.exists(): print(f"An SSH key pair 'id_ed25519' already exists in {ssh_dir}") if auto_yes: print("Auto-answering yes to backup existing key pair.") response = 'y' else: response = input("Would you like to generate a new key pair and backup your existing id_ed25519 key pair. [y/N]: ").lower() if response not in ['y', 'yes']: print("Aborted. No new key generated.") sys.exit(0) # Generate timestamp for backup timestamp = int(time.time()) backup_private_path = ssh_dir / f'id_ed25519.backup_{timestamp}' backup_public_path = ssh_dir / f'id_ed25519.pub.backup_{timestamp}' try: # Backup existing private key if it exists if private_key_path.exists(): private_key_path.rename(backup_private_path) print(f"Backed up existing private key to: {backup_private_path}") # Backup existing public key if it exists if public_key_path.exists(): public_key_path.rename(backup_public_path) print(f"Backed up existing public key to: {backup_public_path}") except OSError as e: print(f"Error backing up existing SSH keys: {e}", file=sys.stderr) sys.exit(1) print("Generating new SSH key pair and adding public key to account...") # Check if ssh-keygen is available try: subprocess.run(['ssh-keygen', '--help'], capture_output=True, check=False) except FileNotFoundError: print("Error: ssh-keygen not found. Please install OpenSSH client tools.", file=sys.stderr) sys.exit(1) # Generate the SSH key pair try: cmd = [ 'ssh-keygen', '-t', 'ed25519', # Ed25519 key type '-f', str(private_key_path), # Output file path '-N', '', # Empty passphrase '-C', f'{os.getenv("USER") or os.getenv("USERNAME", "user")}-vast.ai' # User ] result = subprocess.run( cmd, capture_output=True, text=True, input='y\n', # Automatically answer 'yes' to overwrite prompts check=True ) except subprocess.CalledProcessError as e: print(f"Error generating SSH key: {e}", file=sys.stderr) if e.stderr: print(f"ssh-keygen error: {e.stderr}", file=sys.stderr) sys.exit(1) except Exception as e: print(f"Unexpected error during key generation: {e}", file=sys.stderr) sys.exit(1) # Set proper permissions for the private key try: private_key_path.chmod(0o600) # Read/write for owner only except OSError as e: print(f"Warning: Could not set permissions for private key: {e}", file=sys.stderr) # Read and return the public key content try: with open(public_key_path, 'r') as f: public_key_content = f.read().strip() return public_key_content except IOError as e: print(f"Error reading generated public key: {e}", file=sys.stderr) sys.exit(1) @parser.command( argument("--template_hash", help="template hash (required, but **Note**: if you use this field, you can skip search_params, as they are automatically inferred from the template)", type=str), argument("--template_id", help="template id (optional)", type=int), argument("-n", "--no-default", action="store_true", help="Disable default search param query args"), argument("--launch_args", help="launch args string for create instance ex: \"--onstart onstart_wget.sh --env '-e ONSTART_PATH=https://s3.amazonaws.com/vast.ai/onstart_OOBA.sh' --image atinoda/text-generation-webui:default-nightly --disk 64\"", type=str), argument("--endpoint_name", help="deployment endpoint name (allows multiple workergroups to share same deployment endpoint)", type=str), argument("--endpoint_id", help="deployment endpoint id (allows multiple workergroups to share same deployment endpoint)", type=int), argument("--test_workers",help="number of workers to create to get an performance estimate for while initializing workergroup (default 3)", type=int, default=3), argument("--gpu_ram", help="estimated GPU RAM req (independent of search string)", type=float), argument("--search_params", help="search param string for search offers ex: \"gpu_ram>=23 num_gpus=2 gpu_name=RTX_4090 inet_down>200 direct_port_count>2 disk_space>=64\"", type=str), argument("--min_load", help="[NOTE: this field isn't currently used at the workergroup level] minimum floor load in perf units/s (token/s for LLms)", type=float), argument("--target_util", help="[NOTE: this field isn't currently used at the workergroup level] target capacity utilization (fraction, max 1.0, default 0.9)", type=float), argument("--cold_mult", help="[NOTE: this field isn't currently used at the workergroup level]cold/stopped instance capacity target as multiple of hot capacity target (default 2.0)", type=float), argument("--cold_workers", help="min number of workers to keep 'cold' for this workergroup", type=int), argument("--auto_instance", help=argparse.SUPPRESS, type=str, default="prod"), usage="vastai workergroup create [OPTIONS]", help="Create a new autoscale group", epilog=deindent(""" Create a new autoscaling group to manage a pool of worker instances. Example: vastai create workergroup --template_hash HASH --endpoint_name "LLama" --test_workers 5 """), ) def create__workergroup(args): url = apiurl(args, "/autojobs/" ) # if args.launch_args_dict: # launch_args_dict = json.loads(args.launch_args_dict) # json_blob = {"client_id": "me", "min_load": args.min_load, "target_util": args.target_util, "cold_mult": args.cold_mult, "template_hash": args.template_hash, "template_id": args.template_id, "search_params": args.search_params, "launch_args_dict": launch_args_dict, "gpu_ram": args.gpu_ram, "endpoint_name": args.endpoint_name} if args.no_default: query = "" else: query = " verified=True rentable=True rented=False" #query = {"verified": {"eq": True}, "external": {"eq": False}, "rentable": {"eq": True}, "rented": {"eq": False}} search_params = (args.search_params if args.search_params is not None else "" + query).strip() json_blob = {"client_id": "me", "min_load": args.min_load, "target_util": args.target_util, "cold_mult": args.cold_mult, "cold_workers" : args.cold_workers, "test_workers" : args.test_workers, "template_hash": args.template_hash, "template_id": args.template_id, "search_params": search_params, "launch_args": args.launch_args, "gpu_ram": args.gpu_ram, "endpoint_name": args.endpoint_name, "endpoint_id": args.endpoint_id, "autoscaler_instance": args.auto_instance} if (args.explain): print("request json: ") print(json_blob) r = http_post(args, url, headers=headers,json=json_blob) r.raise_for_status() if 'application/json' in r.headers.get('Content-Type', ''): try: print("workergroup create {}".format(r.json())) except requests.exceptions.JSONDecodeError: print("The response is not valid JSON.") print(r) print(r.text) # Print the raw response to help with debugging. else: print("The response is not JSON. Content-Type:", r.headers.get('Content-Type')) print(r.text) @parser.command( argument("--min_load", help="minimum floor load in perf units/s (token/s for LLms)", type=float, default=0.0), argument("--min_cold_load", help="minimum floor load in perf units/s (token/s for LLms), but allow handling with cold workers", type=float, default=0.0), argument("--target_util", help="target capacity utilization (fraction, max 1.0, default 0.9)", type=float, default=0.9), argument("--cold_mult", help="cold/stopped instance capacity target as multiple of hot capacity target (default 2.5)", type=float, default=2.5), argument("--cold_workers", help="min number of workers to keep 'cold' when you have no load (default 5)", type=int, default=5), argument("--max_workers", help="max number of workers your endpoint group can have (default 20)", type=int, default=20), argument("--endpoint_name", help="deployment endpoint name (allows multiple autoscale groups to share same deployment endpoint)", type=str), argument("--auto_instance", help=argparse.SUPPRESS, type=str, default="prod"), usage="vastai create endpoint [OPTIONS]", help="Create a new endpoint group", epilog=deindent(""" Create a new endpoint group to manage many autoscaling groups Example: vastai create endpoint --target_util 0.9 --cold_mult 2.0 --endpoint_name "LLama" """), ) def create__endpoint(args): url = apiurl(args, "/endptjobs/" ) json_blob = {"client_id": "me", "min_load": args.min_load, "min_cold_load":args.min_cold_load, "target_util": args.target_util, "cold_mult": args.cold_mult, "cold_workers" : args.cold_workers, "max_workers" : args.max_workers, "endpoint_name": args.endpoint_name, "autoscaler_instance": args.auto_instance} if (args.explain): print("request json: ") print(json_blob) r = requests.post(url, headers=headers,json=json_blob) r.raise_for_status() if 'application/json' in r.headers.get('Content-Type', ''): try: print("create endpoint {}".format(r.json())) except requests.exceptions.JSONDecodeError: print("The response is not valid JSON.") print(r) print(r.text) # Print the raw response to help with debugging. else: print("The response is not JSON. Content-Type:", r.headers.get('Content-Type')) print(r.text) def get_runtype(args): runtype = 'ssh' if args.args: runtype = 'args' if (args.args == '') or (args.args == ['']) or (args.args == []): runtype = 'args' args.args = None if not args.jupyter and (args.jupyter_dir or args.jupyter_lab): args.jupyter = True if args.jupyter and runtype == 'args': print("Error: Can't use --jupyter and --args together. Try --onstart or --onstart-cmd instead of --args.", file=sys.stderr) return 1 if args.jupyter: runtype = 'jupyter_direc ssh_direc ssh_proxy' if args.direct else 'jupyter_proxy ssh_proxy' elif args.ssh: runtype = 'ssh_direc ssh_proxy' if args.direct else 'ssh_proxy' return runtype def validate_volume_params(args): if args.volume_size and not args.create_volume: raise argparse.ArgumentTypeError("Error: --volume-size can only be used with --create-volume. Please specify a volume ask ID to create a new volume of that size.") if (args.create_volume or args.link_volume) and not args.mount_path: raise argparse.ArgumentTypeError("Error: --mount-path is required when creating or linking a volume.") # This regex matches absolute or relative Linux file paths (no null bytes) valid_linux_path_regex = re.compile(r'^(/)?([^/\0]+(/)?)+$') if not valid_linux_path_regex.match(args.mount_path): raise argparse.ArgumentTypeError(f"Error: --mount-path '{args.mount_path}' is not a valid Linux file path.") volume_info = { "mount_path": args.mount_path, "create_new": True if args.create_volume else False, "volume_id": args.create_volume if args.create_volume else args.link_volume } if args.volume_label: volume_info["name"] = args.volume_label if args.volume_size: volume_info["size"] = args.volume_size elif args.create_volume: # If creating a new volume and size is not passed in, default size is 15GB volume_info["size"] = 15 return volume_info def validate_portal_config(json_blob): # jupyter runtypes already self-correct if 'jupyter' in json_blob['runtype']: return # remove jupyter configs from portal_config if not a jupyter runtype portal_config = json_blob['env']['PORTAL_CONFIG'].split("|") filtered_config = [config_str for config_str in portal_config if 'jupyter' not in config_str.lower()] if not filtered_config: raise ValueError("Error: env variable PORTAL_CONFIG must contain at least one non-jupyter related config string if runtype is not jupyter") else: json_blob['env']['PORTAL_CONFIG'] = "|".join(filtered_config) @parser.command( argument("id", help="id of instance type to launch (returned from search offers)", type=int), argument("--template_hash", help="Create instance from template info", type=str), argument("--user", help="User to use with docker create. This breaks some images, so only use this if you are certain you need it.", type=str), argument("--disk", help="size of local disk partition in GB", type=float, default=10), argument("--image", help="docker container image to launch", type=str), argument("--login", help="docker login arguments for private repo authentication, surround with '' ", type=str), argument("--label", help="label to set on the instance", type=str), argument("--onstart", help="filename to use as onstart script", type=str), argument("--onstart-cmd", help="contents of onstart script as single argument", type=str), argument("--entrypoint", help="override entrypoint for args launch instance", type=str), argument("--ssh", help="Launch as an ssh instance type", action="store_true"), argument("--jupyter", help="Launch as a jupyter instance instead of an ssh instance", action="store_true"), argument("--direct", help="Use (faster) direct connections for jupyter & ssh", action="store_true"), argument("--jupyter-dir", help="For runtype 'jupyter', directory in instance to use to launch jupyter. Defaults to image's working directory", type=str), argument("--jupyter-lab", help="For runtype 'jupyter', Launch instance with jupyter lab", action="store_true"), argument("--lang-utf8", help="Workaround for images with locale problems: install and generate locales before instance launch, and set locale to C.UTF-8", action="store_true"), argument("--python-utf8", help="Workaround for images with locale problems: set python's locale to C.UTF-8", action="store_true"), argument("--extra", help=argparse.SUPPRESS), argument("--env", help="env variables and port mapping options, surround with '' ", type=str), argument("--args", nargs=argparse.REMAINDER, help="list of arguments passed to container ENTRYPOINT. Onstart is recommended for this purpose. (must be last argument)"), #argument("--create-from", help="Existing instance id to use as basis for new instance. Instance configuration should usually be identical, as only the difference from the base image is copied.", type=str), argument("--force", help="Skip sanity checks when creating from an existing instance", action="store_true"), argument("--cancel-unavail", help="Return error if scheduling fails (rather than creating a stopped instance)", action="store_true"), argument("--bid_price", help="(OPTIONAL) create an INTERRUPTIBLE instance with per machine bid price in $/hour", type=float), argument("--create-volume", metavar="VOLUME_ASK_ID", help="Create a new local volume using an ID returned from the \"search volumes\" command and link it to the new instance", type=int), argument("--link-volume", metavar="EXISTING_VOLUME_ID", help="ID of an existing rented volume to link to the instance during creation. (returned from \"show volumes\" cmd)", type=int), argument("--volume-size", help="Size of the volume to create in GB. Only usable with --create-volume (default 15GB)", type=int), argument("--mount-path", help="The path to the volume from within the new instance container. e.g. /root/volume", type=str), argument("--volume-label", help="(optional) A name to give the new volume. Only usable with --create-volume", type=str), usage="vastai create instance ID [OPTIONS] [--args ...]", help="Create a new instance", epilog=deindent(""" Performs the same action as pressing the "RENT" button on the website at https://console.vast.ai/create/ Creates an instance from an offer ID (which is returned from "search offers"). Each offer ID can only be used to create one instance. Besides the offer ID, you must pass in an '--image' argument as a minimum. If you use args/entrypoint launch mode, we create a container from your image as is, without attempting to inject ssh and or jupyter. If you use the args launch mode, you can override the entrypoint with --entrypoint, and pass arguments to the entrypoint with --args. If you use --args, that must be the last argument, as any following tokens are consumed into the args string. For ssh/jupyter launch types, use --onstart-cmd to pass in startup script, instead of --entrypoint and --args. Examples: # create an on-demand instance with the PyTorch (cuDNN Devel) template and 64GB of disk vastai create instance 384826 --template_hash 661d064bbda1f2a133816b6d55da07c3 --disk 64 # create an on-demand instance with the pytorch/pytorch image, 40GB of disk, open 8081 udp, direct ssh, set hostname to billybob, and a small onstart script vastai create instance 6995713 --image pytorch/pytorch --disk 40 --env '-p 8081:8081/udp -h billybob' --ssh --direct --onstart-cmd "env | grep _ >> /etc/environment; echo 'starting up'"; # create an on-demand instance with the bobsrepo/pytorch:latest image, 20GB of disk, open 22, 8080, jupyter ssh, and set some env variables vastai create instance 384827 --image bobsrepo/pytorch:latest --login '-u bob -p 9d8df!fd89ufZ docker.io' --jupyter --direct --env '-e TZ=PDT -e XNAME=XX4 -p 22:22 -p 8080:8080' --disk 20 # create an on-demand instance with the pytorch/pytorch image, 40GB of disk, override the entrypoint to bash and pass bash a simple command to keep the instance running. (args launch without ssh/jupyter) vastai create instance 5801802 --image pytorch/pytorch --disk 40 --onstart-cmd 'bash' --args -c 'echo hello; sleep infinity;' # create an interruptible (spot) instance with the PyTorch (cuDNN Devel) template, 64GB of disk, and a bid price of $0.10/hr vastai create instance 384826 --template_hash 661d064bbda1f2a133816b6d55da07c3 --disk 64 --bid_price 0.1 Return value: Returns a json reporting the instance ID of the newly created instance: {'success': True, 'new_contract': 7835610} """), ) def create__instance(args: argparse.Namespace): """Performs the same action as pressing the "RENT" button on the website at https://console.vast.ai/create/. :param argparse.Namespace args: Namespace with many fields relevant to the endpoint. """ if args.onstart: with open(args.onstart, "r") as reader: args.onstart_cmd = reader.read() if args.onstart_cmd is None: args.onstart_cmd = args.entrypoint runtype = None json_blob ={ "client_id": "me", "image": args.image, "env" : parse_env(args.env), "price": args.bid_price, "disk": args.disk, "label": args.label, "extra": args.extra, "onstart": args.onstart_cmd, "image_login": args.login, "python_utf8": args.python_utf8, "lang_utf8": args.lang_utf8, "use_jupyter_lab": args.jupyter_lab, "jupyter_dir": args.jupyter_dir, #"create_from": args.create_from, "force": args.force, "cancel_unavail": args.cancel_unavail, "template_hash_id" : args.template_hash, "user": args.user } if args.create_volume or args.link_volume: volume_info = validate_volume_params(args) json_blob["volume_info"] = volume_info if args.template_hash is None: runtype = get_runtype(args) if runtype == 1: return 1 json_blob["runtype"] = runtype if (args.args != None): json_blob["args"] = args.args if "PORTAL_CONFIG" in json_blob["env"]: validate_portal_config(json_blob) #print(f"put asks/{args.id}/ runtype:{runtype}") url = apiurl(args, "/asks/{id}/".format(id=args.id)) if (args.explain): print("request json: ") print(json_blob) r = http_put(args, url, headers=headers,json=json_blob) r.raise_for_status() if args.raw: return r else: print("Started. {}".format(r.json())) @parser.command( argument("--email", help="email address to use for login", type=str), argument("--username", help="username to use for login", type=str), argument("--password", help="password to use for login", type=str), argument("--type", help="host/client", type=str), usage="vastai create subaccount --email EMAIL --username USERNAME --password PASSWORD --type TYPE", help="Create a subaccount", epilog=deindent(""" Creates a new account that is considered a child of your current account as defined via the API key. vastai create subaccount --email bob@gmail.com --username bob --password password --type host vastai create subaccount --email vast@gmail.com --username vast --password password --type host """), ) def create__subaccount(args): """Creates a new account that is considered a child of your current account as defined via the API key. """ # Default value for host_only, can adjust based on expected default behavior host_only = False # Only process the --account_type argument if it's provided if args.type: host_only = args.type.lower() == "host" json_blob = { "email": args.email, "username": args.username, "password": args.password, "host_only": host_only, "parent_id": "me" } # Use --explain to print the request JSON and return early if getattr(args, 'explain', False): print("Request JSON would be: ") print(json_blob) return # Prevents execution of the actual API call # API call execution continues here if --explain is not used url = apiurl(args, "/users/") r = http_post(args, url, headers=headers, json=json_blob) r.raise_for_status() if r.status_code == 200: rj = r.json() print(rj) else: print(r.text) print(f"Failed with error {r.status_code}") @parser.command( argument("--team_name", help="name of the team", type=str), usage="vastai create-team --team_name TEAM_NAME", help="Create a new team", epilog=deindent(""" Creates a new team under your account. Unlike legacy teams, this command does NOT convert your personal account into a team. Each team is created as a separate account, and you can be a member of multiple teams. When you create a team: - You become the team owner. - The team starts as an independent account with its own billing, credits, and resources. - Default roles (owner, manager, member) are automatically created. - You can invite others, assign roles, and manage resources within the team. Optional: You can transfer a portion of your existing personal credits to the team by using the `--transfer_credit` flag. Example: vastai create-team --team_name myteam --transfer_credit 25 Notes: - You cannot create a team from within another team account. For more details, see: https://vast.ai/docs/teams-quickstart """) ) def create__team(args): url = apiurl(args, "/team/") r = http_post(args, url, headers=headers, json={"team_name": args.team_name}) r.raise_for_status() print(r.json()) @parser.command( argument("--name", help="name of the role", type=str), argument("--permissions", help="file path for json encoded permissions, look in the docs for more information", type=str), usage="vastai create team-role --name NAME --permissions PERMISSIONS", help="Add a new role to your team", epilog=deindent(""" Creating a new team role involves understanding how permissions must be sent via json format. You can find more information about permissions here: https://vast.ai/docs/cli/roles-and-permissions """) ) def create__team_role(args): url = apiurl(args, "/team/roles/") permissions = load_permissions_from_file(args.permissions) r = http_post(args, url, headers=headers, json={"name": args.name, "permissions": permissions}) r.raise_for_status() print(r.json()) def get_template_arguments(): return [ argument("--name", help="name of the template", type=str), argument("--image", help="docker container image to launch", type=str), argument("--image_tag", help="docker image tag (can also be appended to end of image_path)", type=str), argument("--href", help="link you want to provide", type=str), argument("--repo", help="link to repository", type=str), argument("--login", help="docker login arguments for private repo authentication, surround with ''", type=str), argument("--env", help="Contents of the 'Docker options' field", type=str), argument("--ssh", help="Launch as an ssh instance type", action="store_true"), argument("--jupyter", help="Launch as a jupyter instance instead of an ssh instance", action="store_true"), argument("--direct", help="Use (faster) direct connections for jupyter & ssh", action="store_true"), argument("--jupyter-dir", help="For runtype 'jupyter', directory in instance to use to launch jupyter. Defaults to image's working directory", type=str), argument("--jupyter-lab", help="For runtype 'jupyter', Launch instance with jupyter lab", action="store_true"), argument("--onstart-cmd", help="contents of onstart script as single argument", type=str), argument("--search_params", help="search offers filters", type=str), argument("-n", "--no-default", action="store_true", help="Disable default search param query args"), argument("--disk_space", help="disk storage space, in GB", type=str), argument("--readme", help="readme string", type=str), argument("--hide-readme", help="hide the readme from users", action="store_true"), argument("--desc", help="description string", type=str), argument("--public", help="make template available to public", action="store_true"), ] @parser.command( *get_template_arguments(), usage="vastai create template", help="Create a new template", epilog=deindent(""" Create a template that can be used to create instances with Example: vastai create template --name "tgi-llama2-7B-quantized" --image "ghcr.io/huggingface/text-generation-inference:1.0.3" --env "-p 3000:3000 -e MODEL_ARGS='--model-id TheBloke/Llama-2-7B-chat-GPTQ --quantize gptq'" --onstart-cmd 'wget -O - https://raw.githubusercontent.com/vast-ai/vast-pyworker/main/scripts/launch_tgi.sh | bash' --search_params "gpu_ram>=23 num_gpus=1 gpu_name=RTX_3090 inet_down>128 direct_port_count>3 disk_space>=192 driver_version>=535086005 rented=False" --disk_space 8.0 --ssh --direct """) ) def create__template(args): # url = apiurl(args, f"/users/0/templates/") url = apiurl(args, f"/template/") jup_direct = args.jupyter and args.direct ssh_direct = args.ssh and args.direct use_ssh = args.ssh or args.jupyter runtype = "jupyter" if args.jupyter else ("ssh" if args.ssh else "args") if args.login: login = args.login.split(" ") docker_login_repo = login[0] else: docker_login_repo = None default_search_query = {} if not args.no_default: default_search_query = {"verified": {"eq": True}, "external": {"eq": False}, "rentable": {"eq": True}, "rented": {"eq": False}} extra_filters = parse_query(args.search_params, default_search_query, offers_fields, offers_alias, offers_mult) template = { "name" : args.name, "image" : args.image, "tag" : args.image_tag, "href": args.href, "repo" : args.repo, "env" : args.env, #str format "onstart" : args.onstart_cmd, #don't accept file name for now "jup_direct" : jup_direct, "ssh_direct" : ssh_direct, "use_jupyter_lab" : args.jupyter_lab, "runtype" : runtype, "use_ssh" : use_ssh, "jupyter_dir" : args.jupyter_dir, "docker_login_repo" : docker_login_repo, #can't store username/password with template for now "extra_filters" : extra_filters, "recommended_disk_space" : args.disk_space, "readme": args.readme, "readme_visible": not args.hide_readme, "desc": args.desc, "private": not args.public, } if (args.explain): print("request json: ") print(template) r = http_post(args, url, headers=headers, json=template) r.raise_for_status() try: rj = r.json() if rj["success"]: print(f"New Template: {rj['template']}") else: print(rj['msg']) except requests.exceptions.JSONDecodeError: print("The response is not valid JSON.") @parser.command( argument("id", help="id of volume offer", type=int), argument("-s", "--size", help="size in GB of volume. Default %(default)s GB.", default=15, type=float), argument("-n", "--name", help="Optional name of volume.", type=str), usage="vastai create volume ID [options]", help="Create a new volume", epilog=deindent(""" Creates a volume from an offer ID (which is returned from "search volumes"). Each offer ID can be used to create multiple volumes, provided the size of all volumes does not exceed the size of the offer. """) ) def create__volume(args: argparse.Namespace): json_blob ={ "size": int(args.size), "id": int(args.id) } if args.name: json_blob["name"] = args.name url = apiurl(args, "/volumes/") if (args.explain): print("request json: ") print(json_blob) r = http_put(args, url, headers=headers,json=json_blob) r.raise_for_status() if args.raw: return r else: print("Created. {}".format(r.json())) @parser.command( argument("id", help="id of network volume offer", type=int), argument("-s", "--size", help="size in GB of network volume. Default %(default)s GB.", default=15, type=float), argument("-n", "--name", help="Optional name of network volume.", type=str), usage="vastai create network volume ID [options]", help="Create a new network volume", epilog=deindent(""" Creates a network volume from an offer ID (which is returned from "search network volumes"). Each offer ID can be used to create multiple volumes, provided the size of all volumes does not exceed the size of the offer. """) ) def create__network_volume(args: argparse.Namespace): json_blob ={ "size": int(args.size), "id": int(args.id) } if args.name: json_blob["name"] = args.name url = apiurl(args, "/network_volumes/") if (args.explain): print("request json: ") print(json_blob) r = http_put(args, url, headers=headers,json=json_blob) r.raise_for_status() if args.raw: return r else: print("Created. {}".format(r.json())) @parser.command( argument("cluster_id", help="ID of cluster to create overlay on top of", type=int), argument("name", help="overlay network name"), usage="vastai create overlay CLUSTER_ID OVERLAY_NAME", help="Creates overlay network on top of a physical cluster", epilog=deindent(""" Creates an overlay network to allow local networking between instances on a physical cluster""") ) def create__overlay(args: argparse.Namespace): json_blob = { "cluster_id": args.cluster_id, "name": args.name } if args.explain: print("request json:", json_blob) req_url = apiurl(args, "/overlay/") r = http_post(args, req_url, json=json_blob) r.raise_for_status() if args.raw: return r print(r.json()["msg"]) @parser.command( argument("id", help="id of apikey to remove", type=int), usage="vastai delete api-key ID", help="Remove an api-key", ) def delete__api_key(args): url = apiurl(args, "/auth/apikeys/{id}/".format(id=args.id)) r = http_del(args, url, headers=headers) r.raise_for_status() print(r.json()) @parser.command( argument("id", help="id ssh key to delete", type=int), usage="vastai delete ssh-key ID", help="Remove an ssh-key", ) def delete__ssh_key(args): url = apiurl(args, "/ssh/{id}/".format(id=args.id)) r = http_del(args, url, headers=headers) r.raise_for_status() print(r.json()) @parser.command( argument("id", help="id of scheduled job to remove", type=int), usage="vastai delete scheduled-job ID", help="Delete a scheduled job", ) def delete__scheduled_job(args): url = apiurl(args, "/commands/schedule_job/{id}/".format(id=args.id)) r = http_del(args, url, headers=headers) r.raise_for_status() print(r.json()) @parser.command( argument("cluster_id", help="ID of cluster to delete", type=int), usage="vastai delete cluster CLUSTER_ID", help="Delete Cluster", epilog=deindent(""" Delete Vast Cluster""") ) def delete__cluster(args: argparse.Namespace): json_blob = { "cluster_id": args.cluster_id } if args.explain: print("request json:", json_blob) req_url = apiurl(args, "/cluster/") r = http_del(args, req_url, json=json_blob) r.raise_for_status() if args.raw: return r print(r.json()["msg"]) @parser.command( argument("id", help="id of group to delete", type=int), usage="vastai delete workergroup ID ", help="Delete a workergroup group", epilog=deindent(""" Note that deleting a workergroup doesn't automatically destroy all the instances that are associated with your workergroup. Example: vastai delete workergroup 4242 """), ) def delete__workergroup(args): id = args.id url = apiurl(args, f"/autojobs/{id}/" ) json_blob = {"client_id": "me", "autojob_id": args.id} if (args.explain): print("request json: ") print(json_blob) r = http_del(args, url, headers=headers,json=json_blob) r.raise_for_status() if 'application/json' in r.headers.get('Content-Type', ''): try: print("workergroup delete {}".format(r.json())) except requests.exceptions.JSONDecodeError: print("The response is not valid JSON.") print(r) print(r.text) # Print the raw response to help with debugging. else: print("The response is not JSON. Content-Type:", r.headers.get('Content-Type')) print(r.text) @parser.command( argument("id", help="id of endpoint group to delete", type=int), usage="vastai delete endpoint ID ", help="Delete an endpoint group", epilog=deindent(""" Example: vastai delete endpoint 4242 """), ) def delete__endpoint(args): id = args.id url = apiurl(args, f"/endptjobs/{id}/" ) json_blob = {"client_id": "me", "endptjob_id": args.id} if (args.explain): print("request json: ") print(json_blob) r = http_del(args, url, headers=headers,json=json_blob) r.raise_for_status() if 'application/json' in r.headers.get('Content-Type', ''): try: print("delete endpoint {}".format(r.json())) except requests.exceptions.JSONDecodeError: print("The response is not valid JSON.") print(r) print(r.text) # Print the raw response to help with debugging. else: print("The response is not JSON. Content-Type:", r.headers.get('Content-Type')) print(r.text) @parser.command( argument("name", help="Environment variable name to delete", type=str), usage="vastai delete env-var ", help="Delete a user environment variable", ) def delete__env_var(args): """Delete an environment variable for the current user.""" url = apiurl(args, "/secrets/") data = {"key": args.name} r = http_del(args, url, headers=headers, json=data) r.raise_for_status() result = r.json() if result.get("success"): print(result.get("msg", "Environment variable deleted successfully.")) else: print(f"Failed to delete environment variable: {result.get('msg', 'Unknown error')}") @parser.command( argument("overlay_identifier", help="ID (int) or name (str) of overlay to delete", nargs="?"), usage="vastai delete overlay OVERLAY_IDENTIFIER", help="Deletes overlay and removes all of its associated instances" ) def delete__overlay(args: argparse.Namespace): identifier = args.overlay_identifier try: overlay_id = int(identifier) json_blob = { "overlay_id": overlay_id } except (ValueError, TypeError): json_blob = { "overlay_name": identifier } if args.explain: print("request json:", json_blob) req_url = apiurl(args, "/overlay/") r = http_del(args, req_url, json=json_blob) r.raise_for_status() if args.raw: return r print(r.json()["msg"]) @parser.command( argument("--template-id", help="Template ID of Template to Delete", type=int), argument("--hash-id", help="Hash ID of Template to Delete", type=str), usage="vastai delete template [--template-id | --hash-id ]", help="Delete a Template", epilog=deindent(""" Note: Deleting a template only removes the user's replationship to a template. It does not get destroyed Example: vastai delete template --template-id 12345 Example: vastai delete template --hash-id 49c538d097ad6437413b83711c9f61e8 """), ) def delete__template(args): url = apiurl(args, f"/template/" ) if args.hash_id: json_blob = { "hash_id": args.hash_id } elif args.template_id: json_blob = { "template_id": args.template_id } else: print('ERROR: Must Specify either Template ID or Hash ID to delete a template') return if (args.explain): print("request json: ") print(json_blob) print(args) print(url) r = http_del(args, url, headers=headers,json=json_blob) print(r) # r.raise_for_status() if 'application/json' in r.headers.get('Content-Type', ''): try: print(r.json()['msg']) except requests.exceptions.JSONDecodeError: print("The response is not valid JSON.") print(r) print(r.text) # Print the raw response to help with debugging. else: print("The response is not JSON. Content-Type:", r.headers.get('Content-Type')) print(r.text) @parser.command( argument("id", help="id of volume contract", type=int), usage="vastai delete volume ID", help="Delete a volume", epilog=deindent(""" Deletes volume with the given ID. All instances using the volume must be destroyed before the volume can be deleted. """) ) def delete__volume(args: argparse.Namespace): url = apiurl(args, "/volumes/", query_args={"id": args.id}) r = http_del(args, url, headers=headers) r.raise_for_status() if args.raw: return r else: print("Deleted. {}".format(r.json())) def destroy_instance(id,args): url = apiurl(args, "/instances/{id}/".format(id=id)) r = http_del(args, url, headers=headers,json={}) r.raise_for_status() if args.raw: return r elif (r.status_code == 200): rj = r.json(); if (rj["success"]): print("destroying instance {id}.".format(**(locals()))); else: print(rj["msg"]); else: print(r.text); print("failed with error {r.status_code}".format(**locals())); @parser.command( argument("id", help="id of instance to delete", type=int), usage="vastai destroy instance id [-h] [--api-key API_KEY] [--raw]", help="Destroy an instance (irreversible, deletes data)", epilog=deindent(""" Perfoms the same action as pressing the "DESTROY" button on the website at https://console.vast.ai/instances/ Example: vastai destroy instance 4242 """), ) def destroy__instance(args): """Perfoms the same action as pressing the "DESTROY" button on the website at https://console.vast.ai/instances/. :param argparse.Namespace args: should supply all the command-line options """ destroy_instance(args.id,args) @parser.command( argument("ids", help="ids of instance to destroy", type=int, nargs='+'), usage="vastai destroy instances [--raw] ", help="Destroy a list of instances (irreversible, deletes data)", ) def destroy__instances(args): """ """ for id in args.ids: destroy_instance(id, args) @parser.command( usage="vastai destroy team", help="Destroy your team", ) def destroy__team(args): url = apiurl(args, "/team/") r = http_del(args, url, headers=headers) r.raise_for_status() print(r.json()) @parser.command( argument("instance_id", help="id of the instance", type=int), argument("ssh_key_id", help="id of the key to detach to the instance", type=str), usage="vastai detach instance_id ssh_key_id", help="Detach an ssh key from an instance", epilog=deindent(""" Example: vastai detach 99999 12345 """) ) def detach__ssh(args): url = apiurl(args, "/instances/{id}/ssh/{ssh_key_id}/".format(id=args.instance_id, ssh_key_id=args.ssh_key_id)) r = http_del(args, url, headers=headers) r.raise_for_status() print(r.json()) @parser.command( argument("id", help="id of instance to execute on", type=int), argument("COMMAND", help="bash command surrounded by single quotes", type=str), argument("--schedule", choices=["HOURLY", "DAILY", "WEEKLY"], help="try to schedule a command to run hourly, daily, or monthly. Valid values are HOURLY, DAILY, WEEKLY For ex. --schedule DAILY"), argument("--start_date", type=str, default=default_start_date(), help="Start date/time in format 'YYYY-MM-DD HH:MM:SS PM' (UTC). Default is now. (optional)"), argument("--end_date", type=str, default=default_end_date(), help="End date/time in format 'YYYY-MM-DD HH:MM:SS PM' (UTC). Default is 7 days from now. (optional)"), argument("--day", type=parse_day_cron_style, help="Day of week you want scheduled job to run on (0-6, where 0=Sunday) or \"*\". Default will be 0. For ex. --day 0", default=0), argument("--hour", type=parse_hour_cron_style, help="Hour of day you want scheduled job to run on (0-23) or \"*\" (UTC). Default will be 0. For ex. --hour 16", default=0), usage="vastai execute id COMMAND", help="Execute a (constrained) remote command on a machine", epilog=deindent(""" Examples: vastai execute 99999 'ls -l -o -r' vastai execute 99999 'rm -r home/delete_this.txt' vastai execute 99999 'du -d2 -h' available commands: ls List directory contents rm Remote files or directories du Summarize device usage for a set of files Return value: Returns the output of the command which was executed on the instance, if successful. May take a few seconds to retrieve the results. """), ) def execute(args): """Execute a (constrained) remote command on a machine. :param argparse.Namespace args: should supply all the command-line options """ url = apiurl(args, "/instances/command/{id}/".format(id=args.id)) json_blob={"command": args.COMMAND} if (args.explain): print("request json: ") print(json_blob) r = http_put(args, url, headers=headers,json=json_blob ) r.raise_for_status() if (args.schedule): validate_frequency_values(args.day, args.hour, args.schedule) cli_command = "execute" api_endpoint = "/api/v0/instances/command/{id}/".format(id=args.id) json_blob["instance_id"] = args.id add_scheduled_job(args, json_blob, cli_command, api_endpoint, "PUT", instance_id=args.id) return if (r.status_code == 200): rj = r.json() if (rj["success"]): for i in range(0,30): time.sleep(0.3) url = rj["result_url"] r = requests.get(url) if (r.status_code == 200): filtered_text = r.text.replace(rj["writeable_path"], ''); print(filtered_text) break else: print(rj); else: print(r.text); print("failed with error {r.status_code}".format(**locals())); @parser.command( argument("id", help="id of endpoint group to fetch logs from", type=int), argument("--level", help="log detail level (0 to 3)", type=int, default=1), argument("--tail", help="", type=int, default=None), usage="vastai get endpt-logs ID [--api-key API_KEY]", help="Fetch logs for a specific serverless endpoint group", epilog=deindent(""" Example: vastai get endpt-logs 382 """), ) def get__endpt_logs(args): #url = apiurl(args, "/endptjobs/" ) if args.url == server_url_default: args.url = None url = (args.url or "https://run.vast.ai") + "/get_endpoint_logs/" json_blob = {"id": args.id, "api_key": args.api_key} if args.tail: json_blob["tail"] = args.tail if (args.explain): print(f"{url} with request json: ") print(json_blob) r = http_post(args, url, headers=headers,json=json_blob) r.raise_for_status() levels = {0 : "info0", 1: "info1", 2: "trace", 3: "debug"} if (r.status_code == 200): rj = None try: rj = r.json() except Exception as e: print(str(e)) print(r.text) if args.raw: # sort_keys return rj or r.text else: dbg_lvl = levels[args.level] if rj and dbg_lvl: print(rj[dbg_lvl]) #print(json.dumps(rj, indent=1, sort_keys=True)) else: print(r.text) @parser.command( argument("id", help="id of endpoint group to fetch logs from", type=int), argument("--level", help="log detail level (0 to 3)", type=int, default=1), argument("--tail", help="", type=int, default=None), usage="vastai get wrkgrp-logs ID [--api-key API_KEY]", help="Fetch logs for a specific serverless worker group group", epilog=deindent(""" Example: vastai get endpt-logs 382 """), ) def get__wrkgrp_logs(args): #url = apiurl(args, "/endptjobs/" ) if args.url == server_url_default: args.url = None url = (args.url or "https://run.vast.ai") + "/get_autogroup_logs/" json_blob = {"id": args.id, "api_key": args.api_key} if args.tail: json_blob["tail"] = args.tail if (args.explain): print(f"{url} with request json: ") print(json_blob) r = http_post(args, url, headers=headers,json=json_blob) r.raise_for_status() levels = {0 : "info0", 1: "info1", 2: "trace", 3: "debug"} if (r.status_code == 200): rj = None try: rj = r.json() except Exception as e: print(str(e)) print(r.text) if args.raw: # sort_keys return rj or r.text else: dbg_lvl = levels[args.level] if rj and dbg_lvl: print(rj[dbg_lvl]) #print(json.dumps(rj, indent=1, sort_keys=True)) else: print(r.text) @parser.command( argument("--email", help="email of user to be invited", type=str), argument("--role", help="role of user to be invited", type=str), usage="vastai invite member --email EMAIL --role ROLE", help="Invite a team member", ) def invite__member(args): url = apiurl(args, "/team/invite/", query_args={"email": args.email, "role": args.role}) r = http_post(args, url, headers=headers) r.raise_for_status() if (r.status_code == 200): print(f"successfully invited {args.email} to your current team") else: print(r.text); print(f"failed with error {r.status_code}") @parser.command( argument("cluster_id", help="ID of cluster to add machine to", type=int), argument("machine_ids", help="machine id(s) to join cluster", type=int, nargs="+"), usage="vastai join cluster CLUSTER_ID MACHINE_IDS", help="Join Machine to Cluster", epilog=deindent(""" Join's Machine to Vast Cluster """) ) def join__cluster(args: argparse.Namespace): json_blob = { "cluster_id": args.cluster_id, "machine_ids": args.machine_ids } if args.explain: print("request json:", json_blob) req_url = apiurl(args, "/cluster/") r = http_put(args, req_url, json=json_blob) r.raise_for_status() if args.raw: return r print(r.json()["msg"]) @parser.command( argument("name", help="Overlay network name to join instance to.", type=str), argument("instance_id", help="Instance ID to add to overlay.", type=int), usage="vastai join overlay OVERLAY_NAME INSTANCE_ID", help="Adds instance to an overlay network", epilog=deindent(""" Adds an instance to a compatible overlay network.""") ) def join__overlay(args: argparse.Namespace): json_blob = { "name": args.name, "instance_id": args.instance_id } if args.explain: print("request json:", json_blob) req_url = apiurl(args, "/overlay/") r = http_put(args, req_url, json=json_blob) r.raise_for_status() if args.raw: return r print(r.json()["msg"]) @parser.command( argument("id", help="id of instance to label", type=int), argument("label", help="label to set", type=str), usage="vastai label instance