Skip to content

vllm.tool_parsers.minimax_m2_tool_parser

MinimaxM2ToolParser

Bases: ToolParser

Source code in vllm/tool_parsers/minimax_m2_tool_parser.py
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
class MinimaxM2ToolParser(ToolParser):
    def __init__(self, tokenizer: TokenizerLike):
        super().__init__(tokenizer)

        self.prev_tool_call_arr: list[dict] = []

        # Sentinel tokens
        self.tool_call_start_token: str = "<minimax:tool_call>"
        self.tool_call_end_token: str = "</minimax:tool_call>"

        # Streaming state
        self.is_tool_call_started: bool = False
        self.current_tool_index: int = 0

        # Regex patterns for complete parsing
        self.tool_call_complete_regex = re.compile(
            r"<minimax:tool_call>(.*?)</minimax:tool_call>", re.DOTALL
        )
        self.invoke_complete_regex = re.compile(
            r"<invoke name=(.*?)</invoke>", re.DOTALL
        )
        self.parameter_complete_regex = re.compile(
            r"<parameter name=(.*?)</parameter>", re.DOTALL
        )

        if not self.model_tokenizer:
            raise ValueError(
                "The model tokenizer must be passed to the ToolParser "
                "constructor during construction."
            )

        self.tool_call_start_token_id = self.vocab.get(self.tool_call_start_token)
        self.tool_call_end_token_id = self.vocab.get(self.tool_call_end_token)

        if self.tool_call_start_token_id is None or self.tool_call_end_token_id is None:
            raise RuntimeError(
                "MiniMax M2 Tool parser could not locate tool call start/end "
                "tokens in the tokenizer!"
            )

        logger.debug(
            "vLLM Successfully import tool parser %s !", self.__class__.__name__
        )

    def _generate_tool_call_id(self) -> str:
        """Generate a unique tool call ID."""
        return f"call_{uuid.uuid4().hex[:24]}"

    def _extract_name(self, name_str: str) -> str:
        """Extract name from quoted string."""
        name_str = name_str.strip()
        if (name_str.startswith('"') and name_str.endswith('"')) or (
            name_str.startswith("'") and name_str.endswith("'")
        ):
            return name_str[1:-1]
        return name_str

    def _extract_types_from_schema(self, schema: Any) -> list[str]:
        """
        Extract all possible types from a JSON schema definition.
        Handles anyOf, oneOf, allOf, type arrays, and enum fields.

        Args:
            schema: The JSON schema definition for a parameter

        Returns:
            List of type strings (e.g., ["string", "integer", "null"])
        """
        if schema is None:
            return ["string"]

        if not isinstance(schema, dict):
            return ["string"]

        types: set[str] = set()

        # Handle direct "type" field
        if "type" in schema:
            type_value = schema["type"]
            if isinstance(type_value, str):
                types.add(type_value)
            elif isinstance(type_value, list):
                for t in type_value:
                    if isinstance(t, str):
                        types.add(t)

        # Handle enum - infer types from enum values
        if "enum" in schema and isinstance(schema["enum"], list) and schema["enum"]:
            for value in schema["enum"]:
                if value is None:
                    types.add("null")
                elif isinstance(value, bool):
                    types.add("boolean")
                elif isinstance(value, int):
                    types.add("integer")
                elif isinstance(value, float):
                    types.add("number")
                elif isinstance(value, str):
                    types.add("string")
                elif isinstance(value, list):
                    types.add("array")
                elif isinstance(value, dict):
                    types.add("object")

        # Handle anyOf, oneOf, allOf - recursively extract types
        for choice_field in ("anyOf", "oneOf", "allOf"):
            if choice_field in schema and isinstance(schema[choice_field], list):
                for choice in schema[choice_field]:
                    extracted = self._extract_types_from_schema(choice)
                    types.update(extracted)

        # If no types found, default to string
        if not types:
            return ["string"]

        return list(types)

    def _convert_param_value_with_types(
        self, value: str, param_types: list[str]
    ) -> Any:
        """
        Convert parameter value to the correct type based on a list of possible types.
        Tries each type in order until one succeeds.

        Args:
            value: The string value to convert
            param_types: List of possible type strings

        Returns:
            The converted value
        """
        # Check if the VALUE itself indicates null (not just if null is allowed)
        if value.lower() in ("null", "none", "nil"):
            return None

        # Normalize types
        normalized_types = [t.lower() for t in param_types]

        # Try each type in order of preference (most specific first, string as fallback)
        # Priority: integer > number > boolean > object > array > string
        type_priority = [
            "integer",
            "int",
            "number",
            "float",
            "boolean",
            "bool",
            "object",
            "array",
            "string",
            "str",
            "text",
        ]

        for param_type in type_priority:
            if param_type not in normalized_types:
                continue

            if param_type in ["string", "str", "text"]:
                return value
            elif param_type in ["integer", "int"]:
                try:
                    return int(value)
                except (ValueError, TypeError):
                    continue
            elif param_type in ["number", "float"]:
                try:
                    val = float(value)
                    return val if val != int(val) else int(val)
                except (ValueError, TypeError):
                    continue
            elif param_type in ["boolean", "bool"]:
                lower_val = value.lower().strip()
                if lower_val in ["true", "1", "yes", "on"]:
                    return True
                elif lower_val in ["false", "0", "no", "off"]:
                    return False
                continue
            elif param_type in ["object", "array"]:
                try:
                    return json.loads(value)
                except json.JSONDecodeError:
                    continue

        # Fallback: try JSON parse, then return as string
        try:
            return json.loads(value)
        except json.JSONDecodeError:
            return value

    def _get_param_types_from_config(
        self, param_name: str, param_config: dict
    ) -> list[str]:
        """
        Get parameter types from parameter configuration.
        Handles anyOf, oneOf, allOf, and direct type definitions.

        Args:
            param_name: The name of the parameter
            param_config: The properties dict from the tool schema

        Returns:
            List of type strings
        """
        if param_name not in param_config:
            return ["string"]

        param_schema = param_config[param_name]
        if not isinstance(param_schema, dict):
            return ["string"]

        return self._extract_types_from_schema(param_schema)

    def _parse_single_invoke(
        self, invoke_str: str, tools: list | None
    ) -> ToolCall | None:
        """Parse a single <invoke> block."""
        # Extract function name
        name_match = re.search(r"^([^>]+)", invoke_str)
        if not name_match:
            return None

        function_name = self._extract_name(name_match.group(1))

        # Get parameter configuration
        param_config = {}
        if tools:
            for tool in tools:
                if (
                    hasattr(tool, "function")
                    and tool.function.name == function_name
                    and hasattr(tool.function, "parameters")
                ):
                    params = tool.function.parameters
                    if isinstance(params, dict) and "properties" in params:
                        param_config = params["properties"]
                    break

        # Extract parameters
        param_dict = {}
        for match in self.parameter_complete_regex.findall(invoke_str):
            param_match = re.search(r"^([^>]+)>(.*)", match, re.DOTALL)
            if param_match:
                param_name = self._extract_name(param_match.group(1))
                param_value = param_match.group(2).strip()

                # Get parameter types (supports anyOf/oneOf/allOf)
                param_type = self._get_param_types_from_config(param_name, param_config)

                # Convert value
                param_dict[param_name] = self._convert_param_value_with_types(
                    param_value, param_type
                )

        return ToolCall(
            type="function",
            function=FunctionCall(
                name=function_name,
                arguments=json.dumps(param_dict, ensure_ascii=False),
            ),
        )

    def _extract_delta_tool_calls(
        self,
        current_text: str,
        request: ChatCompletionRequest | None,
    ) -> list[DeltaToolCall]:
        """Extract DeltaToolCalls from newly completed <invoke> blocks.

        Tracks progress via ``current_tool_index`` so each block is
        extracted exactly once across successive streaming calls.
        """
        complete_invokes = self.invoke_complete_regex.findall(current_text)
        delta_tool_calls: list[DeltaToolCall] = []

        while len(complete_invokes) > self.current_tool_index:
            invoke_str = complete_invokes[self.current_tool_index]
            tool_call = self._parse_single_invoke(
                invoke_str,
                request.tools if request else None,
            )
            if not tool_call:
                self.current_tool_index += 1
                continue

            args_json = tool_call.function.arguments
            idx = self.current_tool_index
            self.current_tool_index += 1

            self.prev_tool_call_arr.append(
                {
                    "name": tool_call.function.name,
                    "arguments": json.loads(args_json),
                }
            )
            self.streamed_args_for_tool.append(args_json)
            delta_tool_calls.append(
                DeltaToolCall(
                    index=idx,
                    id=self._generate_tool_call_id(),
                    function=DeltaFunctionCall(
                        name=tool_call.function.name,
                        arguments=args_json,
                    ),
                    type="function",
                )
            )

        return delta_tool_calls

    def extract_tool_calls(
        self,
        model_output: str,
        request: ChatCompletionRequest,
    ) -> ExtractedToolCallInformation:
        """Extract tool calls from complete model output (non-streaming)."""
        # Quick check
        if self.tool_call_start_token not in model_output:
            return ExtractedToolCallInformation(
                tools_called=False, tool_calls=[], content=model_output
            )

        try:
            tool_calls = []

            # Find all complete tool_call blocks
            for tool_call_match in self.tool_call_complete_regex.findall(model_output):
                # Find all invokes within this tool_call
                for invoke_match in self.invoke_complete_regex.findall(tool_call_match):
                    tool_call = self._parse_single_invoke(
                        invoke_match, request.tools if request else None
                    )
                    if tool_call:
                        tool_calls.append(tool_call)

            if not tool_calls:
                return ExtractedToolCallInformation(
                    tools_called=False, tool_calls=[], content=model_output
                )

            # Update prev_tool_call_arr
            self.prev_tool_call_arr.clear()
            for tool_call in tool_calls:
                self.prev_tool_call_arr.append(
                    {
                        "name": tool_call.function.name,
                        "arguments": tool_call.function.arguments,
                    }
                )

            # Extract content before first tool call
            first_tool_idx = model_output.find(self.tool_call_start_token)
            content = model_output[:first_tool_idx] if first_tool_idx > 0 else None

            return ExtractedToolCallInformation(
                tools_called=True, tool_calls=tool_calls, content=content
            )

        except Exception:
            logger.exception("Error extracting tool calls")
            return ExtractedToolCallInformation(
                tools_called=False, tool_calls=[], content=model_output
            )

    def extract_tool_calls_streaming(
        self,
        previous_text: str,
        current_text: str,
        delta_text: str,
        previous_token_ids: Sequence[int],  # pylint: disable=unused-argument
        current_token_ids: Sequence[int],  # pylint: disable=unused-argument
        delta_token_ids: Sequence[int],
        request: ChatCompletionRequest,
    ) -> DeltaMessage | None:
        """Extract tool calls from streaming model output.

        Uses a buffer-until-complete-invoke strategy: tokens are buffered
        until a complete ``<invoke>...</invoke>`` block is available, then
        parsed and emitted in one shot.
        """

        start_in_text = self.tool_call_start_token in delta_text
        start_in_ids = self.tool_call_start_token_id in delta_token_ids
        tool_call_starting = start_in_text or start_in_ids
        # Reset state on new request (parser is reused) or new tool-call block.
        if not previous_text or tool_call_starting:
            self.current_tool_index = 0
            self.prev_tool_call_arr.clear()
            self.streamed_args_for_tool.clear()
            self.is_tool_call_started = tool_call_starting

        # Pass through content before any tool call.
        if not self.is_tool_call_started:
            return DeltaMessage(content=delta_text) if delta_text else None

        # Capture content before the start token.
        content_before = None
        if start_in_text:
            before = delta_text[: delta_text.index(self.tool_call_start_token)]
            content_before = before or None

        # Extract newly completed <invoke> blocks as DeltaToolCalls.
        delta_tool_calls = self._extract_delta_tool_calls(current_text, request)

        if delta_tool_calls or content_before:
            return DeltaMessage(
                content=content_before,
                tool_calls=delta_tool_calls,
            )

        # EOS and </minimax:tool_call> both arrive as special tokens with
        # no decoded text. Return non-None for EOS so the serving framework
        # reaches the finish-reason handling path instead of skipping.
        if (
            not delta_text
            and delta_token_ids
            and self.prev_tool_call_arr
            and self.tool_call_end_token_id not in delta_token_ids
        ):
            return DeltaMessage(content="")

        return None

