Line data Source code
1 : //
2 : // Copyright (c) 2002-2015 The ANGLE Project Authors. All rights reserved.
3 : // Use of this source code is governed by a BSD-style license that can be
4 : // found in the LICENSE file.
5 : //
6 : // RemoveDynamicIndexing is an AST traverser to remove dynamic indexing of vectors and matrices,
7 : // replacing them with calls to functions that choose which component to return or write.
8 : //
9 :
10 : #include "compiler/translator/RemoveDynamicIndexing.h"
11 :
12 : #include "compiler/translator/InfoSink.h"
13 : #include "compiler/translator/IntermNode.h"
14 : #include "compiler/translator/IntermNodePatternMatcher.h"
15 : #include "compiler/translator/SymbolTable.h"
16 :
17 : namespace sh
18 : {
19 :
20 : namespace
21 : {
22 :
23 0 : TName GetIndexFunctionName(const TType &type, bool write)
24 : {
25 0 : TInfoSinkBase nameSink;
26 0 : nameSink << "dyn_index_";
27 0 : if (write)
28 : {
29 0 : nameSink << "write_";
30 : }
31 0 : if (type.isMatrix())
32 : {
33 0 : nameSink << "mat" << type.getCols() << "x" << type.getRows();
34 : }
35 : else
36 : {
37 0 : switch (type.getBasicType())
38 : {
39 : case EbtInt:
40 0 : nameSink << "ivec";
41 0 : break;
42 : case EbtBool:
43 0 : nameSink << "bvec";
44 0 : break;
45 : case EbtUInt:
46 0 : nameSink << "uvec";
47 0 : break;
48 : case EbtFloat:
49 0 : nameSink << "vec";
50 0 : break;
51 : default:
52 0 : UNREACHABLE();
53 : }
54 0 : nameSink << type.getNominalSize();
55 : }
56 0 : TString nameString = TFunction::mangleName(nameSink.c_str());
57 0 : TName name(nameString);
58 0 : name.setInternal(true);
59 0 : return name;
60 : }
61 :
62 0 : TIntermSymbol *CreateBaseSymbol(const TType &type, TQualifier qualifier)
63 : {
64 0 : TIntermSymbol *symbol = new TIntermSymbol(0, "base", type);
65 0 : symbol->setInternal(true);
66 0 : symbol->getTypePointer()->setQualifier(qualifier);
67 0 : return symbol;
68 : }
69 :
70 0 : TIntermSymbol *CreateIndexSymbol()
71 : {
72 0 : TIntermSymbol *symbol = new TIntermSymbol(0, "index", TType(EbtInt, EbpHigh));
73 0 : symbol->setInternal(true);
74 0 : symbol->getTypePointer()->setQualifier(EvqIn);
75 0 : return symbol;
76 : }
77 :
78 0 : TIntermSymbol *CreateValueSymbol(const TType &type)
79 : {
80 0 : TIntermSymbol *symbol = new TIntermSymbol(0, "value", type);
81 0 : symbol->setInternal(true);
82 0 : symbol->getTypePointer()->setQualifier(EvqIn);
83 0 : return symbol;
84 : }
85 :
86 0 : TIntermConstantUnion *CreateIntConstantNode(int i)
87 : {
88 0 : TConstantUnion *constant = new TConstantUnion();
89 0 : constant->setIConst(i);
90 0 : return new TIntermConstantUnion(constant, TType(EbtInt, EbpHigh));
91 : }
92 :
93 0 : TIntermBinary *CreateIndexDirectBaseSymbolNode(const TType &indexedType,
94 : const TType &fieldType,
95 : const int index,
96 : TQualifier baseQualifier)
97 : {
98 0 : TIntermSymbol *baseSymbol = CreateBaseSymbol(indexedType, baseQualifier);
99 : TIntermBinary *indexNode =
100 0 : new TIntermBinary(EOpIndexDirect, baseSymbol, TIntermTyped::CreateIndexNode(index));
101 0 : return indexNode;
102 : }
103 :
104 0 : TIntermBinary *CreateAssignValueSymbolNode(TIntermTyped *targetNode, const TType &assignedValueType)
105 : {
106 0 : return new TIntermBinary(EOpAssign, targetNode, CreateValueSymbol(assignedValueType));
107 : }
108 :
109 0 : TIntermTyped *EnsureSignedInt(TIntermTyped *node)
110 : {
111 0 : if (node->getBasicType() == EbtInt)
112 0 : return node;
113 :
114 0 : TIntermAggregate *convertedNode = new TIntermAggregate(EOpConstructInt);
115 0 : convertedNode->setType(TType(EbtInt));
116 0 : convertedNode->getSequence()->push_back(node);
117 0 : convertedNode->setPrecisionFromChildren();
118 0 : return convertedNode;
119 : }
120 :
121 0 : TType GetFieldType(const TType &indexedType)
122 : {
123 0 : if (indexedType.isMatrix())
124 : {
125 0 : TType fieldType = TType(indexedType.getBasicType(), indexedType.getPrecision());
126 0 : fieldType.setPrimarySize(static_cast<unsigned char>(indexedType.getRows()));
127 0 : return fieldType;
128 : }
129 : else
130 : {
131 0 : return TType(indexedType.getBasicType(), indexedType.getPrecision());
132 : }
133 : }
134 :
135 : // Generate a read or write function for one field in a vector/matrix.
136 : // Out-of-range indices are clamped. This is consistent with how ANGLE handles out-of-range
137 : // indices in other places.
138 : // Note that indices can be either int or uint. We create only int versions of the functions,
139 : // and convert uint indices to int at the call site.
140 : // read function example:
141 : // float dyn_index_vec2(in vec2 base, in int index)
142 : // {
143 : // switch(index)
144 : // {
145 : // case (0):
146 : // return base[0];
147 : // case (1):
148 : // return base[1];
149 : // default:
150 : // break;
151 : // }
152 : // if (index < 0)
153 : // return base[0];
154 : // return base[1];
155 : // }
156 : // write function example:
157 : // void dyn_index_write_vec2(inout vec2 base, in int index, in float value)
158 : // {
159 : // switch(index)
160 : // {
161 : // case (0):
162 : // base[0] = value;
163 : // return;
164 : // case (1):
165 : // base[1] = value;
166 : // return;
167 : // default:
168 : // break;
169 : // }
170 : // if (index < 0)
171 : // {
172 : // base[0] = value;
173 : // return;
174 : // }
175 : // base[1] = value;
176 : // }
177 : // Note that else is not used in above functions to avoid the RewriteElseBlocks transformation.
178 0 : TIntermFunctionDefinition *GetIndexFunctionDefinition(TType type, bool write)
179 : {
180 0 : ASSERT(!type.isArray());
181 : // Conservatively use highp here, even if the indexed type is not highp. That way the code can't
182 : // end up using mediump version of an indexing function for a highp value, if both mediump and
183 : // highp values are being indexed in the shader. For HLSL precision doesn't matter, but in
184 : // principle this code could be used with multiple backends.
185 0 : type.setPrecision(EbpHigh);
186 :
187 0 : TType fieldType = GetFieldType(type);
188 0 : int numCases = 0;
189 0 : if (type.isMatrix())
190 : {
191 0 : numCases = type.getCols();
192 : }
193 : else
194 : {
195 0 : numCases = type.getNominalSize();
196 : }
197 :
198 0 : TIntermAggregate *paramsNode = new TIntermAggregate(EOpParameters);
199 0 : TQualifier baseQualifier = EvqInOut;
200 0 : if (!write)
201 0 : baseQualifier = EvqIn;
202 0 : TIntermSymbol *baseParam = CreateBaseSymbol(type, baseQualifier);
203 0 : paramsNode->getSequence()->push_back(baseParam);
204 0 : TIntermSymbol *indexParam = CreateIndexSymbol();
205 0 : paramsNode->getSequence()->push_back(indexParam);
206 0 : if (write)
207 : {
208 0 : TIntermSymbol *valueParam = CreateValueSymbol(fieldType);
209 0 : paramsNode->getSequence()->push_back(valueParam);
210 : }
211 :
212 0 : TIntermBlock *statementList = new TIntermBlock();
213 0 : for (int i = 0; i < numCases; ++i)
214 : {
215 0 : TIntermCase *caseNode = new TIntermCase(CreateIntConstantNode(i));
216 0 : statementList->getSequence()->push_back(caseNode);
217 :
218 : TIntermBinary *indexNode =
219 0 : CreateIndexDirectBaseSymbolNode(type, fieldType, i, baseQualifier);
220 0 : if (write)
221 : {
222 0 : TIntermBinary *assignNode = CreateAssignValueSymbolNode(indexNode, fieldType);
223 0 : statementList->getSequence()->push_back(assignNode);
224 0 : TIntermBranch *returnNode = new TIntermBranch(EOpReturn, nullptr);
225 0 : statementList->getSequence()->push_back(returnNode);
226 : }
227 : else
228 : {
229 0 : TIntermBranch *returnNode = new TIntermBranch(EOpReturn, indexNode);
230 0 : statementList->getSequence()->push_back(returnNode);
231 : }
232 : }
233 :
234 : // Default case
235 0 : TIntermCase *defaultNode = new TIntermCase(nullptr);
236 0 : statementList->getSequence()->push_back(defaultNode);
237 0 : TIntermBranch *breakNode = new TIntermBranch(EOpBreak, nullptr);
238 0 : statementList->getSequence()->push_back(breakNode);
239 :
240 0 : TIntermSwitch *switchNode = new TIntermSwitch(CreateIndexSymbol(), statementList);
241 :
242 0 : TIntermBlock *bodyNode = new TIntermBlock();
243 0 : bodyNode->getSequence()->push_back(switchNode);
244 :
245 : TIntermBinary *cond =
246 0 : new TIntermBinary(EOpLessThan, CreateIndexSymbol(), CreateIntConstantNode(0));
247 0 : cond->setType(TType(EbtBool, EbpUndefined));
248 :
249 : // Two blocks: one accesses (either reads or writes) the first element and returns,
250 : // the other accesses the last element.
251 0 : TIntermBlock *useFirstBlock = new TIntermBlock();
252 0 : TIntermBlock *useLastBlock = new TIntermBlock();
253 : TIntermBinary *indexFirstNode =
254 0 : CreateIndexDirectBaseSymbolNode(type, fieldType, 0, baseQualifier);
255 : TIntermBinary *indexLastNode =
256 0 : CreateIndexDirectBaseSymbolNode(type, fieldType, numCases - 1, baseQualifier);
257 0 : if (write)
258 : {
259 0 : TIntermBinary *assignFirstNode = CreateAssignValueSymbolNode(indexFirstNode, fieldType);
260 0 : useFirstBlock->getSequence()->push_back(assignFirstNode);
261 0 : TIntermBranch *returnNode = new TIntermBranch(EOpReturn, nullptr);
262 0 : useFirstBlock->getSequence()->push_back(returnNode);
263 :
264 0 : TIntermBinary *assignLastNode = CreateAssignValueSymbolNode(indexLastNode, fieldType);
265 0 : useLastBlock->getSequence()->push_back(assignLastNode);
266 : }
267 : else
268 : {
269 0 : TIntermBranch *returnFirstNode = new TIntermBranch(EOpReturn, indexFirstNode);
270 0 : useFirstBlock->getSequence()->push_back(returnFirstNode);
271 :
272 0 : TIntermBranch *returnLastNode = new TIntermBranch(EOpReturn, indexLastNode);
273 0 : useLastBlock->getSequence()->push_back(returnLastNode);
274 : }
275 0 : TIntermIfElse *ifNode = new TIntermIfElse(cond, useFirstBlock, nullptr);
276 0 : bodyNode->getSequence()->push_back(ifNode);
277 0 : bodyNode->getSequence()->push_back(useLastBlock);
278 :
279 0 : TIntermFunctionDefinition *indexingFunction = nullptr;
280 0 : if (write)
281 : {
282 0 : indexingFunction = new TIntermFunctionDefinition(TType(EbtVoid), paramsNode, bodyNode);
283 : }
284 : else
285 : {
286 0 : indexingFunction = new TIntermFunctionDefinition(fieldType, paramsNode, bodyNode);
287 : }
288 0 : indexingFunction->getFunctionSymbolInfo()->setNameObj(GetIndexFunctionName(type, write));
289 0 : return indexingFunction;
290 : }
291 :
292 0 : class RemoveDynamicIndexingTraverser : public TLValueTrackingTraverser
293 : {
294 : public:
295 : RemoveDynamicIndexingTraverser(const TSymbolTable &symbolTable, int shaderVersion);
296 :
297 : bool visitBinary(Visit visit, TIntermBinary *node) override;
298 :
299 : void insertHelperDefinitions(TIntermNode *root);
300 :
301 : void nextIteration();
302 :
303 0 : bool usedTreeInsertion() const { return mUsedTreeInsertion; }
304 :
305 : protected:
306 : // Sets of types that are indexed. Note that these can not store multiple variants
307 : // of the same type with different precisions - only one precision gets stored.
308 : std::set<TType> mIndexedVecAndMatrixTypes;
309 : std::set<TType> mWrittenVecAndMatrixTypes;
310 :
311 : bool mUsedTreeInsertion;
312 :
313 : // When true, the traverser will remove side effects from any indexing expression.
314 : // This is done so that in code like
315 : // V[j++][i]++.
316 : // where V is an array of vectors, j++ will only be evaluated once.
317 : bool mRemoveIndexSideEffectsInSubtree;
318 : };
319 :
320 0 : RemoveDynamicIndexingTraverser::RemoveDynamicIndexingTraverser(const TSymbolTable &symbolTable,
321 0 : int shaderVersion)
322 : : TLValueTrackingTraverser(true, false, false, symbolTable, shaderVersion),
323 : mUsedTreeInsertion(false),
324 0 : mRemoveIndexSideEffectsInSubtree(false)
325 : {
326 0 : }
327 :
328 0 : void RemoveDynamicIndexingTraverser::insertHelperDefinitions(TIntermNode *root)
329 : {
330 0 : TIntermBlock *rootBlock = root->getAsBlock();
331 0 : ASSERT(rootBlock != nullptr);
332 0 : TIntermSequence insertions;
333 0 : for (TType type : mIndexedVecAndMatrixTypes)
334 : {
335 0 : insertions.push_back(GetIndexFunctionDefinition(type, false));
336 : }
337 0 : for (TType type : mWrittenVecAndMatrixTypes)
338 : {
339 0 : insertions.push_back(GetIndexFunctionDefinition(type, true));
340 : }
341 0 : mInsertions.push_back(NodeInsertMultipleEntry(rootBlock, 0, insertions, TIntermSequence()));
342 0 : }
343 :
344 : // Create a call to dyn_index_*() based on an indirect indexing op node
345 0 : TIntermAggregate *CreateIndexFunctionCall(TIntermBinary *node,
346 : TIntermTyped *indexedNode,
347 : TIntermTyped *index)
348 : {
349 0 : ASSERT(node->getOp() == EOpIndexIndirect);
350 0 : TIntermAggregate *indexingCall = new TIntermAggregate(EOpFunctionCall);
351 0 : indexingCall->setLine(node->getLine());
352 0 : indexingCall->setUserDefined();
353 0 : indexingCall->getFunctionSymbolInfo()->setNameObj(
354 0 : GetIndexFunctionName(indexedNode->getType(), false));
355 0 : indexingCall->getSequence()->push_back(indexedNode);
356 0 : indexingCall->getSequence()->push_back(index);
357 :
358 0 : TType fieldType = GetFieldType(indexedNode->getType());
359 0 : indexingCall->setType(fieldType);
360 0 : return indexingCall;
361 : }
362 :
363 0 : TIntermAggregate *CreateIndexedWriteFunctionCall(TIntermBinary *node,
364 : TIntermTyped *index,
365 : TIntermTyped *writtenValue)
366 : {
367 : // Deep copy the left node so that two pointers to the same node don't end up in the tree.
368 0 : TIntermNode *leftCopy = node->getLeft()->deepCopy();
369 0 : ASSERT(leftCopy != nullptr && leftCopy->getAsTyped() != nullptr);
370 : TIntermAggregate *indexedWriteCall =
371 0 : CreateIndexFunctionCall(node, leftCopy->getAsTyped(), index);
372 0 : indexedWriteCall->getFunctionSymbolInfo()->setNameObj(
373 0 : GetIndexFunctionName(node->getLeft()->getType(), true));
374 0 : indexedWriteCall->setType(TType(EbtVoid));
375 0 : indexedWriteCall->getSequence()->push_back(writtenValue);
376 0 : return indexedWriteCall;
377 : }
378 :
379 0 : bool RemoveDynamicIndexingTraverser::visitBinary(Visit visit, TIntermBinary *node)
380 : {
381 0 : if (mUsedTreeInsertion)
382 0 : return false;
383 :
384 0 : if (node->getOp() == EOpIndexIndirect)
385 : {
386 0 : if (mRemoveIndexSideEffectsInSubtree)
387 : {
388 0 : ASSERT(node->getRight()->hasSideEffects());
389 : // In case we're just removing index side effects, convert
390 : // v_expr[index_expr]
391 : // to this:
392 : // int s0 = index_expr; v_expr[s0];
393 : // Now v_expr[s0] can be safely executed several times without unintended side effects.
394 :
395 : // Init the temp variable holding the index
396 0 : TIntermDeclaration *initIndex = createTempInitDeclaration(node->getRight());
397 0 : insertStatementInParentBlock(initIndex);
398 0 : mUsedTreeInsertion = true;
399 :
400 : // Replace the index with the temp variable
401 0 : TIntermSymbol *tempIndex = createTempSymbol(node->getRight()->getType());
402 0 : queueReplacementWithParent(node, node->getRight(), tempIndex, OriginalNode::IS_DROPPED);
403 : }
404 0 : else if (IntermNodePatternMatcher::IsDynamicIndexingOfVectorOrMatrix(node))
405 : {
406 0 : bool write = isLValueRequiredHere();
407 :
408 : #if defined(ANGLE_ENABLE_ASSERTS)
409 : // Make sure that IntermNodePatternMatcher is consistent with the slightly differently
410 : // implemented checks in this traverser.
411 : IntermNodePatternMatcher matcher(
412 0 : IntermNodePatternMatcher::kDynamicIndexingOfVectorOrMatrixInLValue);
413 0 : ASSERT(matcher.match(node, getParentNode(), isLValueRequiredHere()) == write);
414 : #endif
415 :
416 0 : TType type = node->getLeft()->getType();
417 0 : mIndexedVecAndMatrixTypes.insert(type);
418 :
419 0 : if (write)
420 : {
421 : // Convert:
422 : // v_expr[index_expr]++;
423 : // to this:
424 : // int s0 = index_expr; float s1 = dyn_index(v_expr, s0); s1++;
425 : // dyn_index_write(v_expr, s0, s1);
426 : // This works even if index_expr has some side effects.
427 0 : if (node->getLeft()->hasSideEffects())
428 : {
429 : // If v_expr has side effects, those need to be removed before proceeding.
430 : // Otherwise the side effects of v_expr would be evaluated twice.
431 : // The only case where an l-value can have side effects is when it is
432 : // indexing. For example, it can be V[j++] where V is an array of vectors.
433 0 : mRemoveIndexSideEffectsInSubtree = true;
434 0 : return true;
435 : }
436 : // TODO(oetuaho@nvidia.com): This is not optimal if the expression using the value
437 : // only writes it and doesn't need the previous value. http://anglebug.com/1116
438 :
439 0 : mWrittenVecAndMatrixTypes.insert(type);
440 0 : TType fieldType = GetFieldType(type);
441 :
442 0 : TIntermSequence insertionsBefore;
443 0 : TIntermSequence insertionsAfter;
444 :
445 : // Store the index in a temporary signed int variable.
446 0 : TIntermTyped *indexInitializer = EnsureSignedInt(node->getRight());
447 0 : TIntermDeclaration *initIndex = createTempInitDeclaration(indexInitializer);
448 0 : initIndex->setLine(node->getLine());
449 0 : insertionsBefore.push_back(initIndex);
450 :
451 0 : TIntermAggregate *indexingCall = CreateIndexFunctionCall(
452 0 : node, node->getLeft(), createTempSymbol(indexInitializer->getType()));
453 :
454 : // Create a node for referring to the index after the nextTemporaryIndex() call
455 : // below.
456 0 : TIntermSymbol *tempIndex = createTempSymbol(indexInitializer->getType());
457 :
458 0 : nextTemporaryIndex(); // From now on, creating temporary symbols that refer to the
459 : // field value.
460 0 : insertionsBefore.push_back(createTempInitDeclaration(indexingCall));
461 :
462 : TIntermAggregate *indexedWriteCall =
463 0 : CreateIndexedWriteFunctionCall(node, tempIndex, createTempSymbol(fieldType));
464 0 : insertionsAfter.push_back(indexedWriteCall);
465 0 : insertStatementsInParentBlock(insertionsBefore, insertionsAfter);
466 0 : queueReplacement(node, createTempSymbol(fieldType), OriginalNode::IS_DROPPED);
467 0 : mUsedTreeInsertion = true;
468 : }
469 : else
470 : {
471 : // The indexed value is not being written, so we can simply convert
472 : // v_expr[index_expr]
473 : // into
474 : // dyn_index(v_expr, index_expr)
475 : // If the index_expr is unsigned, we'll convert it to signed.
476 0 : ASSERT(!mRemoveIndexSideEffectsInSubtree);
477 0 : TIntermAggregate *indexingCall = CreateIndexFunctionCall(
478 0 : node, node->getLeft(), EnsureSignedInt(node->getRight()));
479 0 : queueReplacement(node, indexingCall, OriginalNode::IS_DROPPED);
480 : }
481 : }
482 : }
483 0 : return !mUsedTreeInsertion;
484 : }
485 :
486 0 : void RemoveDynamicIndexingTraverser::nextIteration()
487 : {
488 0 : mUsedTreeInsertion = false;
489 0 : mRemoveIndexSideEffectsInSubtree = false;
490 0 : nextTemporaryIndex();
491 0 : }
492 :
493 : } // namespace
494 :
495 0 : void RemoveDynamicIndexing(TIntermNode *root,
496 : unsigned int *temporaryIndex,
497 : const TSymbolTable &symbolTable,
498 : int shaderVersion)
499 : {
500 0 : RemoveDynamicIndexingTraverser traverser(symbolTable, shaderVersion);
501 0 : ASSERT(temporaryIndex != nullptr);
502 0 : traverser.useTemporaryIndex(temporaryIndex);
503 0 : do
504 : {
505 0 : traverser.nextIteration();
506 0 : root->traverse(&traverser);
507 0 : traverser.updateTree();
508 : } while (traverser.usedTreeInsertion());
509 0 : traverser.insertHelperDefinitions(root);
510 0 : traverser.updateTree();
511 0 : }
512 :
513 : } // namespace sh
|