/*
* Copyright 2025 Google LLC
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
using System.Runtime.CompilerServices;
using System.Text;
using Google.Apis.Util;
using Google.GenAI;
using Google.GenAI.Types;
namespace Microsoft.Extensions.AI;
/// Provides an implementation based on .
internal sealed class GoogleGenAIChatClient : IChatClient
{
/// A thought signature that can be used to skip thought validation when sending foreign function calls.
///
/// See https://ai.google.dev/gemini-api/docs/thought-signatures#faqs.
/// This is more common in agentic scenarios, where a chat history is built up across multiple providers/models.
///
private static readonly byte[] s_skipThoughtValidation = Encoding.UTF8.GetBytes("skip_thought_signature_validator");
/// The wrapped instance (optional).
private readonly Client? _client;
/// The wrapped instance.
private readonly Models _models;
/// The default model that should be used when no override is specified.
private readonly string? _defaultModelId;
/// Lazily-initialized metadata describing the implementation.
private ChatClientMetadata? _metadata;
/// Initializes a new instance.
public GoogleGenAIChatClient(Client client, string? defaultModelId)
{
_client = client;
_models = client.Models;
_defaultModelId = defaultModelId;
}
/// Initializes a new instance.
public GoogleGenAIChatClient(Models client, string? defaultModelId)
{
_models = client;
_defaultModelId = defaultModelId;
}
///
public async Task GetResponseAsync(IEnumerable messages, ChatOptions? options = null, CancellationToken cancellationToken = default)
{
Utilities.ThrowIfNull(messages, nameof(messages));
// Create the request.
(string? modelId, List contents, GenerateContentConfig config) = CreateRequest(messages, options);
// Send it.
GenerateContentResponse generateResult = await _models.GenerateContentAsync(modelId!, contents, config).ConfigureAwait(false);
// Create the response.
ChatResponse chatResponse = new(new ChatMessage(ChatRole.Assistant, new List()))
{
CreatedAt = generateResult.CreateTime is { } dt ? new DateTimeOffset(dt) : null,
ModelId = !string.IsNullOrWhiteSpace(generateResult.ModelVersion) ? generateResult.ModelVersion : modelId,
RawRepresentation = generateResult,
ResponseId = generateResult.ResponseId,
};
// Populate the response messages.
chatResponse.FinishReason = PopulateResponseContents(generateResult, chatResponse.Messages[0].Contents);
// Populate usage information if there is any.
if (generateResult.UsageMetadata is { } usageMetadata)
{
chatResponse.Usage = ExtractUsageDetails(usageMetadata);
}
// Return the response.
return chatResponse;
}
///
public async IAsyncEnumerable GetStreamingResponseAsync(IEnumerable messages, ChatOptions? options = null, [EnumeratorCancellation] CancellationToken cancellationToken = default)
{
Utilities.ThrowIfNull(messages, nameof(messages));
// Create the request.
(string? modelId, List contents, GenerateContentConfig config) = CreateRequest(messages, options);
// Send it, and process the results.
await foreach (GenerateContentResponse generateResult in _models.GenerateContentStreamAsync(modelId!, contents, config).WithCancellation(cancellationToken).ConfigureAwait(false))
{
// Create a response update for each result in the stream.
ChatResponseUpdate responseUpdate = new(ChatRole.Assistant, new List())
{
CreatedAt = generateResult.CreateTime is { } dt ? new DateTimeOffset(dt) : null,
ModelId = !string.IsNullOrWhiteSpace(generateResult.ModelVersion) ? generateResult.ModelVersion : modelId,
RawRepresentation = generateResult,
ResponseId = generateResult.ResponseId,
};
// Populate the response update contents.
responseUpdate.FinishReason = PopulateResponseContents(generateResult, responseUpdate.Contents);
// Populate usage information if there is any.
if (generateResult.UsageMetadata is { } usageMetadata)
{
responseUpdate.Contents.Add(new UsageContent(ExtractUsageDetails(usageMetadata)));
}
// Yield the update.
yield return responseUpdate;
}
}
///
public object? GetService(System.Type serviceType, object? serviceKey = null)
{
Utilities.ThrowIfNull(serviceType, nameof(serviceType));
if (serviceKey is null)
{
// If there's a request for metadata, lazily-initialize it and return it. We don't need to worry about race conditions,
// as there's no requirement that the same instance be returned each time, and creation is idempotent.
if (serviceType == typeof(ChatClientMetadata))
{
return _metadata ??= new("gcp.gen_ai", new("https://generativelanguage.googleapis.com/"), defaultModelId: _defaultModelId);
}
// Allow a consumer to "break glass" and access the underlying client if they need it.
if (serviceType.IsInstanceOfType(_models))
{
return _models;
}
if (_client is not null && serviceType.IsInstanceOfType(_client))
{
return _client;
}
if (serviceType.IsInstanceOfType(this))
{
return this;
}
}
return null;
}
///
void IDisposable.Dispose() { /* nop */ }
/// Creates the message parameters for from and .
private (string? ModelId, List Contents, GenerateContentConfig Config) CreateRequest(IEnumerable messages, ChatOptions? options)
{
// Create the GenerateContentConfig object. If the options contains a RawRepresentationFactory, try to use it to
// create the request instance, allowing the caller to populate it with GenAI-specific options. Otherwise, create
// a new instance directly.
string? model = _defaultModelId;
List contents = new();
GenerateContentConfig config = options?.RawRepresentationFactory?.Invoke(this) as GenerateContentConfig ?? new();
if (options is not null)
{
if (options.FrequencyPenalty is { } frequencyPenalty)
{
config.FrequencyPenalty ??= frequencyPenalty;
}
if (options.Instructions is { } instructions)
{
((config.SystemInstruction ??= new()).Parts ??= new()).Add(new() { Text = instructions });
}
if (options.MaxOutputTokens is { } maxOutputTokens)
{
config.MaxOutputTokens ??= maxOutputTokens;
}
if (!string.IsNullOrWhiteSpace(options.ModelId))
{
model = options.ModelId;
}
if (options.PresencePenalty is { } presencePenalty)
{
config.PresencePenalty ??= presencePenalty;
}
if (options.Seed is { } seed)
{
config.Seed ??= (int)seed;
}
if (options.StopSequences is { } stopSequences)
{
(config.StopSequences ??= new()).AddRange(stopSequences);
}
if (options.Temperature is { } temperature)
{
config.Temperature ??= temperature;
}
if (options.TopP is { } topP)
{
config.TopP ??= topP;
}
if (options.TopK is { } topK)
{
config.TopK ??= topK;
}
// Populate tools. Each kind of tool is added on its own, except for function declarations,
// which are grouped into a single FunctionDeclaration.
List? functionDeclarations = null;
if (options.Tools is { } tools)
{
foreach (var tool in tools)
{
switch (tool)
{
case AIFunctionDeclaration af:
functionDeclarations ??= new();
functionDeclarations.Add(new()
{
Name = af.Name,
Description = af.Description ?? "",
ParametersJsonSchema = af.JsonSchema,
});
break;
case HostedCodeInterpreterTool:
(config.Tools ??= new()).Add(new() { CodeExecution = new() });
break;
case HostedFileSearchTool:
(config.Tools ??= new()).Add(new() { Retrieval = new() });
break;
case HostedWebSearchTool:
(config.Tools ??= new()).Add(new() { GoogleSearch = new() });
break;
}
}
}
if (functionDeclarations is { Count: > 0 })
{
Tool functionTools = new();
(functionTools.FunctionDeclarations ??= new()).AddRange(functionDeclarations);
(config.Tools ??= new()).Add(functionTools);
}
// Transfer over the tool mode if there are any tools.
if (options.ToolMode is { } toolMode && config.Tools?.Count > 0)
{
switch (toolMode)
{
case NoneChatToolMode:
config.ToolConfig = new() { FunctionCallingConfig = new() { Mode = FunctionCallingConfigMode.NONE } };
break;
case AutoChatToolMode:
config.ToolConfig = new() { FunctionCallingConfig = new() { Mode = FunctionCallingConfigMode.AUTO } };
break;
case RequiredChatToolMode required:
config.ToolConfig = new() { FunctionCallingConfig = new() { Mode = FunctionCallingConfigMode.ANY } };
if (required.RequiredFunctionName is not null)
{
((config.ToolConfig.FunctionCallingConfig ??= new()).AllowedFunctionNames ??= new()).Add(required.RequiredFunctionName);
}
break;
}
}
// Set the response format if specified.
if (options.ResponseFormat is ChatResponseFormatJson responseFormat)
{
config.ResponseMimeType = "application/json";
if (responseFormat.Schema is { } schema)
{
config.ResponseJsonSchema = schema;
}
}
}
// Transfer messages to request, handling system messages specially
Dictionary? callIdToFunctionNames = null;
foreach (var message in messages)
{
if (message.Role == ChatRole.System)
{
string instruction = message.Text;
if (!string.IsNullOrWhiteSpace(instruction))
{
((config.SystemInstruction ??= new()).Parts ??= new()).Add(new() { Text = instruction });
}
continue;
}
Content content = new() { Role = message.Role == ChatRole.Assistant ? "model" : "user" };
content.Parts ??= new();
AddPartsForAIContents(ref callIdToFunctionNames, message.Contents, content.Parts);
contents.Add(content);
}
// Make sure the request contains at least one content part (the request would always fail if empty).
if (!contents.SelectMany(c => c.Parts ?? Enumerable.Empty()).Any())
{
contents.Add(new() { Role = "user", Parts = new() { { new() { Text = "" } } } });
}
return (model, contents, config);
}
/// Creates s for and adds them to .
private static void AddPartsForAIContents(ref Dictionary? callIdToFunctionNames, IList contents, List parts)
{
for (int i = 0; i < contents.Count; i++)
{
var content = contents[i];
byte[]? thoughtSignature = null;
if (content is not TextReasoningContent { ProtectedData: not null } &&
i + 1 < contents.Count &&
contents[i + 1] is TextReasoningContent nextReasoning &&
string.IsNullOrWhiteSpace(nextReasoning.Text) &&
nextReasoning.ProtectedData is { } protectedData)
{
i++;
thoughtSignature = Convert.FromBase64String(protectedData);
}
// Before the main switch, do any necessary state tracking. We want to do this
// even if the AIContent includes a Part as its RawRepresentation.
if (content is FunctionCallContent fcc)
{
(callIdToFunctionNames ??= new())[fcc.CallId] = fcc.Name;
callIdToFunctionNames[""] = fcc.Name; // track last function name in case calls don't have IDs
}
Part? part = null;
switch (content)
{
case AIContent aic when aic.RawRepresentation is Part rawPart:
part = rawPart;
break;
case TextContent textContent:
part = new() { Text = textContent.Text };
break;
case TextReasoningContent reasoningContent:
part = new()
{
Thought = true,
Text = !string.IsNullOrWhiteSpace(reasoningContent.Text) ? reasoningContent.Text : null,
ThoughtSignature = reasoningContent.ProtectedData is not null ? Convert.FromBase64String(reasoningContent.ProtectedData) : null,
};
break;
case DataContent dataContent:
part = new()
{
InlineData = new()
{
MimeType = dataContent.MediaType,
Data = dataContent.Data.ToArray(),
DisplayName = dataContent.Name,
}
};
break;
case UriContent uriContent:
part = new()
{
FileData = new()
{
FileUri = uriContent.Uri.AbsoluteUri,
MimeType = uriContent.MediaType,
}
};
break;
case FunctionCallContent functionCallContent:
part = new()
{
FunctionCall = new()
{
Id = functionCallContent.CallId,
Name = functionCallContent.Name,
Args = functionCallContent.Arguments is null ? null : functionCallContent.Arguments as Dictionary ?? new(functionCallContent.Arguments!),
},
ThoughtSignature = thoughtSignature ?? s_skipThoughtValidation,
};
break;
case FunctionResultContent functionResultContent:
FunctionResponse funcResponse = new()
{
Id = functionResultContent.CallId,
};
if (callIdToFunctionNames?.TryGetValue(functionResultContent.CallId, out string? functionName) is true ||
callIdToFunctionNames?.TryGetValue("", out functionName) is true)
{
funcResponse.Name = functionName;
}
switch (functionResultContent.Result)
{
case null:
break;
case AIContent aic when ToFunctionResponsePart(aic) is { } singleContentBlob:
funcResponse.Parts = new() { singleContentBlob };
break;
case IEnumerable aiContents:
List? nonBlobContent = null;
foreach (var aiContent in aiContents)
{
if (ToFunctionResponsePart(aiContent) is { } contentBlob)
{
(funcResponse.Parts ??= new()).Add(contentBlob);
}
else
{
(nonBlobContent ??= new()).Add(aiContent);
}
}
if (nonBlobContent is not null)
{
funcResponse.Response = new() { ["result"] = nonBlobContent };
}
break;
case TextContent textContent:
funcResponse.Response = new() { ["result"] = textContent.Text };
break;
default:
funcResponse.Response = new() { ["result"] = functionResultContent.Result };
break;
}
part = new()
{
FunctionResponse = funcResponse,
};
static FunctionResponsePart? ToFunctionResponsePart(AIContent content)
{
switch (content)
{
case AIContent when content.RawRepresentation is FunctionResponsePart functionResponsePart:
return functionResponsePart;
case DataContent dc when IsSupportedMediaType(dc.MediaType):
FunctionResponseBlob dataBlob = new()
{
MimeType = dc.MediaType,
Data = dc.Data.Span.ToArray(),
};
if (!string.IsNullOrWhiteSpace(dc.Name))
{
dataBlob.DisplayName = dc.Name;
}
return new() { InlineData = dataBlob };
case UriContent uc when IsSupportedMediaType(uc.MediaType):
return new()
{
FileData = new()
{
MimeType = uc.MediaType,
FileUri = uc.Uri.AbsoluteUri,
}
};
default:
return null;
}
// https://docs.cloud.google.com/vertex-ai/generative-ai/docs/multimodal/function-calling#mm-fr
static bool IsSupportedMediaType(string mediaType) =>
// images
mediaType.Equals("image/png", StringComparison.OrdinalIgnoreCase) ||
mediaType.Equals("image/jpeg", StringComparison.OrdinalIgnoreCase) ||
mediaType.Equals("image/webp", StringComparison.OrdinalIgnoreCase) ||
// documents
mediaType.Equals("application/pdf", StringComparison.OrdinalIgnoreCase) ||
mediaType.Equals("text/plain", StringComparison.OrdinalIgnoreCase);
}
break;
}
if (part is not null)
{
part.ThoughtSignature ??= thoughtSignature;
parts.Add(part);
}
thoughtSignature = null;
}
}
/// Creates s for and adds them to .
private static void AddAIContentsForParts(List parts, IList contents)
{
foreach (var part in parts)
{
AIContent content;
if (!string.IsNullOrEmpty(part.Text))
{
content = part.Thought is true ?
new TextReasoningContent(part.Text) :
new TextContent(part.Text);
}
else if (part.InlineData is { } inlineData)
{
content = new DataContent(inlineData.Data, inlineData.MimeType ?? "application/octet-stream")
{
Name = inlineData.DisplayName,
};
}
else if (part.FileData is { FileUri: not null } fileData)
{
content = new UriContent(new Uri(fileData.FileUri), fileData.MimeType ?? "application/octet-stream");
}
else if (part.FunctionCall is { Name: not null } functionCall)
{
content = new FunctionCallContent(functionCall.Id ?? "", functionCall.Name, functionCall.Args!);
}
else if (part.FunctionResponse is { } functionResponse)
{
content = new FunctionResultContent(
functionResponse.Id ?? "",
functionResponse.Response?.TryGetValue("output", out var output) is true ? output :
functionResponse.Response?.TryGetValue("error", out var error) is true ? error :
null);
}
else if (part.ExecutableCode is { Code: not null } executableCode)
{
content = new CodeInterpreterToolCallContent()
{
Inputs = new List()
{
new DataContent(Encoding.UTF8.GetBytes(executableCode.Code), executableCode.Language switch
{
Language.PYTHON => "text/x-python",
_ => "text/x-source-code",
})
},
};
}
else if (part.CodeExecutionResult is { Output: { } codeOutput } codeExecutionResult)
{
content = new CodeInterpreterToolResultContent()
{
Outputs = new List()
{
codeExecutionResult.Outcome is Outcome.OUTCOME_OK ?
new TextContent(codeOutput) :
new ErrorContent(codeOutput) { ErrorCode = codeExecutionResult.Outcome.ToString() }
},
};
}
else
{
content = new AIContent();
}
content.RawRepresentation = part;
contents.Add(content);
if (part.ThoughtSignature is { } thoughtSignature)
{
contents.Add(new TextReasoningContent(null)
{
ProtectedData = Convert.ToBase64String(thoughtSignature),
});
}
}
}
private static ChatFinishReason? PopulateResponseContents(GenerateContentResponse generateResult, IList responseContents)
{
ChatFinishReason? finishReason = null;
// Populate the response messages. There should only be at most one candidate, but if there are more, ignore all but the first.
if (generateResult.Candidates is { Count: > 0 } &&
generateResult.Candidates[0] is { Content: { } candidateContent } candidate)
{
// Grab the finish reason if one exists.
finishReason = ConvertFinishReason(candidate.FinishReason);
// Add all of the response content parts as AIContents.
if (candidateContent.Parts is { } parts)
{
AddAIContentsForParts(parts, responseContents);
}
// Add any citation metadata.
if (candidate.CitationMetadata is { Citations: { Count: > 0 } citations } &&
responseContents.OfType().FirstOrDefault() is TextContent textContent)
{
foreach (var citation in citations)
{
textContent.Annotations = new List()
{
new CitationAnnotation()
{
Title = citation.Title,
Url = Uri.TryCreate(citation.Uri, UriKind.Absolute, out Uri? uri) ? uri : null,
AnnotatedRegions = new List()
{
new TextSpanAnnotatedRegion()
{
StartIndex = citation.StartIndex,
EndIndex = citation.EndIndex,
}
},
}
};
}
}
}
// Populate error information if there is any.
if (generateResult.PromptFeedback is { } promptFeedback)
{
responseContents.Add(new ErrorContent(promptFeedback.BlockReasonMessage));
}
return finishReason;
}
/// Creates an M.E.AI from a Google .
private static ChatFinishReason? ConvertFinishReason(FinishReason? finishReason)
{
return finishReason switch
{
null => null,
FinishReason.MAX_TOKENS =>
ChatFinishReason.Length,
FinishReason.MALFORMED_FUNCTION_CALL or
FinishReason.UNEXPECTED_TOOL_CALL =>
ChatFinishReason.ToolCalls,
FinishReason.FINISH_REASON_UNSPECIFIED or
FinishReason.STOP =>
ChatFinishReason.Stop,
_ => ChatFinishReason.ContentFilter,
};
}
/// Creates a populated from the supplied .
private static UsageDetails ExtractUsageDetails(GenerateContentResponseUsageMetadata usageMetadata)
{
UsageDetails details = new()
{
InputTokenCount = usageMetadata.PromptTokenCount,
OutputTokenCount = usageMetadata.CandidatesTokenCount,
TotalTokenCount = usageMetadata.TotalTokenCount,
};
AddIfPresent(nameof(usageMetadata.CachedContentTokenCount), usageMetadata.CachedContentTokenCount);
AddIfPresent(nameof(usageMetadata.ThoughtsTokenCount), usageMetadata.ThoughtsTokenCount);
AddIfPresent(nameof(usageMetadata.ToolUsePromptTokenCount), usageMetadata.ToolUsePromptTokenCount);
return details;
void AddIfPresent(string key, int? value)
{
if (value is int i)
{
(details.AdditionalCounts ??= new())[key] = i;
}
}
}
}