_convert_param_value_with_types

_convert_param_value_with_types(
    value: str, param_types: list[str]
) -> Any

Convert parameter value to the correct type based on a list of possible types. Tries each type in order until one succeeds.

Parameters:

Name Type Description Default
value str

The string value to convert

required
param_types list[str]

List of possible type strings

required

Returns:

Type Description
Any

The converted value

Source code in vllm/tool_parsers/minimax_m2_tool_parser.py
def _convert_param_value_with_types(
    self, value: str, param_types: list[str]
) -> Any:
    """
    Convert parameter value to the correct type based on a list of possible types.
    Tries each type in order until one succeeds.

    Args:
        value: The string value to convert
        param_types: List of possible type strings

    Returns:
        The converted value
    """
    # Check if the VALUE itself indicates null (not just if null is allowed)
    if value.lower() in ("null", "none", "nil"):
        return None

    # Normalize types
    normalized_types = [t.lower() for t in param_types]

    # Try each type in order of preference (most specific first, string as fallback)
    # Priority: integer > number > boolean > object > array > string
    type_priority = [
        "integer",
        "int",
        "number",
        "float",
        "boolean",
        "bool",
        "object",
        "array",
        "string",
        "str",
        "text",
    ]

    for param_type in type_priority:
        if param_type not in normalized_types:
            continue

        if param_type in ["string", "str", "text"]:
            return value
        elif param_type in ["integer", "int"]:
            try:
                return int(value)
            except (ValueError, TypeError):
                continue
        elif param_type in ["number", "float"]:
            try:
                val = float(value)
                return val if val != int(val) else int(val)
            except (ValueError, TypeError):
                continue
        elif param_type in ["boolean", "bool"]:
            lower_val = value.lower().strip()
            if lower_val in ["true", "1", "yes", "on"]:
                return True
            elif lower_val in ["false", "0", "no", "off"]:
                return False
            continue
        elif param_type in ["object", "array"]:
            try:
                return json.loads(value)
            except json.JSONDecodeError:
                continue

    # Fallback: try JSON parse, then return as string
    try:
        return json.loads(value)
    except json.JSONDecodeError:
        return value

