import * as grpc from '@grpc/grpc-js'; import { TransportProtocolName } from '../../../core.js'; import { A2AServiceClient } from '../../../grpc/pb/a2a_services.js'; import { MessageSendParams, TaskPushNotificationConfig, TaskIdParams, ListTaskPushNotificationConfigParams, DeleteTaskPushNotificationConfigParams, TaskQueryParams, Task, AgentCard, GetTaskPushNotificationConfigParams, } from '../../../types.js'; import { A2AStreamEventData, SendMessageResult } from '../../client.js'; import { RequestOptions } from '../../multitransport-client.js'; import { Transport, TransportFactory } from '../transport.js'; import { ToProto } from '../../../types/converters/to_proto.js'; import { FromProto } from '../../../types/converters/from_proto.js'; import { AuthenticatedExtendedCardNotConfiguredError, PushNotificationNotSupportedError, TaskNotFoundError, TaskNotCancelableError, UnsupportedOperationError, } from '../../../errors.js'; type GrpcUnaryCall = ( request: TReq, metadata: grpc.Metadata, options: Partial, callback: (error: grpc.ServiceError | null, response: TRes) => void ) => grpc.ClientUnaryCall; type GrpcStreamCall = ( request: TReq, metadata?: grpc.Metadata, options?: Partial ) => grpc.ClientReadableStream; export interface GrpcTransportOptions { endpoint: string; grpcChannelCredentials?: grpc.ChannelCredentials; grpcCallOptions?: Partial; } export class GrpcTransport implements Transport { private readonly grpcCallOptions?: Partial; private readonly grpcClient: A2AServiceClient; constructor(options: GrpcTransportOptions) { this.grpcCallOptions = options.grpcCallOptions; this.grpcClient = new A2AServiceClient( options.endpoint, options.grpcChannelCredentials ?? grpc.credentials.createInsecure() ); } async getExtendedAgentCard(options?: RequestOptions): Promise { const rpcResponse = await this._sendGrpcRequest( 'getAgentCard', undefined, options, this.grpcClient.getAgentCard.bind(this.grpcClient), ToProto.getAgentCardRequest, FromProto.agentCard ); return rpcResponse; } async sendMessage( params: MessageSendParams, options?: RequestOptions ): Promise { const rpcResponse = await this._sendGrpcRequest( 'sendMessage', params, options, this.grpcClient.sendMessage.bind(this.grpcClient), ToProto.messageSendParams, FromProto.sendMessageResult ); return rpcResponse; } async *sendMessageStream( params: MessageSendParams, options?: RequestOptions ): AsyncGenerator { yield* this._sendGrpcStreamingRequest( 'sendStreamingMessage', params, options, this.grpcClient.sendStreamingMessage.bind(this.grpcClient), ToProto.messageSendParams ); } async setTaskPushNotificationConfig( params: TaskPushNotificationConfig, options?: RequestOptions ): Promise { const rpcResponse = await this._sendGrpcRequest( 'createTaskPushNotificationConfig', params, options, this.grpcClient.createTaskPushNotificationConfig.bind(this.grpcClient), ToProto.taskPushNotificationConfigCreate, FromProto.taskPushNotificationConfig ); return rpcResponse; } async getTaskPushNotificationConfig( params: GetTaskPushNotificationConfigParams, options?: RequestOptions ): Promise { const rpcResponse = await this._sendGrpcRequest( 'getTaskPushNotificationConfig', params, options, this.grpcClient.getTaskPushNotificationConfig.bind(this.grpcClient), ToProto.getTaskPushNotificationConfigParams, FromProto.taskPushNotificationConfig ); return rpcResponse; } async listTaskPushNotificationConfig( params: ListTaskPushNotificationConfigParams, options?: RequestOptions ): Promise { const rpcResponse = await this._sendGrpcRequest( 'listTaskPushNotificationConfig', params, options, this.grpcClient.listTaskPushNotificationConfig.bind(this.grpcClient), ToProto.listTaskPushNotificationConfigParams, FromProto.listTaskPushNotificationConfig ); return rpcResponse; } async deleteTaskPushNotificationConfig( params: DeleteTaskPushNotificationConfigParams, options?: RequestOptions ): Promise { await this._sendGrpcRequest( 'deleteTaskPushNotificationConfig', params, options, this.grpcClient.deleteTaskPushNotificationConfig.bind(this.grpcClient), ToProto.deleteTaskPushNotificationConfigParams, () => {} ); } async getTask(params: TaskQueryParams, options?: RequestOptions): Promise { const rpcResponse = await this._sendGrpcRequest( 'getTask', params, options, this.grpcClient.getTask.bind(this.grpcClient), ToProto.taskQueryParams, FromProto.task ); return rpcResponse; } async cancelTask(params: TaskIdParams, options?: RequestOptions): Promise { const rpcResponse = await this._sendGrpcRequest( 'cancelTask', params, options, this.grpcClient.cancelTask.bind(this.grpcClient), ToProto.cancelTaskRequest, FromProto.task ); return rpcResponse; } async *resubscribeTask( params: TaskIdParams, options?: RequestOptions ): AsyncGenerator { yield* this._sendGrpcStreamingRequest( 'taskSubscription', params, options, this.grpcClient.taskSubscription.bind(this.grpcClient), ToProto.taskIdParams ); } private async _sendGrpcRequest( method: keyof A2AServiceClient, params: TParams, options: RequestOptions | undefined, call: GrpcUnaryCall, parser: (req: TParams) => TReq, converter: (res: TRes) => TResponse ): Promise { return new Promise((resolve, reject) => { let onAbort: (() => void) | undefined; const clientCall = call( parser(params), this._buildMetadata(options), this.grpcCallOptions ?? {}, (error, response) => { if (options?.signal && onAbort) { options.signal.removeEventListener('abort', onAbort); } if (error) { return reject(GrpcTransport.mapToError(error, method)); } resolve(converter(response)); } ); if (options?.signal) { if (options.signal.aborted) { clientCall.cancel(); } else { onAbort = () => clientCall.cancel(); options.signal.addEventListener('abort', onAbort); } } }); } private async *_sendGrpcStreamingRequest( method: 'sendStreamingMessage' | 'taskSubscription', params: TParams, options: RequestOptions | undefined, call: GrpcStreamCall, parser: (req: TParams) => TReq ): AsyncGenerator { const streamResponse = call( parser(params), this._buildMetadata(options), this.grpcCallOptions ?? {} ); let onAbort: (() => void) | undefined; if (options?.signal) { if (options.signal.aborted) { streamResponse.cancel(); } else { onAbort = () => streamResponse.cancel(); options.signal.addEventListener('abort', onAbort); } } try { for await (const response of streamResponse) { yield FromProto.messageStreamResult(response); } } catch (error) { if (this.isServiceError(error)) { throw GrpcTransport.mapToError(error, method); } else { throw new Error(`GRPC error for ${String(method)}!`, { cause: error, }); } } finally { if (options?.signal && onAbort) { options.signal.removeEventListener('abort', onAbort); } streamResponse.cancel(); } } private isServiceError(error: unknown): error is grpc.ServiceError { return typeof error === 'object' && error !== null && 'code' in error; } private _buildMetadata(options: RequestOptions | undefined): grpc.Metadata { const metadata = new grpc.Metadata(); if (options?.serviceParameters) { for (const [key, value] of Object.entries(options.serviceParameters)) { metadata.set(key, value); } } return metadata; } // TODO: the logic of mapToError will be removed in v1.0.0 with the enriched error model (https://a2a-protocol.org/latest/specification/#106-error-handling) private static mapToError(error: grpc.ServiceError, method: keyof A2AServiceClient): Error { switch (error.code) { case grpc.status.NOT_FOUND: return new TaskNotFoundError(error.details); case grpc.status.FAILED_PRECONDITION: if (method === 'cancelTask') { return new TaskNotCancelableError(error.details); } if (method === 'getAgentCard') { return new AuthenticatedExtendedCardNotConfiguredError(error.details); } break; case grpc.status.UNIMPLEMENTED: if ( [ 'getTaskPushNotificationConfig', 'createTaskPushNotificationConfig', 'deleteTaskPushNotificationConfig', 'listTaskPushNotificationConfig', ].includes(method) ) { return new PushNotificationNotSupportedError(error.details); } if (['getAgentCard', 'taskSubscription'].includes(method)) { return new UnsupportedOperationError(error.details); } break; //TODO: add case for grpc.status.INVALID_ARGUMENT and grpc.status.INTERNAL (the respective a2a errors are not implemented yet) default: break; } return new Error(`GRPC error for ${String(method)}! ${error.code} ${error.details}`, { cause: error, }); } } export class GrpcTransportFactoryOptions { grpcChannelCredentials?: grpc.ChannelCredentials; grpcCallOptions?: Partial; } export class GrpcTransportFactory implements TransportFactory { public static readonly name: TransportProtocolName = 'GRPC'; constructor(private readonly options?: GrpcTransportFactoryOptions) {} get protocolName(): string { return GrpcTransportFactory.name; } async create(url: string, _agentCard: AgentCard): Promise { return new GrpcTransport({ endpoint: url, grpcChannelCredentials: this.options?.grpcChannelCredentials, grpcCallOptions: this.options?.grpcCallOptions, }); } }