Source code for nuc.nilauth

"""
nilauth client.
"""

import logging
from enum import StrEnum
import hashlib
from datetime import datetime, timedelta, timezone
from dataclasses import dataclass
import secrets
import json
from time import sleep
from typing import Any, Dict, List
import requests
from secp256k1 import PrivateKey, PublicKey

from nuc.payer import Payer
from nuc.envelope import NucTokenEnvelope
from nuc.builder import NucTokenBuilder
from nuc.token import Command, Did, InvocationBody

logger = logging.getLogger(__name__)


DEFAULT_REQUEST_TIMEOUT: float = 10
PAYMENT_TX_RETRIES: List[int] = [1, 2, 3, 5, 10, 10, 10]
TX_NOT_COMMITTED_ERROR_CODE: str = "TRANSACTION_NOT_COMMITTED"


[docs] class BlindModule(StrEnum): """ A Nillion blind module. """ NILAI = "nilai" """The nilai blind module""" NILDB = "nildb" """The nildb blind module"""
[docs] @dataclass class NilauthAbout: """ Information about a nilauth server. """ public_key: PublicKey """ The server's public key. """
[docs] @dataclass class SubscriptionDetails: """ Information about a subscription. """ expires_at: datetime """ The timestamp at which this subscription expires. """ renewable_at: datetime """ The timestamp at which this subscription can be renewed. """
[docs] @dataclass class Subscription: """ Information about a subscription. """ subscribed: bool """ Whether there is an active subscription """ details: SubscriptionDetails | None """ The details about the subscription. """
[docs] @dataclass class RevokedToken: """ A revoked token. """ token_hash: bytes revoked_at: datetime
[docs] class NilauthClient: """ A class to interact with nilauth. Example ------- .. code-block:: py3 from secp256k1 import PrivateKey from nuc.nilauth import NilauthClient # Create a client to talk to nilauth at the given url. client = NilauthClient(base_url) # Create a private key. key = PrivateKey() # Request a token for it. token = client.request_token(key) """ def __init__(self, base_url: str, timeout_seconds=DEFAULT_REQUEST_TIMEOUT) -> None: """ Construct a new client to talk to nilauth. Arguments --------- base_url nilauth's URL. timeout_seconds The timeout to use for all requests. """ self._base_url = base_url self._timeout_seconds = timeout_seconds
[docs] def request_token(self, key: PrivateKey, blind_module: BlindModule) -> str: """ Request a token, issued to the public key tied to the given private key. Requesting tokens can only be done if a subscription has been paid for the blind module ahead of time. Arguments --------- key The key for which the token should be issued to. blind_module The blind module to get a token for. .. note:: The private key is only used to sign a payload to prove ownership and is never transmitted anywhere. """ public_key = self.about().public_key expires_at = datetime.now(timezone.utc) + timedelta(minutes=1) payload = { "nonce": secrets.token_bytes(16).hex(), "target_public_key": public_key.serialize().hex(), "expires_at": int(expires_at.timestamp()), "blind_module": str(blind_module), } request = self._create_signed_request(payload, key) response = self._post( f"{self._base_url}/api/v1/nucs/create", request, ) return response["token"]
[docs] def pay_subscription( self, pubkey: PublicKey, payer: Payer, blind_module: BlindModule ) -> None: """ Pay for a subscription for a blind module. Arguments --------- pubkey The public key the subscription is for. payer The payer that will be used. blind_module The blind module that the subscription is for. """ subscription = self.subscription_status(pubkey, blind_module) if subscription.details and subscription.details.renewable_at > datetime.now( timezone.utc ): raise CannotRenewSubscription(subscription.details.renewable_at) public_key = self.about().public_key.serialize() cost = self.subscription_cost(blind_module) payload = json.dumps( { "nonce": secrets.token_bytes(16).hex(), "service_public_key": public_key.hex(), "blind_module": str(blind_module), } ).encode("utf8") digest = hashlib.sha256(payload).digest() logger.info( "Making nilchain payment with payload=%s, digest=%s", payload.hex(), digest.hex(), ) tx_hash = payer.pay(digest, amount_unil=cost) logger.info("Submitting payment to nilauth with tx hash %s", tx_hash) request = { "tx_hash": tx_hash, "payload": payload.hex(), "public_key": pubkey.serialize().hex(), } for sleep_time in PAYMENT_TX_RETRIES: try: self._post( f"{self._base_url}/api/v1/payments/validate", request, ) return except RequestException as e: if e.error_code == TX_NOT_COMMITTED_ERROR_CODE: logger.warning( "Server couldn't process payment transaction, retrying in %s", sleep_time, ) sleep(sleep_time) else: raise raise PaymentValidationException(tx_hash, payload)
[docs] def subscription_status( self, pubkey: PublicKey, blind_module: BlindModule ) -> Subscription: """ Get the status of a subscription to a blind module. Arguments --------- pubkey The public key for which to get the subscription information. blind_module The blind module to get the subscription status for. .. note:: The private key is only used to sign a payload to prove ownership and is never transmitted anywhere. """ public_key = pubkey.serialize().hex() response = self._get( f"{self._base_url}/api/v1/subscriptions/status?public_key={public_key}&blind_module={str(blind_module)}" ) subscribed = response["subscribed"] details = response["details"] if details: details = SubscriptionDetails( expires_at=datetime.fromtimestamp(details["expires_at"], timezone.utc), renewable_at=datetime.fromtimestamp( details["renewable_at"], timezone.utc ), ) return Subscription(subscribed, details)
[docs] def about(self) -> NilauthAbout: """ Get information about the nilauth server. """ about = self._get(f"{self._base_url}/about") raw_public_key = bytes.fromhex(about["public_key"]) public_key = PublicKey(raw_public_key, raw=True) return NilauthAbout(public_key=public_key)
[docs] def subscription_cost(self, blind_module: BlindModule) -> int: """ Get the subscription cost in unils. Arguments --------- blind_module The blind module to get the subscription cost for. """ response = self._get( f"{self._base_url}/api/v1/payments/cost?blind_module={str(blind_module)}" ) return response["cost_unils"]
[docs] def revoke_token( self, auth_token: NucTokenEnvelope, token: NucTokenEnvelope, key: PrivateKey ) -> None: """ Revoke a token. Arguments --------- auth_token The token to be used as a base for authentication. token The token to be revoked. key The private key to use to mint the token. """ about = self.about() serialized_token = token.serialize() auth_token.validate_signatures() args = {"token": serialized_token} invocation = ( NucTokenBuilder.extending(auth_token) .body(InvocationBody(args)) .command(Command(["nuc", "revoke"])) .audience(Did(about.public_key.serialize())) .build(key) ) self._post( f"{self._base_url}/api/v1/revocations/revoke", {}, headers={"Authorization": f"Bearer {invocation}"}, )
[docs] def lookup_revoked_tokens(self, envelope: NucTokenEnvelope) -> List[RevokedToken]: """ Lookup revoked tokens that would invalidate the given token. Arguments --------- envelope The token envelope to do lookups for. """ hashes = [t.compute_hash().hex() for t in (envelope.token, *envelope.proofs)] request = {"hashes": hashes} response = self._post( f"{self._base_url}/api/v1/revocations/lookup", request, ) return [ RevokedToken( token_hash=t["token_hash"], revoked_at=datetime.fromtimestamp(t["revoked_at"], timezone.utc), ) for t in response["revoked"] ]
def _get(self, url: str, **kwargs) -> Any: response = requests.get(url, timeout=self._timeout_seconds, **kwargs) return self._response_as_json(response) def _post(self, url: str, body: Any, **kwargs) -> Any: response = requests.post( url, timeout=self._timeout_seconds, json=body, **kwargs ) return self._response_as_json(response) @staticmethod def _response_as_json(response: requests.Response) -> Any: body_json = response.json() code = response.status_code if 200 <= code < 300: return body_json message = body_json.get("message") error_code = body_json.get("error_code") if not message or not error_code: raise RequestException( "server did not reply with any error messages", "UNKNOWN" ) raise RequestException(message, error_code) @staticmethod def _create_signed_request(payload: Any, key: PrivateKey) -> Dict[str, Any]: payload = json.dumps(payload).encode("utf8") signature = key.ecdsa_serialize_compact(key.ecdsa_sign(payload)) return { "public_key": key.pubkey.serialize().hex(), # type: ignore "signature": signature.hex(), "payload": payload.hex(), }
[docs] class RequestException(Exception): """ An exception raised when a request fails. """ message: str error_code: str def __init__(self, message: str, error_code: str) -> None: super().__init__(self, f"{error_code}: {message}") self.message = message self.error_code = error_code
[docs] class PaymentValidationException(Exception): """ An exception raised when the validation for a payment fails. """ tx_hash: str payload: bytes def __init__(self, tx_hash: str, payload: bytes) -> None: super().__init__( self, f"failed to validate payment: tx_hash='{tx_hash}', payload='{payload.hex()}'", ) self.tx_hash = tx_hash self.payload = payload
[docs] class CannotRenewSubscription(Exception): """ An exception raised when a subscription cannot be renewed yet. """ renewable_at: datetime def __init__(self, renewable_at: datetime) -> None: super().__init__(self, f"cannot renew before {renewable_at.isoformat()}") self.renewable_at = renewable_at