_extract_delta_tool_calls

_extract_delta_tool_calls(
    current_text: str, request: ChatCompletionRequest | None
) -> list[DeltaToolCall]

Extract DeltaToolCalls from newly completed blocks.

Tracks progress via current_tool_index so each block is extracted exactly once across successive streaming calls.

Source code in vllm/tool_parsers/minimax_m2_tool_parser.py
def _extract_delta_tool_calls(
    self,
    current_text: str,
    request: ChatCompletionRequest | None,
) -> list[DeltaToolCall]:
    """Extract DeltaToolCalls from newly completed <invoke> blocks.

    Tracks progress via ``current_tool_index`` so each block is
    extracted exactly once across successive streaming calls.
    """
    complete_invokes = self.invoke_complete_regex.findall(current_text)
    delta_tool_calls: list[DeltaToolCall] = []

    while len(complete_invokes) > self.current_tool_index:
        invoke_str = complete_invokes[self.current_tool_index]
        tool_call = self._parse_single_invoke(
            invoke_str,
            request.tools if request else None,
        )
        if not tool_call:
            self.current_tool_index += 1
            continue

        args_json = tool_call.function.arguments
        idx = self.current_tool_index
        self.current_tool_index += 1

        self.prev_tool_call_arr.append(
            {
                "name": tool_call.function.name,
                "arguments": json.loads(args_json),
            }
        )
        self.streamed_args_for_tool.append(args_json)
        delta_tool_calls.append(
            DeltaToolCall(
                index=idx,
                id=self._generate_tool_call_id(),
                function=DeltaFunctionCall(
                    name=tool_call.function.name,
                    arguments=args_json,
                ),
                type="function",
            )
        )

    return delta_tool_calls

_extract_name

_extract_name(name_str: str) -> str

Extract name from quoted string.

Source code in vllm/tool_parsers/minimax_m2_tool_parser.py
def _extract_name(self, name_str: str) -> str:
    """Extract name from quoted string."""
    name_str = name_str.strip()
    if (name_str.startswith('"') and name_str.endswith('"')) or (
        name_str.startswith("'") and name_str.endswith("'")
    ):
        return name_str[1:-1]
    return name_str

_extract_types_from_schema

_extract_types_from_schema(schema: Any) -> list[str]

Extract all possible types from a JSON schema definition. Handles anyOf, oneOf, allOf, type arrays, and enum fields.

Parameters:

Name Type Description Default
schema Any

The JSON schema definition for a parameter

required

Returns:

Type Description
list[str]

List of type strings (e.g., ["string", "integer", "null"])

Source code in vllm/tool_parsers/minimax_m2_tool_parser.py
def _extract_types_from_schema(self, schema: Any) -> list[str]:
    """
    Extract all possible types from a JSON schema definition.
    Handles anyOf, oneOf, allOf, type arrays, and enum fields.

    Args:
        schema: The JSON schema definition for a parameter

    Returns:
        List of type strings (e.g., ["string", "integer", "null"])
    """
    if schema is None:
        return ["string"]

    if not isinstance(schema, dict):
        return ["string"]

    types: set[str] = set()

    # Handle direct "type" field
    if "type" in schema:
        type_value = schema["type"]
        if isinstance(type_value, str):
            types.add(type_value)
        elif isinstance(type_value, list):
            for t in type_value:
                if isinstance(t, str):
                    types.add(t)

    # Handle enum - infer types from enum values
    if "enum" in schema and isinstance(schema["enum"], list) and schema["enum"]:
        for value in schema["enum"]:
            if value is None:
                types.add("null")
            elif isinstance(value, bool):
                types.add("boolean")
            elif isinstance(value, int):
                types.add("integer")
            elif isinstance(value, float):
                types.add("number")
            elif isinstance(value, str):
                types.add("string")
            elif isinstance(value, list):
                types.add("array")
            elif isinstance(value, dict):
                types.add("object")

    # Handle anyOf, oneOf, allOf - recursively extract types
    for choice_field in ("anyOf", "oneOf", "allOf"):
        if choice_field in schema and isinstance(schema[choice_field], list):
            for choice in schema[choice_field]:
                extracted = self._extract_types_from_schema(choice)
                types.update(extracted)

    # If no types found, default to string
    if not types:
        return ["string"]

    return list(types)

_generate_tool_call_id

_generate_tool_call_id() -> str

Generate a unique tool call ID.

Source code in vllm/tool_parsers/minimax_m2_tool_parser.py
def _generate_tool_call_id(self) -> str:
    """Generate a unique tool call ID."""
    return f"call_{uuid.uuid4().hex[:24]}"

_get_param_types_from_config

_get_param_types_from_config(
    param_name: str, param_config: dict
) -> list[str]

Get parameter types from parameter configuration. Handles anyOf, oneOf, allOf, and direct type definitions.

Parameters:

Name Type Description Default
param_name str

The name of the parameter

required
param_config dict

The properties dict from the tool schema

required

Returns:

Type Description
list[str]

List of type strings

Source code in vllm/tool_parsers/minimax_m2_tool_parser.py
def _get_param_types_from_config(
    self, param_name: str, param_config: dict
) -> list[str]:
    """
    Get parameter types from parameter configuration.
    Handles anyOf, oneOf, allOf, and direct type definitions.

    Args:
        param_name: The name of the parameter
        param_config: The properties dict from the tool schema

    Returns:
        List of type strings
    """
    if param_name not in param_config:
        return ["string"]

    param_schema = param_config[param_name]
    if not isinstance(param_schema, dict):
        return ["string"]

    return self._extract_types_from_schema(param_schema)

_parse_single_invoke

_parse_single_invoke(
    invoke_str: str, tools: list | None
) -> ToolCall | None

Parse a single block.

Source code in vllm/tool_parsers/minimax_m2_tool_parser.py
def _parse_single_invoke(
    self, invoke_str: str, tools: list | None
) -> ToolCall | None:
    """Parse a single <invoke> block."""
    # Extract function name
    name_match = re.search(r"^([^>]+)", invoke_str)
    if not name_match:
        return None

    function_name = self._extract_name(name_match.group(1))

    # Get parameter configuration
    param_config = {}
    if tools:
        for tool in tools:
            if (
                hasattr(tool, "function")
                and tool.function.name == function_name
                and hasattr(tool.function, "parameters")
            ):
                params = tool.function.parameters
                if isinstance(params, dict) and "properties" in params:
                    param_config = params["properties"]
                break

    # Extract parameters
    param_dict = {}
    for match in self.parameter_complete_regex.findall(invoke_str):
        param_match = re.search(r"^([^>]+)>(.*)", match, re.DOTALL)
        if param_match:
            param_name = self._extract_name(param_match.group(1))
            param_value = param_match.group(2).strip()

            # Get parameter types (supports anyOf/oneOf/allOf)
            param_type = self._get_param_types_from_config(param_name, param_config)

            # Convert value
            param_dict[param_name] = self._convert_param_value_with_types(
                param_value, param_type
            )

    return ToolCall(
        type="function",
        function=FunctionCall(
            name=function_name,
            arguments=json.dumps(param_dict, ensure_ascii=False),
        ),
    )

extract_tool_calls

extract_tool_calls(
    model_output: str, request: ChatCompletionRequest
) -> ExtractedToolCallInformation

Extract tool calls from complete model output (non-streaming).

Source code in vllm/tool_parsers/minimax_m2_tool_parser.py
def extract_tool_calls(
    self,
    model_output: str,
    request: ChatCompletionRequest,
) -> ExtractedToolCallInformation:
    """Extract tool calls from complete model output (non-streaming)."""
    # Quick check
    if self.tool_call_start_token not in model_output:
        return ExtractedToolCallInformation(
            tools_called=False, tool_calls=[], content=model_output
        )

    try:
        tool_calls = []

        # Find all complete tool_call blocks
        for tool_call_match in self.tool_call_complete_regex.findall(model_output):
            # Find all invokes within this tool_call
            for invoke_match in self.invoke_complete_regex.findall(tool_call_match):
                tool_call = self._parse_single_invoke(
                    invoke_match, request.tools if request else None
                )
                if tool_call:
                    tool_calls.append(tool_call)

        if not tool_calls:
            return ExtractedToolCallInformation(
                tools_called=False, tool_calls=[], content=model_output
            )

        # Update prev_tool_call_arr
        self.prev_tool_call_arr.clear()
        for tool_call in tool_calls:
            self.prev_tool_call_arr.append(
                {
                    "name": tool_call.function.name,
                    "arguments": tool_call.function.arguments,
                }
            )

        # Extract content before first tool call
        first_tool_idx = model_output.find(self.tool_call_start_token)
        content = model_output[:first_tool_idx] if first_tool_idx > 0 else None

        return ExtractedToolCallInformation(
            tools_called=True, tool_calls=tool_calls, content=content
        )

    except Exception:
        logger.exception("Error extracting tool calls")
        return ExtractedToolCallInformation(
            tools_called=False, tool_calls=[], content=model_output
        )

extract_tool_calls_streaming

extract_tool_calls_streaming(
    previous_text: str,
    current_text: str,
    delta_text: str,
    previous_token_ids: Sequence[int],
    current_token_ids: Sequence[int],
    delta_token_ids: Sequence[int],
    request: ChatCompletionRequest,
) -> DeltaMessage | None

Extract tool calls from streaming model output.

Uses a buffer-until-complete-invoke strategy: tokens are buffered until a complete <invoke>...</invoke> block is available, then parsed and emitted in one shot.

Source code in vllm/tool_parsers/minimax_m2_tool_parser.py
def extract_tool_calls_streaming(
    self,
    previous_text: str,
    current_text: str,
    delta_text: str,
    previous_token_ids: Sequence[int],  # pylint: disable=unused-argument
    current_token_ids: Sequence[int],  # pylint: disable=unused-argument
    delta_token_ids: Sequence[int],
    request: ChatCompletionRequest,
) -> DeltaMessage | None:
    """Extract tool calls from streaming model output.

    Uses a buffer-until-complete-invoke strategy: tokens are buffered
    until a complete ``<invoke>...</invoke>`` block is available, then
    parsed and emitted in one shot.
    """

    start_in_text = self.tool_call_start_token in delta_text
    start_in_ids = self.tool_call_start_token_id in delta_token_ids
    tool_call_starting = start_in_text or start_in_ids
    # Reset state on new request (parser is reused) or new tool-call block.
    if not previous_text or tool_call_starting:
        self.current_tool_index = 0
        self.prev_tool_call_arr.clear()
        self.streamed_args_for_tool.clear()
        self.is_tool_call_started = tool_call_starting

    # Pass through content before any tool call.
    if not self.is_tool_call_started:
        return DeltaMessage(content=delta_text) if delta_text else None

    # Capture content before the start token.
    content_before = None
    if start_in_text:
        before = delta_text[: delta_text.index(self.tool_call_start_token)]
        content_before = before or None

    # Extract newly completed <invoke> blocks as DeltaToolCalls.
    delta_tool_calls = self._extract_delta_tool_calls(current_text, request)

    if delta_tool_calls or content_before:
        return DeltaMessage(
            content=content_before,
            tool_calls=delta_tool_calls,
        )

    # EOS and </minimax:tool_call> both arrive as special tokens with
    # no decoded text. Return non-None for EOS so the serving framework
    # reaches the finish-reason handling path instead of skipping.
    if (
        not delta_text
        and delta_token_ids
        and self.prev_tool_call_arr
        and self.tool_call_end_token_id not in delta_token_ids
    ):
        return DeltaMessage(content="")

    return None